# -*- coding: utf-8 -*-
from __future__ import unicode_literals

import logging
from typing import Iterable, Optional, Dict, Any

import ydb

from travel.avia.library.python.ticket_daemon.ydb import utils as ydb_utils
from travel.avia.library.python.ticket_daemon import date as date_util

log = logging.getLogger(__name__)


def create_tables(session, table_path):
    # type: (ydb.Session, basestring)->ydb.Operation
    primary_key = [
        'point_from', 'point_to', 'date_forward', 'date_backward', 'klass',
        'passengers', 'national_version', 'partner_code',
    ]
    log.info('Create banned variants table: %s', table_path)
    return session.create_table(
        table_path,
        ydb.TableDescription()
            .with_column(ydb.Column('point_from', ydb.OptionalType(ydb.PrimitiveType.Utf8)))
            .with_column(ydb.Column('point_to', ydb.OptionalType(ydb.PrimitiveType.Utf8)))
            .with_column(ydb.Column('date_forward', ydb.OptionalType(ydb.PrimitiveType.Uint32)))
            .with_column(ydb.Column('date_backward', ydb.OptionalType(ydb.PrimitiveType.Uint32)))
            .with_column(ydb.Column('klass', ydb.OptionalType(ydb.PrimitiveType.Uint8)))
            .with_column(ydb.Column('passengers', ydb.OptionalType(ydb.PrimitiveType.Uint32)))
            .with_column(ydb.Column('national_version', ydb.OptionalType(ydb.PrimitiveType.Utf8)))
            .with_column(ydb.Column('partner_code', ydb.OptionalType(ydb.PrimitiveType.Utf8)))
            .with_column(ydb.Column('payload', ydb.OptionalType(ydb.PrimitiveType.String)))
            .with_column(ydb.Column('created_at', ydb.OptionalType(ydb.PrimitiveType.Uint64)))  # unixtime
            .with_column(ydb.Column('expires_at', ydb.OptionalType(ydb.PrimitiveType.Uint64)))  # unixtime

            .with_primary_keys(*primary_key)
            .with_ttl(ydb.TtlSettings().with_value_since_unix_epoch('expires_at', ydb.ColumnUnit.UNIT_SECONDS))
    )


def upsert(session, table_path, q, partner_code, payload, ttl_in_seconds):
    # type: (ydb.Session, basestring, Query, basestring, bytes, int) -> ydb.convert.ResultSets
    query = """
        DECLARE $pointFrom AS Utf8;
        DECLARE $pointTo AS Utf8;
        DECLARE $dateForward AS Uint32;
        DECLARE $dateBackward AS Uint32;
        DECLARE $klass AS Uint8;
        DECLARE $passengers AS Uint32;
        DECLARE $nationalVersion AS Utf8;
        DECLARE $partnerCode AS Utf8;
        DECLARE $payload AS String;
        DECLARE $created_at AS Uint64;
        DECLARE $expires_at AS Uint64;

        UPSERT INTO `{table_path}` (
            point_from, point_to, date_forward, date_backward, klass, passengers,
            national_version, partner_code, payload, created_at, expires_at)
        VALUES (
            $pointFrom, $pointTo, $dateForward, $dateBackward, $klass, $passengers,
            $nationalVersion, $partnerCode, $payload, $created_at, $expires_at
        );
    """.format(
        table_path=table_path,
    )
    prepared_query = session.prepare(query)
    return session.transaction(ydb.SerializableReadWrite()).execute(
        prepared_query,
        commit_tx=True,
        parameters={
            '$pointFrom': q.point_from.point_key,
            '$pointTo': q.point_to.point_key,
            '$dateForward': ydb_utils.to_days(q.date_forward),
            '$dateBackward': ydb_utils.to_days(q.date_backward),
            '$klass': ydb_utils.get_klass_id(q.klass),
            '$passengers': ydb_utils.passengers_integer_key(q),
            '$nationalVersion': q.national_version,
            '$partnerCode': partner_code,
            '$payload': payload,
            '$created_at': date_util.unixtime(),
            '$expires_at': date_util.unixtime() + ttl_in_seconds,
        }
    )


def select_prepared(
    session, table_path, q, partner_code,
    columns=('payload',)
):
    # type: (ydb.Session, basestring, Query, basestring, Iterable[basestring]) -> Optional[Dict[Any]]
    columns = set(columns)
    columns.add('created_at')
    query = """
    DECLARE $pointFrom AS Utf8;
    DECLARE $pointTo AS Utf8;
    DECLARE $dateForward AS Uint32;
    DECLARE $dateBackward AS Uint32;
    DECLARE $klass AS Uint8;
    DECLARE $passengers AS Uint32;
    DECLARE $nationalVersion AS Utf8;
    DECLARE $partnerCode AS Utf8;
    DECLARE $unixtime AS Uint64;

    SELECT {columns}
    FROM `{table_path}`
    WHERE
        point_from = $pointFrom
        AND point_to = $pointTo
        AND date_forward = $dateForward
        AND date_backward = $dateBackward
        AND klass = $klass
        AND passengers = $passengers
        AND national_version = $nationalVersion
        AND partner_code = $partnerCode
        AND expires_at > $unixtime
    ORDER BY `created_at`
    LIMIT 1;
    """.format(table_path=table_path, columns=', '.join(columns))

    prepared_query = session.prepare(query)  # type: ydb.types.DataQuery
    result_sets = session.transaction(ydb.SerializableReadWrite()).execute(
        prepared_query,
        commit_tx=True,
        parameters={
            '$pointFrom': q.point_from.point_key,
            '$pointTo': q.point_to.point_key,
            '$dateForward': ydb_utils.to_days(q.date_forward),
            '$dateBackward': ydb_utils.to_days(q.date_backward),
            '$klass': ydb_utils.get_klass_id(q.klass),
            '$passengers': ydb_utils.passengers_integer_key(q),
            '$nationalVersion': q.national_version,
            '$partnerCode': partner_code,
            '$unixtime': date_util.unixtime(),
        }
    )
    try:
        return result_sets[0].rows
    except (TypeError, IndexError):
        return None
