# -*- coding: utf-8 -*-
import logging
import os
import time

import ydb
from django.conf import settings
from travel.avia.library.python.ticket_daemon.date import unixtime
from travel.avia.library.python.ticket_daemon.decorators import log_elapsed_time
from travel.avia.library.python.ticket_daemon.ydb.django import utils as django_ydb_utils
from travel.avia.library.python.ticket_daemon.ydb.utils import (
    to_days, get_klass_id, passengers_integer_key, ensure_path_exists
)

logger = logging.getLogger(__name__)

DRIVER_CONFIG = settings.WIZARD_DRIVER_CONFIG
WIZARD_RESULTS_PARTNER_TABLE_NAME = 'wizard_results_by_partner'


def create_tables(session_pool, path):
    def callee(session):
        primary_key = [
            'point_from', 'point_to', 'klass', 'passengers',
            'national_version', 'date_backward', 'date_forward',
        ]
        profile = (
            ydb.TableProfile()
                .with_replication_policy(
                ydb.ReplicationPolicy()
                    .with_allow_promotion(ydb.FeatureFlag.ENABLED)
                    .with_create_per_availability_zone(ydb.FeatureFlag.ENABLED)
                    .with_replicas_count(1)
            )
        )

        session.create_table(
            os.path.join(path, WIZARD_RESULTS_PARTNER_TABLE_NAME),
            ydb.TableDescription()
                .with_column(ydb.Column('point_from', ydb.OptionalType(ydb.PrimitiveType.Utf8)))
                .with_column(ydb.Column('point_to', ydb.OptionalType(ydb.PrimitiveType.Utf8)))
                .with_column(ydb.Column('klass', ydb.OptionalType(ydb.PrimitiveType.Uint8)))
                .with_column(ydb.Column('passengers', ydb.OptionalType(ydb.PrimitiveType.Uint32)))
                .with_column(ydb.Column('national_version', ydb.OptionalType(ydb.PrimitiveType.Utf8)))
                .with_column(ydb.Column('date_backward', ydb.OptionalType(ydb.PrimitiveType.Uint32)))
                .with_column(ydb.Column('date_forward', ydb.OptionalType(ydb.PrimitiveType.Uint32)))
                .with_column(ydb.Column('partner_code', ydb.OptionalType(ydb.PrimitiveType.Utf8)))
                .with_column(ydb.Column('expires_at', ydb.OptionalType(ydb.PrimitiveType.Uint64)))  # unixtime
                .with_column(
                ydb.Column('expires_at_by_partner', ydb.OptionalType(ydb.PrimitiveType.Uint64)))  # unixtime
                .with_column(ydb.Column('min_price', ydb.OptionalType(ydb.PrimitiveType.Int32)))
                .with_column(ydb.Column('search_result', ydb.OptionalType(ydb.PrimitiveType.String)))
                .with_column(ydb.Column('filter_state', ydb.OptionalType(ydb.PrimitiveType.String)))
                .with_column(ydb.Column('ttl_expires_at', ydb.OptionalType(ydb.PrimitiveType.Datetime)))
                .with_primary_keys(*primary_key)
                .with_profile(profile)
                .with_ttl(ydb.TtlSettings().with_date_type_column('ttl_expires_at'))
        )

    return session_pool.retry_operation_sync(callee)


def init():
    path = ''
    with ydb.Driver(DRIVER_CONFIG) as driver:
        ensure_path_exists(driver, DRIVER_CONFIG.database, path)
        with ydb.SessionPool(driver, size=10) as session_pool:
            create_tables(session_pool, DRIVER_CONFIG.database)


def delete():
    def _drop_table(session_pool, path):
        def callee(session):
            session.drop_table(os.path.join(path, WIZARD_RESULTS_PARTNER_TABLE_NAME))

        return session_pool.retry_operation_sync(callee)

    with ydb.Driver(DRIVER_CONFIG) as driver:
        with ydb.SessionPool(driver, size=10) as session_pool:
            _drop_table(session_pool, DRIVER_CONFIG.database)


UPSERT_PARTNER_QUERY = """
    PRAGMA TablePathPrefix("{path}");

    DECLARE $partnerCode AS Utf8;
    DECLARE $pointFrom AS Utf8;
    DECLARE $pointTo AS Utf8;
    DECLARE $dateForward AS Uint32;
    DECLARE $dateBackward AS Uint32;
    DECLARE $klass AS Uint8;
    DECLARE $passengers AS Uint32;
    DECLARE $nationalVersion AS Utf8;
    DECLARE $search_result AS String;
    DECLARE $filter_state AS String;
    DECLARE $min_price AS Int32;
    DECLARE $expires_at AS Uint64;
    DECLARE $expires_at_by_partner AS Uint64;
    DECLARE $ttl_expires_at AS Datetime;

    UPSERT INTO {results} (
        partner_code, point_from, point_to, date_forward, date_backward, klass, passengers,
        national_version, search_result, filter_state,
        min_price, expires_at, expires_at_by_partner, ttl_expires_at
    )
    VALUES (
        $partnerCode, $pointFrom, $pointTo, $dateForward, $dateBackward, $klass, $passengers,
        $nationalVersion, $search_result, $filter_state,
        $min_price, $expires_at, $expires_at_by_partner, $ttl_expires_at
    );
    """


