# coding: utf8
from __future__ import absolute_import, division, print_function, unicode_literals

from typing import List, Dict, Union, AnyStr

import ydb

from travel.rasp.library.python.ydb.base_table import BaseTable


class Cache(BaseTable):
    PATH = 'cache'
    NAME = 'cache'

    def description(self):
        return (ydb.TableDescription()
                    .with_column(ydb.Column('key', ydb.OptionalType(ydb.PrimitiveType.Utf8)))
                    .with_column(ydb.Column('value', ydb.OptionalType(ydb.PrimitiveType.JsonDocument)))
                    .with_column(ydb.Column('expire_at', ydb.OptionalType(ydb.PrimitiveType.Timestamp)))
                    .with_primary_key('key')
                    .with_ttl(ydb.TtlSettings().with_date_type_column('expire_at', expire_after_seconds=0)))

    def add(self, data):
        # type: (Cache, List[Dict]) -> None

        query = """
            DECLARE $data AS "List<Struct<
                key: Utf8,
                value: JsonDocument,
                expire_at: Timestamp
            >>";

            UPSERT INTO `{}` (key, value, expire_at)
            SELECT
                key, value, expire_at
            FROM AS_TABLE($data);
        """.format(self.full_name)
        params = {
            '$data': data,
        }
        self.execute(query, params)

    def get(self, key):
        # type: (Cache, AnyStr) -> Union[None, ydb.convert._Row]

        query = """
            DECLARE $key as Utf8;

            SELECT *
            FROM `{}`
            WHERE
                key = $key AND
                expire_at >= CurrentUtcTimestamp();
        """.format(self.full_name)
        params = {
            '$key': key
        }
        result_sets = self.execute(query, params)
        if len(result_sets) == 0:
            return None
        rows = result_sets[0].rows
        return rows[0] if rows else None
