# -*- coding: utf-8 -*-
import logging
import os
import random

from django.conf import settings
import ydb

from travel.avia.ticket_daemon_api.jsonrpc.lib import date as date_utils
from travel.avia.library.python.ticket_daemon.decorators import log_elapsed_time
from travel.avia.library.python.ticket_daemon.ydb import utils as ydb_utils
from travel.avia.library.python.ticket_daemon.ydb.django import utils as django_ydb_utils

logger = logging.getLogger(__name__)

RESULTS_TABLE_NAME = 'results'
_RESULTS_TTL_TABLE_NAME = 'results_expiration_queue'
RESULTS_TTL_TABLES_COUNT = 10

REDIRECT_DATA_TABLE_NAME = 'redirect_data'


def format_ttl_table_name(idx):
    return _RESULTS_TTL_TABLE_NAME + str(idx)


def delete():
    def _drop_table(session_pool, path):
        def callee(session):
            session.drop_table(os.path.join(path, RESULTS_TABLE_NAME))
            for idx in xrange(RESULTS_TTL_TABLES_COUNT):
                session.drop_table(os.path.join(path, format_ttl_table_name(idx)))

        return session_pool.retry_operation_sync(callee)

    with ydb.Driver(settings.DRIVER_CONFIG) as driver:
        with ydb.SessionPool(driver, size=10) as session_pool:
            _drop_table(session_pool, settings.DRIVER_CONFIG.database)


def create_tables(session_pool, path):
    def callee(session):
        primary_key = [
            'point_from', 'point_to', 'date_forward', 'date_backward', 'klass',
            'passengers', 'national_version', 'lang', 'partner_code',
        ]
        session.create_table(
            os.path.join(path, RESULTS_TABLE_NAME),
            ydb.TableDescription()
                .with_column(ydb.Column('point_from', ydb.OptionalType(ydb.PrimitiveType.Utf8)))
                .with_column(ydb.Column('point_to', ydb.OptionalType(ydb.PrimitiveType.Utf8)))
                .with_column(ydb.Column('date_forward', ydb.OptionalType(ydb.PrimitiveType.Uint32)))
                .with_column(ydb.Column('date_backward', ydb.OptionalType(ydb.PrimitiveType.Uint32)))
                .with_column(ydb.Column('klass', ydb.OptionalType(ydb.PrimitiveType.Uint8)))
                .with_column(ydb.Column('passengers', ydb.OptionalType(ydb.PrimitiveType.Uint32)))
                .with_column(ydb.Column('national_version', ydb.OptionalType(ydb.PrimitiveType.Utf8)))
                .with_column(ydb.Column('lang', ydb.OptionalType(ydb.PrimitiveType.Utf8)))
                .with_column(ydb.Column('partner_code', ydb.OptionalType(ydb.PrimitiveType.Utf8)))
                .with_column(ydb.Column('variants', ydb.OptionalType(ydb.PrimitiveType.String)))
                .with_column(ydb.Column('redirect_data', ydb.OptionalType(ydb.PrimitiveType.String)))
                .with_column(ydb.Column('created_at', ydb.OptionalType(ydb.PrimitiveType.Uint64)))  # unixtime
                .with_column(ydb.Column('expires_at', ydb.OptionalType(ydb.PrimitiveType.Uint64)))  # unixtime

                .with_primary_keys(*primary_key)
        )
        expiration_table_description = (
            ydb.TableDescription()
            .with_column(ydb.Column('expires_at', ydb.OptionalType(ydb.PrimitiveType.Uint64)))  # unixtime
            .with_column(ydb.Column('point_from', ydb.OptionalType(ydb.PrimitiveType.Utf8)))
            .with_column(ydb.Column('point_to', ydb.OptionalType(ydb.PrimitiveType.Utf8)))
            .with_column(ydb.Column('date_forward', ydb.OptionalType(ydb.PrimitiveType.Uint32)))
            .with_column(ydb.Column('date_backward', ydb.OptionalType(ydb.PrimitiveType.Uint32)))
            .with_column(ydb.Column('klass', ydb.OptionalType(ydb.PrimitiveType.Uint8)))
            .with_column(ydb.Column('passengers', ydb.OptionalType(ydb.PrimitiveType.Uint32)))
            .with_column(ydb.Column('national_version', ydb.OptionalType(ydb.PrimitiveType.Utf8)))
            .with_column(ydb.Column('lang', ydb.OptionalType(ydb.PrimitiveType.Utf8)))
            .with_column(ydb.Column('partner_code', ydb.OptionalType(ydb.PrimitiveType.Utf8)))

            .with_primary_keys('expires_at', *primary_key)
        )
        for idx in xrange(RESULTS_TTL_TABLES_COUNT):
            session.create_table(
                os.path.join(path, format_ttl_table_name(idx)),
                expiration_table_description
            )

    return session_pool.retry_operation_sync(callee)


