from datetime import datetime, timedelta
from typing import List, Optional

from travel.avia.country_restrictions.lib.table_format.base_format import BaseFormat
from travel.avia.country_restrictions.lib.table_format.metrics_as_columns_format import MetricsAsColumnsFormat
from travel.avia.country_restrictions.lib.parsers.yt_tables_to_yt_table_parser import YtTablesToYtTableParser
from travel.avia.country_restrictions.lib.types import CountryInfo
from travel.avia.country_restrictions.lib.types.metric_type.metric_type import MetricType


class AssessorsBaseParser(YtTablesToYtTableParser):
    COUNTRIES_INPUT_YT_TABLES_FOLDERS = []
    REGIONS_INPUT_YT_TABLES_FOLDERS = []
    SOURCE_ID = 'assessors-base'

    METRIC_TYPES: List[MetricType] = []

    UPDATING_TABLE_NAME = 'source-assessors-result'
    OUTPUT_TABLE_FORMATS: List[BaseFormat] = [
        MetricsAsColumnsFormat(output_table_short_name='source-assessors-result')
    ]

    PARSER_NAME = SOURCE_ID
    LOOK_ONLY_NEW_TABLES = False

    SUBPARSERS = []

    def get_input_tables_names_by_base_path(self, newest_element_timestamp: datetime, base_path: str) -> List[str]:
        # 12 hours offset because names of assessors_main tables aren't equal to real addition time
        # average diff is 3 hours
        timestamp_to_compare = newest_element_timestamp - timedelta(hours=12)
        tables_to_read = []
        for t in self.yt_client.list(base_path):
            ts = datetime.strptime(t, '%Y-%m-%dT%H:%M:%S.%f')
            if ts > timestamp_to_compare or not self.LOOK_ONLY_NEW_TABLES:
                tables_to_read.append(base_path + '/' + t)

        return sorted(tables_to_read, key=str.lower)

    def get_input_table_names(self, newest_element_timestamp: datetime) -> List[str]:
        tables = []
        for f in self.COUNTRIES_INPUT_YT_TABLES_FOLDERS:
            tables.extend(self.get_input_tables_names_by_base_path(newest_element_timestamp, f))
        for f in self.REGIONS_INPUT_YT_TABLES_FOLDERS:
            tables.extend(self.get_input_tables_names_by_base_path(newest_element_timestamp, f))

        return tables

    def get_point_key(self, row) -> Optional[str]:
        geo_id = row['geo_id']
        return self.geo_format_manager.get_point_key_by_geo_id(geo_id)

    def parse_line(self, row, table_modification_time) -> CountryInfo:
        data = {}
        for parser in self.SUBPARSERS:
            parser_result = parser(data, row)
            data.update(parser_result)

        for k, v in data.items():
            if v is not None:
                v.last_modification_time = table_modification_time
        return data
