# -*- coding: utf-8 -*-
import os

import ydb

from travel.avia.library.python.ydb.session_manager import YdbSessionManager


class MonthAndYearPricesByCityToTable(object):
    def __init__(self, ydb_session_manager, database, table_name):
        # type: (YdbSessionManager, str, str) -> None
        self._session_manager = ydb_session_manager
        self._database = database
        self._table_name = table_name

    def create_if_doesnt_exist(self):
        def callee(session):
            # type: (ydb.table.Session) -> None
            primary_key = [
                'to_id', 'national_version',
            ]
            profile = (
                ydb.TableProfile()
                    .with_replication_policy(
                    ydb.ReplicationPolicy()
                        .with_allow_promotion(ydb.FeatureFlag.ENABLED)
                        .with_create_per_availability_zone(ydb.FeatureFlag.ENABLED)
                        .with_replicas_count(1)
                )
            )
            session.create_table(
                os.path.join(self._database, self._table_name),
                ydb.TableDescription()
                    .with_column(ydb.Column('to_id', ydb.OptionalType(ydb.PrimitiveType.Uint32)))
                    .with_column(ydb.Column('national_version', ydb.OptionalType(ydb.PrimitiveType.Utf8)))
                    .with_column(ydb.Column('currency', ydb.OptionalType(ydb.PrimitiveType.Utf8)))

                    .with_column(ydb.Column('year_median_price', ydb.OptionalType(ydb.PrimitiveType.Uint32)))

                    .with_column(ydb.Column('popular_month_year', ydb.OptionalType(ydb.PrimitiveType.Uint32)))
                    .with_column(ydb.Column('popular_month', ydb.OptionalType(ydb.PrimitiveType.Uint8)))
                    .with_column(ydb.Column('popular_month_median_price', ydb.OptionalType(ydb.PrimitiveType.Uint32)))

                    .with_column(ydb.Column('min_month_year', ydb.OptionalType(ydb.PrimitiveType.Uint32)))
                    .with_column(ydb.Column('min_month', ydb.OptionalType(ydb.PrimitiveType.Uint8)))
                    .with_column(ydb.Column('min_month_median_price', ydb.OptionalType(ydb.PrimitiveType.Uint32)))

                    .with_column(ydb.Column('max_month_year', ydb.OptionalType(ydb.PrimitiveType.Uint32)))
                    .with_column(ydb.Column('max_month', ydb.OptionalType(ydb.PrimitiveType.Uint8)))
                    .with_column(ydb.Column('max_month_median_price', ydb.OptionalType(ydb.PrimitiveType.Uint32)))

                    .with_primary_keys(*primary_key)
                    .with_profile(profile)
            )

        with self._session_manager.get_session_pool() as session_pool:
            session_pool.retry_operation_sync(callee)

    def replace_batch(self, batch):
        query = """
        --!syntax_v1
        PRAGMA TablePathPrefix("{path}");

        DECLARE $data AS List<Struct<
            to_id: Uint32,
            national_version: Utf8,
            currency: Utf8,

            year_median_price: Uint32,

            popular_month_year: Uint32,
            popular_month: Uint8,
            popular_month_median_price: Uint32,

            min_month_year: Uint32,
            min_month: Uint8,
            min_month_median_price: Uint32,

            max_month_year: Uint32,
            max_month: Uint8,
            max_month_median_price: Uint32
            >>;

        REPLACE INTO {table}
        SELECT * FROM AS_TABLE($data);
        """.format(path=self._database, table=self._table_name)

        def callee(session):
            # type: (ydb.table.Session) -> None
            prepared_query = session.prepare(query)
            session.transaction(ydb.SerializableReadWrite()).execute(
                prepared_query,
                commit_tx=True,
                parameters={
                    '$data': batch,
                },
                settings=ydb.table.settings.BaseRequestSettings().with_operation_timeout(5),
            )

        with self._session_manager.get_session_pool() as session_pool:
            session_pool.retry_operation_sync(callee)

    def count(self):
        query = """
        --!syntax_v1
        PRAGMA TablePathPrefix("{path}");

        SELECT COUNT(*) as c FROM {table};
        """.format(path=self._database, table=self._table_name)

        def callee(session):
            # type: (ydb.table.Session) -> int
            result_sets = session.transaction(ydb.SerializableReadWrite()).execute(
                query,
                commit_tx=True,
                settings=ydb.table.settings.BaseRequestSettings().with_operation_timeout(5),
            )
            return result_sets[0].rows[0]['c']

        with self._session_manager.get_session_pool() as session_pool:
            return session_pool.retry_operation_sync(callee)


class MonthAndYearPricesByCityTo(object):
    __slots__ = (
        'to_id',
        'national_version',
        'currency',

        'year_median_price',

        'popular_month_year',
        'popular_month',
        'popular_month_median_price',

        'min_month_year',
        'min_month',
        'min_month_median_price',

        'max_month_year',
        'max_month',
        'max_month_median_price',
    )

    def __init__(
            self,

            to_id,
            national_version,
            currency,

            year_median_price,

            popular_month_year,
            popular_month,
            popular_month_median_price,

            min_month_year,
            min_month,
            min_month_median_price,

            max_month_year,
            max_month,
            max_month_median_price,
    ):
        self.to_id = to_id
        self.national_version = national_version
        self.currency = currency

        self.year_median_price = year_median_price

        self.popular_month_year = popular_month_year
        self.popular_month = popular_month
        self.popular_month_median_price = popular_month_median_price

        self.min_month_year = min_month_year
        self.min_month = min_month
        self.min_month_median_price = min_month_median_price

        self.max_month_year = max_month_year
        self.max_month = max_month
        self.max_month_median_price = max_month_median_price

    def __str__(self):
        return str({s: getattr(self, s) for s in self.__slots__})
