# -*- coding: utf-8 -*-
import os
from typing import List

import ydb

from travel.avia.library.python.ydb.session_manager import YdbSessionManager


class CityToNearestCities(object):
    __slots__ = (
        'to_id',
        'national_version',
        'nearest_city_ids',
    )

    def __init__(
            self,
            to_id: int,
            national_version: str,
            nearest_city_ids: str,
    ):
        self.to_id = to_id
        self.national_version = national_version
        self.nearest_city_ids = nearest_city_ids

    def __str__(self):
        return str({s: getattr(self, s) for s in self.__slots__})


class CityToNearestCitiesTable(object):
    def __init__(
            self,
            ydb_session_manager: YdbSessionManager,
            database: str,
            table_name: str,
    ) -> None:
        self._session_manager = ydb_session_manager
        self._database = database
        self._table_name = table_name

    def create_if_doesnt_exist(self) -> None:
        def callee(session: ydb.table.Session) -> None:
            primary_key = [
                'to_id', 'national_version',
            ]
            profile = (
                ydb.TableProfile()
                    .with_replication_policy(
                    ydb.ReplicationPolicy()
                        .with_allow_promotion(ydb.FeatureFlag.ENABLED)
                        .with_create_per_availability_zone(ydb.FeatureFlag.ENABLED)
                        .with_replicas_count(1)
                )
            )
            session.create_table(
                os.path.join(self._database, self._table_name),
                ydb.TableDescription()
                    .with_column(ydb.Column('to_id', ydb.OptionalType(ydb.PrimitiveType.Uint32)))
                    .with_column(ydb.Column('national_version', ydb.OptionalType(ydb.PrimitiveType.Utf8)))
                    .with_column(ydb.Column('nearest_city_ids', ydb.OptionalType(ydb.PrimitiveType.Utf8)))
                    .with_primary_keys(*primary_key)
                    .with_profile(profile)
            )

        with self._session_manager.get_session_pool() as session_pool:
            session_pool.retry_operation_sync(callee)

    def replace_batch(self, batch: List[CityToNearestCities]) -> None:
        query = """
        --!syntax_v1
        PRAGMA TablePathPrefix("{path}");

        DECLARE $data AS List<Struct<
            to_id: Uint32,
            national_version: Utf8,
            nearest_city_ids: Utf8
        >>;

        REPLACE INTO {table}
        SELECT * FROM AS_TABLE($data);
        """.format(path=self._database, table=self._table_name)

        def callee(session: ydb.table.Session) -> None:
            prepared_query = session.prepare(query)
            session.transaction(ydb.SerializableReadWrite()).execute(
                prepared_query,
                commit_tx=True,
                parameters={
                    '$data': batch,
                },
                settings=ydb.table.settings.BaseRequestSettings().with_operation_timeout(5),
            )

        with self._session_manager.get_session_pool() as session_pool:
            session_pool.retry_operation_sync(callee)

    def count(self) -> int:
        query = """
        --!syntax_v1
        PRAGMA TablePathPrefix("{path}");

        SELECT COUNT(*) as c FROM {table};
        """.format(path=self._database, table=self._table_name)

        def callee(session: ydb.table.Session) -> int:
            result_sets = session.transaction(ydb.SerializableReadWrite()).execute(
                query,
                commit_tx=True,
                settings=ydb.table.settings.BaseRequestSettings().with_operation_timeout(5),
            )
            return result_sets[0].rows[0]['c']

        with self._session_manager.get_session_pool() as session_pool:
            return session_pool.retry_operation_sync(callee)
