# -*- coding: utf-8 -*-
import os
from datetime import date, timedelta
from typing import List

import ydb

from travel.avia.library.python.ydb.session_manager import YdbSessionManager
from travel.avia.avia_statistics.landing_routes import LandingRoute


class FlightsTable(object):
    def __init__(self, ydb_session_manager, database, table_name):
        # type: (YdbSessionManager, str, str) -> None
        self._session_manager = ydb_session_manager
        self._database = database
        self._table_name = table_name

    def create_if_doesnt_exist(self):
        def callee(session):
            # type: (ydb.table.Session) -> None
            primary_key = [
                'from_id', 'to_id', 'company_id', 'flight_number', 'national_version', 'departure_date',
            ]
            profile = (
                ydb.TableProfile()
                    .with_replication_policy(
                    ydb.ReplicationPolicy()
                        .with_allow_promotion(ydb.FeatureFlag.ENABLED)
                        .with_create_per_availability_zone(ydb.FeatureFlag.ENABLED)
                        .with_replicas_count(1)
                )
            )
            session.create_table(
                os.path.join(self._database, self._table_name),
                ydb.TableDescription()
                    .with_column(ydb.Column('from_id', ydb.OptionalType(ydb.PrimitiveType.Uint32)))
                    .with_column(ydb.Column('to_id', ydb.OptionalType(ydb.PrimitiveType.Uint32)))
                    .with_column(ydb.Column('company_id', ydb.OptionalType(ydb.PrimitiveType.Uint32)))
                    .with_column(ydb.Column('flight_number', ydb.OptionalType(ydb.PrimitiveType.Utf8)))
                    .with_column(ydb.Column('national_version', ydb.OptionalType(ydb.PrimitiveType.Utf8)))
                    .with_column(ydb.Column('departure_date', ydb.OptionalType(ydb.PrimitiveType.Uint32)))
                    .with_primary_keys(*primary_key)
                    .with_index(ydb.TableIndex('departure_date_index').with_index_columns('departure_date'))
                    .with_profile(profile)
            )

        with self._session_manager.get_session_pool() as session_pool:
            session_pool.retry_operation_sync(callee)

    def replace_batch(self, batch):
        query = """
        PRAGMA TablePathPrefix("{path}");

        DECLARE $data AS "List<Struct<
            from_id: Uint32,
            to_id: Uint32,
            company_id: Uint32,
            flight_number: Utf8,
            national_version: Utf8,
            departure_date: Uint32>>";

        REPLACE INTO {table}
        SELECT * FROM AS_TABLE($data);
        """.format(path=self._database, table=self._table_name)

        def callee(session):
            # type: (ydb.table.Session) -> None
            prepared_query = session.prepare(query)
            session.transaction(ydb.SerializableReadWrite()).execute(
                prepared_query,
                commit_tx=True,
                parameters={
                    '$data': batch,
                },
                settings=ydb.table.settings.BaseRequestSettings().with_operation_timeout(5),
            )

        with self._session_manager.get_session_pool() as session_pool:
            session_pool.retry_operation_sync(callee)

    def delete_old(self, expired_date, timeout=120):
        # type: (date, int) -> None
        query = """
        PRAGMA TablePathPrefix("{path}");

        DECLARE $date as Uint32;

        DELETE FROM {table}
        WHERE departure_date = $date;
        """.format(path=self._database, table=self._table_name)

        def callee(session):
            # type: (ydb.table.Session) -> None
            prepared_query = session.prepare(query)
            session.transaction(ydb.SerializableReadWrite()).execute(
                prepared_query,
                commit_tx=True,
                parameters={
                    '$date': int(expired_date.strftime('%Y%m%d')),
                },
                settings=ydb.table.settings.BaseRequestSettings().with_operation_timeout(timeout),
            )

        with self._session_manager.get_session_pool() as session_pool:
            for i in range(7):
                expired_date -= timedelta(days=1)
                session_pool.retry_operation_sync(callee)

    def get_top_airlines_by_route(self, route, limit=5):
        # type: (LandingRoute, int) -> List[int]
        query = """
        DECLARE $today AS Uint32;
        DECLARE $from_id AS Uint32;
        DECLARE $to_id AS Uint32;
        DECLARE $national_version AS Utf8;
        DECLARE $limit AS Uint32;

        $grouped_by_companies = (
            select
                from_id,
                to_id,
                national_version,
                company_id,
                count(distinct flight_number) as flights_count,
                AsTuple(company_id, count(distinct flight_number)) as pair
            from flights
            where
                from_id = $from_id
                and to_id = $to_id
                and national_version = $national_version
                and departure_date >= $today
                and departure_date <= $today + 30
            group by (from_id, to_id, national_version, company_id)
            order by flights_count desc
            LIMIT $limit
        );


        select * from (
            select ListMap(ListTake(AGGREGATE_LIST(pair), $limit), ($x) -> { RETURN $x.0; }) as popular_companies
            from $grouped_by_companies
            group by (from_id, to_id, national_version)
        )
        flatten by popular_companies;
        """

        def callee(session):
            # type: (ydb.table.Session) -> List[int]
            prepared_query = session.prepare(query)
            result_sets = session.transaction(ydb.SerializableReadWrite()).execute(
                prepared_query,
                commit_tx=True,
                parameters={
                    '$today': int(date.today().strftime('%Y%m%d')),
                    '$limit': limit,
                    '$from_id': route.from_id,
                    '$to_id': route.to_id,
                    '$national_version': route.national_version,
                },
            )
            return [r['popular_companies'] for r in result_sets[0].rows]

        with self._session_manager.get_session_pool() as session_pool:
            return session_pool.retry_operation_sync(callee)


class Flight(object):
    __slots__ = ('from_id', 'to_id', 'company_id', 'flight_number', 'national_version', 'departure_date',)

    def __init__(self, from_id, to_id, company_id, flight_number, national_version, departure_date):
        self.from_id = from_id
        self.to_id = to_id
        self.company_id = company_id
        self.flight_number = flight_number
        self.national_version = national_version
        self.departure_date = departure_date

    def __str__(self):
        return str({s: getattr(self, s) for s in self.__slots__})