def describe_table(session_pool, path, name):
    def callee(session):
        result = session.describe_table(os.path.join(path, name))
        print('\n> describe table: series')
        for column in result.columns:
            print('column, name: ', column.name, ',', str(column.type.item).strip())

    return session_pool.retry_operation_sync(callee)


def upsert_prepared(session_pool, path, q, partner_code, variants, redirect_data, meta,
                                                 store_time, redirect_data_store_time):
    fill_data_query = """PRAGMA TablePathPrefix("{path}");
        DECLARE $pointFrom AS Utf8;
        DECLARE $pointTo AS Utf8;
        DECLARE $dateForward AS Uint32;
        DECLARE $dateBackward AS Uint32;
        DECLARE $klass AS Uint8;
        DECLARE $passengers AS Uint32;
        DECLARE $nationalVersion AS Utf8;
        DECLARE $lang AS Utf8;
        DECLARE $partnerCode AS Utf8;
        DECLARE $variants AS String;
        DECLARE $redirectData AS String;
        DECLARE $meta AS String;
        DECLARE $created_at AS Uint64;
        DECLARE $expires_at AS Uint64;
        DECLARE $redirect_data_expires_at AS Datetime;
        UPSERT INTO {results} (
            point_from, point_to, date_forward, date_backward, klass, passengers,
            national_version, lang, partner_code, variants, meta, created_at, expires_at)
        VALUES (
            $pointFrom, $pointTo, $dateForward, $dateBackward, $klass, $passengers,
            $nationalVersion, $lang, $partnerCode, $variants, $meta, $created_at, $expires_at
        );
        UPSERT INTO {redirect_data} (
            point_from, point_to, date_forward, date_backward, klass, passengers,
            national_version, lang, partner_code, redirect_data, created_at, expires_at)
        VALUES (
            $pointFrom, $pointTo, $dateForward, $dateBackward, $klass, $passengers,
            $nationalVersion, $lang, $partnerCode, $redirectData, $created_at, $redirect_data_expires_at
        );
        UPSERT INTO {results_ttl} (
            expires_at, point_from, point_to, date_forward, date_backward, klass,
            passengers, national_version, lang, partner_code)
        VALUES (
            $expires_at, $pointFrom, $pointTo, $dateForward, $dateBackward, $klass,
            $passengers, $nationalVersion, $lang, $partnerCode
        );
    """.format(
        path=path,
        results=RESULTS_TABLE_NAME,
        redirect_data=REDIRECT_DATA_TABLE_NAME,
        results_ttl=format_ttl_table_name(random.randint(0, RESULTS_TTL_TABLES_COUNT - 1)))

    def callee(session):
        prepared_query = session.prepare(fill_data_query.format(path))
        session.transaction(ydb.SerializableReadWrite()).execute(
            prepared_query,
            commit_tx=True,
            parameters={
                '$pointFrom': q.point_from.point_key,
                '$pointTo': q.point_to.point_key,
                '$dateForward': ydb_utils.to_days(q.date_forward),
                '$dateBackward': ydb_utils.to_days(q.date_backward),
                '$klass': ydb_utils.get_klass_id(q.klass),
                '$passengers': ydb_utils.passengers_integer_key(q),
                '$nationalVersion': q.national_version,
                '$lang': 'any',
                '$partnerCode': partner_code,
                '$variants': variants,
                '$redirectData': redirect_data,
                '$meta': meta,
                '$created_at': date_utils.unixtime(),
                '$expires_at': date_utils.unixtime() + store_time,
                '$redirect_data_expires_at': date_utils.unixtime() + redirect_data_store_time,
            }
        )

    return session_pool.retry_operation_sync(callee)