def upsert_partner_prepared(
    session_pool,
    path,
    q,
    search_result,
    filter_state,
    min_price,
    expires_at,
    expires_at_by_partner,
    ttl_expires_at,
    partner_code
):
    query = UPSERT_PARTNER_QUERY.format(
        path=path,
        results=WIZARD_RESULTS_PARTNER_TABLE_NAME,
    )

    if ttl_expires_at:
        ttl_expires_at = int(time.mktime(ttl_expires_at.utctimetuple()))

    def callee(session):
        prepared_query = session.prepare(query)
        session.transaction(ydb.SerializableReadWrite()).execute(
            prepared_query,
            commit_tx=True,
            parameters={
                '$partnerCode': partner_code,
                '$pointFrom': q.point_from.point_key,
                '$pointTo': q.point_to.point_key,
                '$dateForward': to_days(q.date_forward),
                '$dateBackward': to_days(q.date_backward),
                '$klass': get_klass_id(q.klass),
                '$passengers': passengers_integer_key(q),
                '$nationalVersion': q.national_version,
                '$search_result': search_result,
                '$filter_state': filter_state,
                '$min_price': min_price,
                '$expires_at': expires_at,
                '$expires_at_by_partner': expires_at_by_partner,
                '$ttl_expires_at': ttl_expires_at,
            }
        )

    return session_pool.retry_operation_sync(callee)


SELECT_PARTNER_QUERY = """
    PRAGMA TablePathPrefix("{path}");

    DECLARE $partnerCode AS Utf8;
    DECLARE $pointFrom AS Utf8;
    DECLARE $pointTo AS Utf8;
    DECLARE $klass AS Uint8;
    DECLARE $dateForward AS Uint32;
    DECLARE $dateBackward AS Uint32;
    DECLARE $passengers AS Uint32;
    DECLARE $nationalVersion AS Utf8;
    DECLARE $now AS Uint64;

    SELECT search_result
    FROM {results}
    WHERE
        point_from = $pointFrom
        AND point_to = $pointTo
        AND klass = $klass
        AND passengers = $passengers
        AND national_version = $nationalVersion
        AND date_forward = $dateForward
        AND date_backward = $dateBackward
        AND partner_code = $partnerCode
        AND expires_at > $now
    ;
    """


def select_partner_prepared(session_pool, path, q, partner_code):
    query = SELECT_PARTNER_QUERY.format(path=path, results=WIZARD_RESULTS_PARTNER_TABLE_NAME)

    def callee(session):
        prepared_query = session.prepare(query)
        result_sets = session.transaction(ydb.SerializableReadWrite()).execute(
            prepared_query,
            commit_tx=True,
            parameters={
                '$partnerCode': partner_code,
                '$pointFrom': q.point_from.point_key,
                '$pointTo': q.point_to.point_key,
                '$klass': get_klass_id(q.klass),
                '$dateForward': to_days(q.date_forward),
                '$dateBackward': to_days(q.date_backward),
                '$passengers': passengers_integer_key(q),
                '$nationalVersion': q.national_version,
                '$now': unixtime(),
            }
        )

        return result_sets[0].rows

    return session_pool.retry_operation_sync(callee)


SELECT_ALL_PARTNER_QUERY = """
    PRAGMA TablePathPrefix("{path}");

    DECLARE $pointFrom AS Utf8;
    DECLARE $pointTo AS Utf8;
    DECLARE $klass AS Uint8;
    DECLARE $dateForward AS Uint32;
    DECLARE $dateBackward AS Uint32;
    DECLARE $passengers AS Uint32;
    DECLARE $nationalVersion AS Utf8;
    DECLARE $now AS Uint64;
    DECLARE $partnerCodes AS 'List<Utf8>';

    SELECT partner_code, search_result
    FROM {results}
    WHERE
        point_from = $pointFrom
        AND point_to = $pointTo
        AND klass = $klass
        AND passengers = $passengers
        AND national_version = $nationalVersion
        AND partner_code IN $partnerCodes
        AND date_forward = $dateForward
        AND date_backward = $dateBackward
        AND expires_at > $now
    ;
    """


def select_all_partner_prepared(session_pool, path, q, partner_codes):
    query = SELECT_ALL_PARTNER_QUERY.format(path=path, results=WIZARD_RESULTS_PARTNER_TABLE_NAME)

    def callee(session):
        prepared_query = session.prepare(query)
        result_sets = session.transaction(ydb.SerializableReadWrite()).execute(
            prepared_query,
            commit_tx=True,
            parameters={
                '$pointFrom': q.point_from.point_key,
                '$pointTo': q.point_to.point_key,
                '$klass': get_klass_id(q.klass),
                '$dateForward': to_days(q.date_forward),
                '$dateBackward': to_days(q.date_backward),
                '$passengers': passengers_integer_key(q),
                '$nationalVersion': q.national_version,
                '$now': unixtime(),
                '$partnerCodes': partner_codes,
            }
        )
        return result_sets[0].rows

    return session_pool.retry_operation_sync(callee)


class WizardCacheByPartner(object):
    DATABASE = DRIVER_CONFIG.database

    @log_elapsed_time(logger, statsd_prefix='ydb_wizard.set.elapsed')
    def set(self, query, search_result, filter_state, min_price, expires_at, expires_at_by_partner, ttl_expires_at, partner_code):
        with django_ydb_utils.get_session_pool() as session_pool:
            upsert_partner_prepared(
                session_pool,
                self.DATABASE,
                query,
                search_result,
                filter_state,
                min_price,
                expires_at,
                expires_at_by_partner,
                ttl_expires_at,
                partner_code
            )

    @log_elapsed_time(logger, statsd_prefix='ydb_wizard.get.elapsed')
    def get(self, query, partner_code):
        with django_ydb_utils.get_session_pool() as session_pool:
            return select_partner_prepared(session_pool, self.DATABASE, query, partner_code)

    @log_elapsed_time(logger, statsd_prefix='ydb_wizard.get.elapsed')
    def get_all(self, query, partner_codes):
        with django_ydb_utils.get_session_pool() as session_pool:
            return select_all_partner_prepared(session_pool, self.DATABASE, query, partner_codes)
