# -*- coding: utf-8 -*-
import os
import json
from typing import List, Optional

import ydb

from travel.avia.library.python.ydb.session_manager import YdbSessionManager


class RouteInfosTable(object):
    def __init__(self, ydb_session_manager, database, table_name):
        # type: (YdbSessionManager, str, str) -> None
        self._session_manager = ydb_session_manager
        self._database = database
        self._table_name = table_name

    def create_if_doesnt_exist(self):
        def callee(session):
            # type: (ydb.table.Session) -> None
            primary_key = ['from_id', 'to_id']
            profile = (
                ydb.TableProfile()
                    .with_replication_policy(
                    ydb.ReplicationPolicy()
                        .with_allow_promotion(ydb.FeatureFlag.ENABLED)
                        .with_create_per_availability_zone(ydb.FeatureFlag.ENABLED)
                        .with_replicas_count(1)
                )
            )
            session.create_table(
                os.path.join(self._database, self._table_name),
                ydb.TableDescription()
                    .with_column(ydb.Column('from_id', ydb.OptionalType(ydb.PrimitiveType.Uint32)))
                    .with_column(ydb.Column('to_id', ydb.OptionalType(ydb.PrimitiveType.Uint32)))
                    .with_column(ydb.Column('distance', ydb.OptionalType(ydb.PrimitiveType.Uint32)))
                    .with_column(ydb.Column('duration', ydb.OptionalType(ydb.PrimitiveType.Uint32)))
                    .with_column(ydb.Column('from_airports', ydb.OptionalType(ydb.PrimitiveType.String)))
                    .with_column(ydb.Column('to_airports', ydb.OptionalType(ydb.PrimitiveType.String)))
                    .with_primary_keys(*primary_key)
                    .with_profile(profile)
            )

        with self._session_manager.get_session_pool() as session_pool:
            session_pool.retry_operation_sync(callee)

    def replace_batch(self, batch):
        query = """
        PRAGMA TablePathPrefix("{path}");

        DECLARE $data AS "List<Struct<
            from_id: Uint32,
            to_id: Uint32,
            distance: Uint32?,
            duration: Uint32?,
            from_airports: String,
            to_airports: String>>";

        REPLACE INTO {table}
        SELECT * FROM AS_TABLE($data);
        """.format(path=self._database, table=self._table_name)

        def callee(session):
            # type: (ydb.table.Session) -> None
            prepared_query = session.prepare(query)
            session.transaction(ydb.SerializableReadWrite()).execute(
                prepared_query,
                commit_tx=True,
                parameters={
                    '$data': batch,
                },
                settings=ydb.table.settings.BaseRequestSettings().with_operation_timeout(5),
            )

        with self._session_manager.get_session_pool() as session_pool:
            session_pool.retry_operation_sync(callee)


class RouteInfo(object):
    __slots__ = ('from_id', 'to_id', 'distance', 'duration', 'from_airports', 'to_airports',)

    def __init__(self, from_id, to_id, distance, duration, from_airports, to_airports):
        # type: (int, int, int, Optional[int], List[int], List[int]) -> None
        self.from_id = from_id
        self.to_id = to_id
        self.distance = distance
        self.duration = duration
        self.from_airports = json.dumps(from_airports)
        self.to_airports = json.dumps(to_airports)

    def __str__(self):
        return str({s: getattr(self, s) for s in self.__slots__})
