import json
from abc import ABC, abstractmethod
from typing import Any, Dict, List

from travel.avia.country_restrictions.lib.geo_format_manager import GeoFormatManager
from travel.avia.country_restrictions.lib.types import Metric

from travel.avia.country_restrictions.lib.utils.multi_cluster_yt_client import get_yt_column_description


class BaseFormat(ABC):
    YT_GEO_ID_COLUMN_NAME: str = 'key'
    YT_POINT_KEY_COLUMN_NAME: str = 'point_key'

    def __init__(self, output_table_short_name: str):
        self.output_table_short_name = output_table_short_name
        self.point_key_column_description = get_yt_column_description(self.YT_POINT_KEY_COLUMN_NAME, 'string')
        self.geo_id_column_description = get_yt_column_description(self.YT_GEO_ID_COLUMN_NAME, 'uint64')
        self.indexes_column_descriptions = [
            self.point_key_column_description,
            self.geo_id_column_description,
        ]

    @staticmethod
    def reformat_metric_dict(metric_dict: Dict[str, Metric]) -> Dict[str, Dict[str, Any]]:
        result = {}
        for k, v in metric_dict.items():
            if v is not None:
                v = json.loads(v.json())
            result[k] = v
        return result

    @abstractmethod
    def dict_data_to_yt_format(self, data, geo_format_manager: GeoFormatManager):
        pass

    def prepare_data(self, data, geo_format_manager: GeoFormatManager):
        result = {}
        for geo, d in data.items():
            if geo is None:
                continue
            result[geo] = self.reformat_metric_dict(d)
        return self.dict_data_to_yt_format(result, geo_format_manager)

    @abstractmethod
    def get_yt_schema(self, metric_types) -> List[Dict[str, str]]:
        pass
