from datacloud.dev_utils.ydb.lib.core.ydb_table import YdbTable
import ydb


class ScoreTable(YdbTable):
    def __init__(self, ydb_manager, database, table):
        super(ScoreTable, self).__init__(ydb_manager, database, table)

    def create(self):
        with self._init_session() as session:
            session.create_table(
                self.full_table_path,
                ydb.TableDescription()
                .with_column(ydb.Column('hashed_cid', ydb.OptionalType(ydb.DataType.Uint64)))
                .with_column(ydb.Column('score', ydb.OptionalType(ydb.DataType.Double)))
                .with_primary_key('hashed_cid')
            )

    def get_one(self, record):
        query_params = {
            '$hashed_cid': record.hashed_cid
        }
        return self._get_one(query_params)

    def get(self, record):
        query_params = {
            '$hashed_cid': record.hashed_cid
        }
        for record in self._get(query_params):
            yield record

    def get_multiple(self, hashed_cids_list):
        # TODO: Raw request, refactor later
        with self._init_session() as session:
            query = self._select_multiple_request.format(
                database=self.database,
                table=self.table,
                items='(' + ', '.join(map(str, hashed_cids_list)) + ')'
            )
            result_sets = session.transaction(ydb.SerializableReadWrite()).execute(
                query,
                commit_tx=True
            )
            for row in result_sets[0].rows:
                yield self.Record(**row)

    class Record(object):
        __slots__ = ('hashed_cid', 'score')

        def __init__(self, hashed_cid, score=None):
            self.hashed_cid = hashed_cid
            self.score = score

        def __str__(self):
            return '({hashed_cid}: {score})'.format(hashed_cid=self.hashed_cid, score=self.score)

        def __repr__(self):
            return self.__str__()

        def __eq__(self, other):
            return self.hashed_cid == other.hashed_cid and self.score == other.score

    _insert_request = """
        PRAGMA TablePathPrefix("{database}");

        DECLARE $records AS "List<Struct<
            hashed_cid: Uint64,
            score: Double>>";

        REPLACE INTO [{table}]
        SELECT
            hashed_cid,
            score
        FROM AS_TABLE($records);
    """

    _select_request = """
        PRAGMA TablePathPrefix("{database}");

        DECLARE $hashed_cid AS Uint64;

        SELECT *
        FROM [{table}]
        WHERE hashed_cid = $hashed_cid;
    """

    _select_multiple_request = """
        PRAGMA TablePathPrefix("{database}");

        SELECT *
        FROM [{table}]
        WHERE hashed_cid IN {items};
    """


class MetaScoreTable(YdbTable):
    def __init__(self, ydb_manager, database, table):
        super(MetaScoreTable, self).__init__(ydb_manager, database, table)

    def create(self):
        with self._init_session() as session:
            session.create_table(
                self.full_table_path,
                ydb.TableDescription()
                .with_column(ydb.Column('hashed_id', ydb.OptionalType(ydb.DataType.Uint64)))
                .with_column(ydb.Column('score', ydb.OptionalType(ydb.DataType.Double)))
                .with_primary_key('hashed_id')
            )

    def get_one(self, record):
        query_params = {
            '$hashed_id': record.hashed_id
        }
        return self._get_one(query_params)

    def get(self, record):
        query_params = {
            '$hashed_id': record.hashed_id
        }
        for rec in self._get(query_params):
            yield rec

    def get_multiple(self, hashed_ids_list):
        # TODO: Raw request, refactor later
        with self._init_session() as session:
            query = self._select_multiple_request.format(
                database=self.database,
                table=self.table,
                items='(' + ', '.join(map(str, hashed_ids_list)) + ')'
            )
            result_sets = session.transaction(ydb.SerializableReadWrite()).execute(
                query,
                commit_tx=True
            )
            for row in result_sets[0].rows:
                yield self.Record(**row)

    class Record(object):
        __slots__ = ('hashed_id', 'score')

        def __init__(self, hashed_id, score=None):
            self.hashed_id = hashed_id
            self.score = score

        def __str__(self):
            return '({hashed_id}: {score})'.format(hashed_id=self.hashed_id, score=self.score)

        def __repr__(self):
            return self.__str__()

        def __eq__(self, other):
            return self.hashed_id == other.hashed_id and self.score == other.score

    _insert_request = """
        PRAGMA TablePathPrefix("{database}");

        DECLARE $records AS "List<Struct<
            hashed_id: Uint64,
            score: Double>>";

        REPLACE INTO [{table}]
        SELECT
            hashed_id,
            score
        FROM AS_TABLE($records);
    """

    _select_request = """
        PRAGMA TablePathPrefix("{database}");

        DECLARE $hashed_id AS Uint64;

        SELECT *
        FROM [{table}]
        WHERE hashed_id = $hashed_id;
    """

    _select_multiple_request = """
        PRAGMA TablePathPrefix("{database}");

        SELECT *
        FROM [{table}]
        WHERE hashed_id IN {items};
    """