def select_prepared_without_redirect_data(
    session_pool, path, q, partner_code,
    columns=('created_at', 'partner_code', 'variants', 'redirect_data')
):
    query = """
    PRAGMA TablePathPrefix("{path}");

    DECLARE $pointFrom AS Utf8;
    DECLARE $pointTo AS Utf8;
    DECLARE $dateForward AS Uint32;
    DECLARE $dateBackward AS Uint32;
    DECLARE $klass AS Uint8;
    DECLARE $passengers AS Uint32;
    DECLARE $nationalVersion AS Utf8;
    DECLARE $partnerCode AS Utf8;
    DECLARE $unixtime AS Uint32;

    SELECT {columns}
    FROM {results}
    WHERE
        point_from = $pointFrom
        AND point_to = $pointTo
        AND date_forward = $dateForward
        AND date_backward = $dateBackward
        AND klass = $klass
        AND passengers = $passengers
        AND national_version = $nationalVersion
        AND partner_code = $partnerCode
        AND expires_at > $unixtime
    ORDER BY created_at DESC
    LIMIT 1;
    """.format(path=path, results=RESULTS_TABLE_NAME, columns=', '.join(columns))

    def callee(session):
        prepared_query = session.prepare(query)
        result_sets = session.transaction(ydb.SerializableReadWrite()).execute(
            prepared_query,
            commit_tx=True,
            parameters={
                '$pointFrom': q.point_from.point_key,
                '$pointTo': q.point_to.point_key,
                '$dateForward': ydb_utils.to_days(q.date_forward),
                '$dateBackward': ydb_utils.to_days(q.date_backward),
                '$klass': ydb_utils.get_klass_id(q.klass),
                '$passengers': ydb_utils.passengers_integer_key(q),
                '$nationalVersion': q.national_version,
                '$partnerCode': partner_code,
                '$unixtime': date_utils.unixtime(),
            }
        )

        return result_sets[0].rows

    return session_pool.retry_operation_sync(callee)


def select_prepared(
    session_pool, path, q, partner_code,
    columns=('created_at', 'partner_code', 'variants', 'redirect_data')
):
    if 'redirect_data' not in columns:
        return select_prepared_without_redirect_data(session_pool, path, q, partner_code, columns)

    columns = (col for col in columns if col != 'redirect_data')

    query = """
    PRAGMA TablePathPrefix("{path}");

    DECLARE $pointFrom AS Utf8;
    DECLARE $pointTo AS Utf8;
    DECLARE $dateForward AS Uint32;
    DECLARE $dateBackward AS Uint32;
    DECLARE $klass AS Uint8;
    DECLARE $passengers AS Uint32;
    DECLARE $nationalVersion AS Utf8;
    DECLARE $partnerCode AS Utf8;
    DECLARE $unixtime AS Uint32;

    SELECT {columns}
    FROM {results}
    WHERE
        point_from = $pointFrom
        AND point_to = $pointTo
        AND date_forward = $dateForward
        AND date_backward = $dateBackward
        AND klass = $klass
        AND passengers = $passengers
        AND national_version = $nationalVersion
        AND partner_code = $partnerCode
        AND expires_at > $unixtime
    ORDER BY created_at DESC
    LIMIT 1;


    SELECT
        redirect_data
    FROM {redirect_data}
    WHERE
        point_from = $pointFrom
        AND point_to = $pointTo
        AND date_forward = $dateForward
        AND date_backward = $dateBackward
        AND klass = $klass
        AND passengers = $passengers
        AND national_version = $nationalVersion
        AND partner_code IN $partnerCodes
    LIMIT 1;
    """.format(path=path, results=RESULTS_TABLE_NAME, redirect_data=REDIRECT_DATA_TABLE_NAME, columns=', '.join(columns))

    def callee(session):
        prepared_query = session.prepare(query)
        result_sets = session.transaction(ydb.SerializableReadWrite()).execute(
            prepared_query,
            commit_tx=True,
            parameters={
                '$pointFrom': q.point_from.point_key,
                '$pointTo': q.point_to.point_key,
                '$dateForward': ydb_utils.to_days(q.date_forward),
                '$dateBackward': ydb_utils.to_days(q.date_backward),
                '$klass': ydb_utils.get_klass_id(q.klass),
                '$passengers': ydb_utils.passengers_integer_key(q),
                '$nationalVersion': q.national_version,
                '$partnerCode': partner_code,
                '$unixtime': date_utils.unixtime(),
            }
        )

        variants_rows = result_sets[0].rows
        redirect_data_rows = result_sets[1].rows
        if len(variants_rows) > 0 and len(redirect_data_rows) > 0 and 'redirect_data' in redirect_data_rows[0]:
            variants_rows[0]['redirect_data'] = redirect_data_rows[0]['redirect_data']
        return variants_rows

    return session_pool.retry_operation_sync(callee)


