# -*- coding: utf-8 -*-
from __future__ import annotations

from enum import Enum
from textwrap import dedent
from types import ModuleType
from typing import Any, Dict, Generator, List, NamedTuple, Set
import importlib
import itertools
import json
import logging

from travel.hotels.content_manager.lib.common import get_dc_defaults, get_dc_yql_schema, get_dc_yt_schema


LOG = logging.getLogger(__name__)


DATACLASSES_TO_MIGRATE = {
    'hotels_wl': 'StorageHotelWL',
    'permalinks': 'StoragePermalink',
    'permalinks_wl': 'StoragePermalinkWL',
    'permarooms': 'StoragePermaroom',
    'mappings': 'StorageMapping',
    'urls': 'StorageUrl',
    'sc_descriptions': 'StorageSCDescription',
}


class NewTable(NamedTuple):
    name: str
    schema: Dict[str, str]


class VersionedModule(NamedTuple):
    version: int
    module: ModuleType


class MigrationGenerator(object):

    def __init__(self, version_from: int):
        self.migrations: List[str] = list()
        self.new_tables: List[NewTable] = list()
        self.update_migrations(version_from)

    @staticmethod
    def get_query(
            table_name: str,
            schema: Dict[str, str],
            new_fields: Set[str],
            new_types: Set[str],
            default_values: Dict[str, Any],
    ):
        query_fields = list()

        for field_name, field_type in schema.items():
            if field_name in new_fields:
                default_value = default_values[field_name]
                if field_type == 'Any':
                    field = f'{default_value} AS `{field_name}`'
                else:
                    field = f'CAST({default_value} AS {field_type}) AS `{field_name}`'
            elif field_name in new_types:
                field = f'CAST(`{field_name}` AS {field_type}) AS `{field_name}`'
            else:
                field = f'`{field_name}`'

            query_fields.append(field)

        query_fields = ',\n                '.join(query_fields)

        query = f'''
            INSERT INTO `{{storage_path}}/{table_name}` WITH TRUNCATE
            SELECT
                {query_fields}
            FROM `{{storage_path}}/{table_name}`;
        '''
        return query

    @staticmethod
    def get_yql_default_values(default_values: Dict[str, Any]) -> Dict[str, Any]:
        result = dict()
        for key, value in default_values.items():
            if value is None:
                value = 'NULL'
            elif isinstance(value, Enum):
                value = value.value
            elif isinstance(value, str):
                value = f'"{value}"'
            elif isinstance(value, (dict, list)):
                value = json.dumps(value)
                value = value.replace('{', '{{').replace('}', '}}')
                value = f'Yson::SerializeText(Json(@@{value}@@))'
            result[key] = value
        return result

    def get_migration(self, prev_module: VersionedModule, next_module: VersionedModule) -> str:
        LOG.info(f'Previous version: {prev_module.version}')
        LOG.info(f'Next version: {next_module.version}')
        queries = list()
        for table_name, dc_name in DATACLASSES_TO_MIGRATE.items():
            LOG.info(f'Checking {dc_name}')

            prev_dc = prev_module.module.__dict__.get(dc_name)
            next_dc = next_module.module.__dict__[dc_name]

            if prev_dc is None:
                new_table = NewTable(
                    name=table_name,
                    schema=get_dc_yt_schema(next_dc)
                )
                self.new_tables.append(new_table)
                continue

            prev_schema = get_dc_yql_schema(prev_dc)
            next_schema = get_dc_yql_schema(next_dc)

            prev_fields = set(prev_schema.keys())
            next_fields = set(next_schema.keys())

            new_fields = next_fields - prev_fields
            LOG.debug(f'New fields: {new_fields}')

            removed_fields = prev_fields - next_fields
            LOG.debug(f'Removed fields: {removed_fields}')

            common_fields = prev_fields & next_fields
            new_types = {f for f in common_fields if prev_schema[f] != next_schema[f]}
            LOG.debug(f'New types: {new_types}')

            if not (new_fields or removed_fields or new_types):
                continue

            default_values = self.get_yql_default_values(get_dc_defaults(next_dc))
            query = self.get_query(table_name, next_schema, new_fields, new_types, default_values)
            queries.append(dedent(query))

        if not queries and not self.new_tables:
            raise RuntimeError(f'No changes detected between\n{prev_module}\nand\n{next_module}')

        return ''.join(queries)

    def generate_migrations_for_modules(self, modules: List[VersionedModule]) -> Generator[str]:
        previous_modules, next_modules = itertools.tee(modules)
        next(next_modules, None)
        for previous_module, next_module in zip(previous_modules, next_modules):
            yield self.get_migration(previous_module, next_module)

    def update_migrations(self, version_from: int) -> None:
        n = 0
        modules = list()
        for n in itertools.count(version_from):
            try:
                module = importlib.import_module(f'travel.hotels.content_manager.migrations.storage_{n}')
                modules.append(VersionedModule(n, module))
            except ImportError:
                break
        module = importlib.import_module('travel.hotels.content_manager.data_model.storage')
        modules.append(VersionedModule(n, module))

        self.migrations.extend(self.generate_migrations_for_modules(modules))
