from abc import abstractmethod
from datetime import datetime
from typing import List

import yt.wrapper as yt

from travel.avia.country_restrictions.lib.parsers.to_yt_table_parser import ToYtTableParser
from travel.avia.country_restrictions.lib.types import CountryInfo


class YtTablesToYtTableParser(ToYtTableParser):
    """
    Abstract class for parsing from YT tables with same format to YT table.
    """

    def get_newest_element_timestamp(self, data) -> datetime:
        """
        Analyze stored data and return timestamp of the newest element.
        """

        top_timestamp = datetime.utcfromtimestamp(0).replace(tzinfo=None)
        for key, value in data.items():
            for key2, metric in value.items():
                if metric.updated_time is not None and metric.updated_time.replace(tzinfo=None) > top_timestamp:
                    top_timestamp = metric.updated_time.replace(tzinfo=None)

        return top_timestamp

    @abstractmethod
    def get_input_table_names(self, newest_element_timestamp: datetime) -> List[str]:
        """
        Returns list of full YT paths to source tables.
        All rows from all source tables will be concatenated for row parser.
        """
        pass

    @abstractmethod
    def get_point_key(self, row) -> str:

        pass

    @abstractmethod
    def parse_line(self, row, table_modification_time: datetime) -> CountryInfo:
        pass

    def get_data(self, old_data):
        newest_element_timestamp = self.get_newest_element_timestamp(old_data)
        data = {}
        newest_table_mod_time = None
        for table_name in self.get_input_table_names(newest_element_timestamp):
            table_mod_time_string = self.yt_client.get_attribute(table_name, "modification_time")
            table_mod_time = datetime.strptime(table_mod_time_string, "%Y-%m-%dT%H:%M:%S.%f%z")
            table = self.yt_client.read_table(yt.TablePath(table_name))
            for row in table:
                key = self.get_point_key(row)
                if key is None:
                    continue

                country_info = self.parse_line(row, table_mod_time)
                if key not in data:
                    data[key] = country_info
                else:
                    data[key].update(country_info)

            if newest_table_mod_time is None or newest_table_mod_time < table_mod_time:
                newest_table_mod_time = table_mod_time

        for country, metrics in data.items():
            for metric_name, metric_value in metrics.items():
                if metric_value is not None:
                    metric_value.updated_time = newest_table_mod_time

        return data