def select_many_prepared(
    session_pool, path, q, partner_codes,
    columns=('created_at', 'partner_code', 'variants')
):
    if 'redirect_data' not in columns:
        return select_many_prepared_without_redirect_data(session_pool, path, q, partner_codes, columns)

    columns = tuple(col for col in columns if col != 'redirect_data')

    query = """
    PRAGMA TablePathPrefix("{path}");

    DECLARE $pointFrom AS Utf8;
    DECLARE $pointTo AS Utf8;
    DECLARE $dateForward AS Uint32;
    DECLARE $dateBackward AS Uint32;
    DECLARE $klass AS Uint8;
    DECLARE $passengers AS Uint32;
    DECLARE $nationalVersion AS Utf8;
    DECLARE $partnerCodes AS 'List<Utf8>';
    DECLARE $unixtime AS Uint32;

    SELECT {columns}
    FROM {results}
    WHERE
        point_from = $pointFrom
        AND point_to = $pointTo
        AND date_forward = $dateForward
        AND date_backward = $dateBackward
        AND klass = $klass
        AND passengers = $passengers
        AND national_version = $nationalVersion
        AND partner_code IN $partnerCodes
        AND expires_at > $unixtime;

    SELECT
        redirect_data,
        partner_code
    FROM {redirect_data}
    WHERE
        point_from = $pointFrom
        AND point_to = $pointTo
        AND date_forward = $dateForward
        AND date_backward = $dateBackward
        AND klass = $klass
        AND passengers = $passengers
        AND national_version = $nationalVersion
        AND partner_code IN $partnerCodes;
    """.format(path=path, results=RESULTS_TABLE_NAME, redirect_data=REDIRECT_DATA_TABLE_NAME, columns=', '.join(columns))

    def callee(session):
        prepared_query = session.prepare(query)
        result_sets = session.transaction(ydb.SerializableReadWrite()).execute(
            prepared_query,
            commit_tx=True,
            parameters={
                '$pointFrom': q.point_from.point_key,
                '$pointTo': q.point_to.point_key,
                '$dateForward': ydb_utils.to_days(q.date_forward),
                '$dateBackward': ydb_utils.to_days(q.date_backward),
                '$klass': ydb_utils.get_klass_id(q.klass),
                '$passengers': ydb_utils.passengers_integer_key(q),
                '$nationalVersion': q.national_version,
                '$partnerCodes': partner_codes,
                '$unixtime': date_utils.unixtime(),
            }
        )

        variants = result_sets[0].rows
        redirect_data_by_partner = {row['partner_code']: row['redirect_data'] for row in result_sets[1].rows}
        for v in variants:
            v['redirect_data'] = redirect_data_by_partner.get(v['partner_code'])
        return variants

    return session_pool.retry_operation_sync(callee)


