# -*- coding: utf-8 -*-
import os

import ydb

from travel.avia.library.python.ydb.session_manager import YdbSessionManager


class CityRouteCrosslinksTable(object):
    def __init__(self, ydb_session_manager, database, table_name):
        # type: (YdbSessionManager, str, str) -> None
        self._session_manager = ydb_session_manager
        self._database = database
        self._table_name = table_name

    def create_if_doesnt_exist(self):
        def callee(session):
            # type: (ydb.table.Session) -> None
            primary_key = [
                'to_id', 'national_version', 'position',
            ]
            profile = (
                ydb.TableProfile()
                    .with_replication_policy(
                    ydb.ReplicationPolicy()
                        .with_allow_promotion(ydb.FeatureFlag.ENABLED)
                        .with_create_per_availability_zone(ydb.FeatureFlag.ENABLED)
                        .with_replicas_count(1)
                )
            )
            session.create_table(
                os.path.join(self._database, self._table_name),
                ydb.TableDescription()
                    .with_column(ydb.Column('to_id', ydb.OptionalType(ydb.PrimitiveType.Uint32)))
                    .with_column(ydb.Column('national_version', ydb.OptionalType(ydb.PrimitiveType.Utf8)))
                    .with_column(ydb.Column('position', ydb.OptionalType(ydb.PrimitiveType.Uint8)))
                    .with_column(ydb.Column('crosslink_from_id', ydb.OptionalType(ydb.PrimitiveType.Uint32)))
                    .with_column(ydb.Column('crosslink_to_id', ydb.OptionalType(ydb.PrimitiveType.Uint32)))
                    .with_column(ydb.Column('price', ydb.OptionalType(ydb.PrimitiveType.Uint32)))
                    .with_column(ydb.Column('currency', ydb.OptionalType(ydb.PrimitiveType.Utf8)))
                    .with_column(ydb.Column('date', ydb.OptionalType(ydb.PrimitiveType.Utf8)))
                    .with_primary_keys(*primary_key)
                    .with_profile(profile)
            )

        with self._session_manager.get_session_pool() as session_pool:
            session_pool.retry_operation_sync(callee)

    def replace_batch(self, batch):
        """
        При обновлении REPLACE INTO перетераются данные по primary key.
        Поэтому если записываем меньше записей для города назначения, чем раньше,
        то возможны дубли направлений, их нужно фильтровать при использовании.
        По этой же причине, нужно при использовании фильтровать цены на прошлые даты.
        """
        query = """
        --!syntax_v1
        PRAGMA TablePathPrefix("{path}");

        DECLARE $data AS List<Struct<
            to_id: Uint32,
            national_version: Utf8,
            position: Uint8,
            crosslink_from_id: Uint32,
            crosslink_to_id: Uint32,
            price: Uint32?,
            currency: Utf8?,
            `date`: Utf8?>>;

        REPLACE INTO {table}
        SELECT * FROM AS_TABLE($data);
        """.format(path=self._database, table=self._table_name)

        def callee(session):
            # type: (ydb.table.Session) -> None
            prepared_query = session.prepare(query)
            session.transaction(ydb.SerializableReadWrite()).execute(
                prepared_query,
                commit_tx=True,
                parameters={
                    '$data': batch,
                },
                settings=ydb.table.settings.BaseRequestSettings().with_operation_timeout(5),
            )

        with self._session_manager.get_session_pool() as session_pool:
            session_pool.retry_operation_sync(callee)

    def count(self):
        query = """
        --!syntax_v1
        PRAGMA TablePathPrefix("{path}");

        SELECT COUNT(*) as c FROM {table};
        """.format(path=self._database, table=self._table_name)

        def callee(session):
            # type: (ydb.table.Session) -> int
            result_sets = session.transaction(ydb.SerializableReadWrite()).execute(
                query,
                commit_tx=True,
                settings=ydb.table.settings.BaseRequestSettings().with_operation_timeout(5),
            )
            return result_sets[0].rows[0]['c']

        with self._session_manager.get_session_pool() as session_pool:
            return session_pool.retry_operation_sync(callee)

    def cities_count(self):
        query = """
        --!syntax_v1
        PRAGMA TablePathPrefix("{path}");

        SELECT COUNT(DISTINCT to_id) as c FROM {table};
        """.format(path=self._database, table=self._table_name)

        def callee(session):
            # type: (ydb.table.Session) -> int
            result_sets = session.transaction(ydb.SerializableReadWrite()).execute(
                query,
                commit_tx=True,
                settings=ydb.table.settings.BaseRequestSettings().with_operation_timeout(5),
            )
            return result_sets[0].rows[0]['c']

        with self._session_manager.get_session_pool() as session_pool:
            return session_pool.retry_operation_sync(callee)


class CityRouteCrosslink(object):
    __slots__ = (
        'to_id',
        'national_version',
        'position',
        'crosslink_from_id',
        'crosslink_to_id',
        'price',
        'currency',
        'date',
    )

    def __init__(
        self,
        to_id,
        national_version,
        position,
        crosslink_from_id,
        crosslink_to_id,
        price,
        currency,
        date,
    ):
        self.to_id = to_id
        self.national_version = national_version
        self.position = position
        self.crosslink_from_id = crosslink_from_id
        self.crosslink_to_id = crosslink_to_id
        self.price = price
        self.currency = currency
        self.date = date

    def __str__(self):
        return str({s: getattr(self, s) for s in self.__slots__})
