# -*- coding: utf-8 -*-
from __future__ import unicode_literals

import logging
import os
import typing
import ujson
from typing import Any, Type

import ydb

from travel.avia.library.python.ticket_daemon.ydb.banned_variants.data import _DataManager, BannedVariantParams
from travel.avia.library.python.ticket_daemon.ydb.banned_variants.errors import BannedVariantsPartnerNotFound
from travel.avia.library.python.ticket_daemon.ydb.banned_variants import operations, data
from travel.avia.library.python.ydb.session_manager import YdbSessionManager

log = logging.getLogger(__name__)


class BannedVariantsCache(object):
    def __init__(self, session_manager, relative_table_path='banned_variants'):
        # type: (YdbSessionManager, basestring)->None
        self.session_manager = session_manager
        self.table_path = os.path.join(session_manager.database, relative_table_path)
        log.info('Creating BannedVariantsCache(%s)', self.table_path)

    def create_tables(self):
        # type: ()->ydb.Operation
        with self.session_manager.get_session_pool() as session_pool:
            def callee(session):
                return operations.create_tables(session, self.table_path)

            return session_pool.retry_operation_sync(callee)

    def set(self, query, partner_code, payload, ttl_in_seconds):
        # type: (Query, basestring, bytes, int)->None
        with self.session_manager.get_session_pool() as session_pool:
            def callee(session):
                return operations.upsert(session, self.table_path, query, partner_code, payload, ttl_in_seconds)

            return session_pool.retry_operation_sync(callee=callee)

    def get(self, query, partner_code, columns=('payload',)):
        # type: (Query, basestring, typing.Iterable[basestring])->Dict
        with self.session_manager.get_session_pool() as session_pool:
            def callee(session):
                try:
                    return operations.select_prepared(session, self.table_path, query, partner_code, columns=columns)[0]
                except IndexError:
                    return None

            return session_pool.retry_operation_sync(callee=callee)

    def update_banned_variants(self, query, partner_code, variant_info):
        # type: (Query, basestring, BannedVariantParams)->Any
        try:
            variant_info_updater_cls = data.variants_updater_by_partner_code[partner_code]  # type: Type[_DataManager]
            variant_info_updater = variant_info_updater_cls()  # type: _DataManager
        except KeyError:
            raise BannedVariantsPartnerNotFound(partner_code)

        try:
            old_variants = ujson.loads(self.get(query, partner_code)['payload'])
        except (TypeError, ValueError, KeyError):
            old_variants = None

        variants = variant_info_updater.merge(
            old_variants_info=old_variants,
            new_variant_info=variant_info,
        )
        payload = ujson.dumps(variants)
        return self.set(
            query=query,
            partner_code=partner_code,
            payload=payload,
            ttl_in_seconds=variant_info_updater.ttl
        )
