import os

import ydb

from travel.rasp.crosslink.tools.db_base import BaseTable


class AssociatesTable(BaseTable):
    PATH = ''
    TABLE_NAME = 'associates'

    def get_table_description(self):
        return (ydb.TableDescription()
                .with_column(ydb.Column('graph_version', ydb.OptionalType(ydb.PrimitiveType.Uint64)))
                .with_column(ydb.Column('source_from_key', ydb.OptionalType(ydb.PrimitiveType.Utf8)))
                .with_column(ydb.Column('source_to_key', ydb.OptionalType(ydb.PrimitiveType.Utf8)))
                .with_column(ydb.Column('from_key', ydb.OptionalType(ydb.PrimitiveType.Utf8)))
                .with_column(ydb.Column('to_key', ydb.OptionalType(ydb.PrimitiveType.Utf8)))
                .with_column(ydb.Column('transport_type', ydb.OptionalType(ydb.PrimitiveType.Utf8)))
                .with_primary_keys('graph_version',
                                   'source_from_key',
                                   'source_to_key',
                                   'from_key',
                                   'to_key',
                                   'transport_type'))

    def fill(self, associates_data):
        query = f"""
        DECLARE $associatesData AS "List<Struct<
            graph_version: Uint64,
            source_from_key: Utf8,
            source_to_key: Utf8,
            from_key: Utf8,
            to_key: Utf8,
            transport_type: Utf8>>";

        REPLACE INTO {self.TABLE_NAME}
        SELECT
            graph_version,
            source_from_key,
            source_to_key,
            from_key,
            to_key,
            transport_type
        FROM AS_TABLE($associatesData);
        """
        params = {
            '$associatesData': associates_data,
        }
        return self._execute(query, params)


class PointsTable(BaseTable):
    PATH = ''
    TABLE_NAME = 'points'

    def fill(self, points_data):
        query = f"""
        DECLARE $pointsData AS "List<Struct<
            key: Utf8,
            slug: Utf8,
            title: Utf8,
            title_ru_genitive: Utf8,
            title_ru_accusative: Utf8>>";

        REPLACE INTO {self.TABLE_NAME}
        SELECT
            key,
            slug,
            title,
            title_ru_genitive,
            title_ru_accusative
        FROM AS_TABLE($pointsData);
        """
        params = {
            '$pointsData': points_data,
        }
        return self._execute(query, params)

    def get_table_description(self):
        return (ydb.TableDescription()
                .with_column(ydb.Column('key', ydb.OptionalType(ydb.PrimitiveType.Utf8)))
                .with_column(ydb.Column('slug', ydb.OptionalType(ydb.PrimitiveType.Utf8)))
                .with_column(ydb.Column('title', ydb.OptionalType(ydb.PrimitiveType.Utf8)))
                .with_column(ydb.Column('title_ru_accusative', ydb.OptionalType(ydb.PrimitiveType.Utf8)))
                .with_column(ydb.Column('title_ru_genitive', ydb.OptionalType(ydb.PrimitiveType.Utf8)))
                .with_primary_key('key'))


def credentials_from_environ():
    if os.getenv('YDB_TOKEN') is not None:
        return os.getenv('YDB_TOKEN')


def get_tables(endpoint, db, token):
    points_table = PointsTable(endpoint, db, token if token else credentials_from_environ())
    associates_table = AssociatesTable(endpoint, db, token if token else credentials_from_environ())
    return points_table, associates_table