def select_many_prepared_without_redirect_data(
    session_pool, path, q, partner_codes,
    columns=('created_at', 'partner_code', 'variants')
):
    query = """
    PRAGMA TablePathPrefix("{path}");

    DECLARE $pointFrom AS Utf8;
    DECLARE $pointTo AS Utf8;
    DECLARE $dateForward AS Uint32;
    DECLARE $dateBackward AS Uint32;
    DECLARE $klass AS Uint8;
    DECLARE $passengers AS Uint32;
    DECLARE $nationalVersion AS Utf8;
    DECLARE $partnerCodes AS 'List<Utf8>';
    DECLARE $unixtime AS Uint32;

    SELECT {columns}
    FROM {results}
    WHERE
        point_from = $pointFrom
        AND point_to = $pointTo
        AND date_forward = $dateForward
        AND date_backward = $dateBackward
        AND klass = $klass
        AND passengers = $passengers
        AND national_version = $nationalVersion
        AND partner_code IN $partnerCodes
        AND expires_at > $unixtime;
    """.format(path=path, results=RESULTS_TABLE_NAME, columns=', '.join(columns))

    def callee(session):
        prepared_query = session.prepare(query)
        result_sets = session.transaction(ydb.SerializableReadWrite()).execute(
            prepared_query,
            commit_tx=True,
            parameters={
                '$pointFrom': q.point_from.point_key,
                '$pointTo': q.point_to.point_key,
                '$dateForward': ydb_utils.to_days(q.date_forward),
                '$dateBackward': ydb_utils.to_days(q.date_backward),
                '$klass': ydb_utils.get_klass_id(q.klass),
                '$passengers': ydb_utils.passengers_integer_key(q),
                '$nationalVersion': q.national_version,
                '$partnerCodes': partner_codes,
                '$unixtime': date_utils.unixtime(),
            }
        )

        return result_sets[0].rows

    return session_pool.retry_operation_sync(callee)


def init():
    path = ''
    with ydb.Driver(settings.DRIVER_CONFIG) as driver:
        ydb_utils.ensure_path_exists(driver, settings.DRIVER_CONFIG.database, path)
        with ydb.SessionPool(driver, size=10) as session_pool:
            create_tables(session_pool, settings.DRIVER_CONFIG.database)


class ServiceCache(object):
    DATABASE = settings.DRIVER_CONFIG.database

    @staticmethod
    @log_elapsed_time(logger, statsd_prefix='ydb.set.elapsed')
    def set(query, p_code, variants, redirect_data, meta, store_time, redirect_data_store_time):
        """
        :param Query q:
        :param basestring p_code:
        :param str variants:
        :param str redirect_data:
        :param str meta:
        :param int store_time:
        :param int redirect_data_store_time:
        """
        with django_ydb_utils.get_session_pool() as session_pool:
            upsert_prepared(
                session_pool, ServiceCache.DATABASE, query, p_code, variants, redirect_data, meta, store_time,
                redirect_data_store_time,
            )

    @staticmethod
    @log_elapsed_time(logger, statsd_prefix='ydb.get.elapsed')
    def get(query, p_code, **kwargs):
        """

        :param Query query:
        :param basestring p_code:
        """
        with django_ydb_utils.get_session_pool() as session_pool:
            return select_prepared(session_pool, ServiceCache.DATABASE, query, p_code, **kwargs)

    @staticmethod
    @log_elapsed_time(logger, statsd_prefix='ydb.get_many.elapsed')
    def get_many(query, p_codes, **kwargs):
        """

        :param Query query:
        :param typing.List[basestring] p_codes:
        """
        with django_ydb_utils.get_session_pool() as session_pool:
            return select_many_prepared(session_pool, ServiceCache.DATABASE, query, p_codes, **kwargs)
