import logging


BATCH_SIZE = 1000000


class YabsAwapsExportYtHandler():
    def __init__(self, yt_cluster, yt_path, yt_token, sql_host, sql_database, sql_username, sql_password):
        self.yt_cluster = yt_cluster
        self.yt_path = yt_path
        self.yt_token = yt_token

        self.sql_host = sql_host
        self.sql_database = sql_database
        self.sql_username = sql_username
        self.sql_password = sql_password

    def export(self, sql_table_name):
        import yt.wrapper as yt
        from sqlalchemy import create_engine, MetaData, Table
        from sqlalchemy.engine.reflection import Inspector
        from sqlalchemy.orm import sessionmaker
        from sqlalchemy.dialects.mssql import base
        from sqlalchemy.sql import sqltypes
        from uuid import UUID
        from datetime import datetime, date

        def get_yt_type(sql_type):
            if isinstance(sql_type, base.BIGINT):
                return 'int64'
            elif isinstance(sql_type, base.INTEGER) or \
                    isinstance(sql_type, base.SMALLINT):
                return 'int32'
            elif isinstance(sql_type, base.TINYINT):
                return 'uint8'
            elif isinstance(sql_type, base.BIT):
                return 'boolean'
            elif isinstance(sql_type, base.MONEY) or \
                    isinstance(sql_type, base.FLOAT) or \
                    isinstance(sql_type, base.REAL) or \
                    isinstance(sql_type, sqltypes.DECIMAL) or \
                    isinstance(sql_type, base.NUMERIC):
                return 'double'
            elif isinstance(sql_type, base.DATETIME) or \
                    isinstance(sql_type, base.SMALLDATETIME) or \
                    isinstance(sql_type, base.TIMESTAMP) or \
                    isinstance(sql_type, base.DATE):
                return 'string'
            elif isinstance(sql_type, base.UNIQUEIDENTIFIER):
                return 'string'
            elif isinstance(sql_type, base.NVARCHAR) or \
                    isinstance(sql_type, base.VARCHAR) or \
                    isinstance(sql_type, sqltypes.TEXT) or \
                    isinstance(sql_type, base.CHAR):
                return 'string'

            raise Exception('Unknown sql type [{sql_type}] instance of {pyhton_type}'
                            .format(sql_type=str(sql_type).lower(), pyhton_type=type(sql_type)))

        def generate_yt_schema(table_meta, primary_key=None):
            yt_schema = []
            for c in table_meta.c:
                yt_column = {
                    'name': c.name,
                    'type': get_yt_type(c.type),
                    'required': not c.nullable,
                }
                if c.primary_key if primary_key is None else c.name in primary_key:
                    yt_column['sort_order'] = 'ascending'
                yt_schema.append(yt_column)

            # 'sort_order' fields must be first in YT schema
            return sorted(yt_schema,
                          key=lambda k: (
                              'sort_order' not in k,
                              primary_key.index(k['name']) if k['name'] in primary_key else 0,
                              k['name']
                          ))

        def cast_to_yt(value):
            if isinstance(value, UUID):
                return str(value)
            elif isinstance(value, datetime):
                return value.strftime('%Y-%m-%dT%H:%M:%S')
            elif isinstance(value, date):
                return value.strftime('%Y-%m-%d')

            return value

        yt_table = "{path}/{table}".format(path=self.yt_path, table=sql_table_name)

        logging.info("{host}/{db}/{table} -> {yt}".format(host=self.sql_host, db=self.sql_database, table=sql_table_name, yt=yt_table))

        yt_client = yt.YtClient(self.yt_cluster, token=self.yt_token)
        sql_engine = create_engine(
            'mssql+pymssql://{username}:{password}@{host}/{database}'.format(host=self.sql_host,
                                                                             database=self.sql_database,
                                                                             username=self.sql_username,
                                                                             password=self.sql_password))
        SessionMaker = sessionmaker(bind=sql_engine)
        session = SessionMaker()
        try:
            sql_inspector = Inspector.from_engine(sql_engine)
            sql_table = Table(sql_table_name, MetaData())
            sql_inspector.reflecttable(sql_table, None)
            sql_column_names = [c.name for c in sql_table.columns]
            logging.info(sql_column_names)

            schema = generate_yt_schema(sql_table, [])
            logging.info(schema)

            if yt_client.exists(yt_table):
                yt_client.remove(yt_table, recursive=True)

            yt_client.create('table', yt_table, recursive=True, attributes={'schema': schema})

            query = session.query(sql_table)

            yt_rows_sent = 0
            yt_batch = []
            for r in query.yield_per(BATCH_SIZE):
                yt_row = {column: cast_to_yt(r[idx]) for idx, column in enumerate(sql_column_names)}
                yt_batch.append(yt_row)

                if len(yt_batch) == BATCH_SIZE:
                    yt_client.write_table(yt.TablePath(yt_table, append=(yt_rows_sent > 0)),
                                          yt_batch,
                                          format=yt.JsonFormat(attributes={"encode_utf8": False}, encoding='utf-8'))
                    yt_rows_sent += len(yt_batch)
                    yt_batch[:] = []  # clear list
                    logging.info("{rows} rows sent".format(rows=yt_rows_sent))

            if len(yt_batch) > 0:
                yt_client.write_table(yt.TablePath(yt_table, append=(yt_rows_sent > 0)),
                                      yt_batch,
                                      format=yt.JsonFormat(attributes={"encode_utf8": False}, encoding='utf-8'))
                yt_rows_sent += len(yt_batch)
                logging.info("{rows} rows sent".format(rows=yt_rows_sent))

        except Exception as e:
            logging.error(str(e))
        finally:
            session.close()
