from pydantic import parse_obj_as
from typing import List
import yt.wrapper as yt

from travel.avia.country_restrictions.aggregator.metric_postprocess import process_tourism_availability_date, \
    remove_flights_availability_if_russia, remove_vaccines_if_no_need, union_visa_and_visa_issuance
from travel.avia.country_restrictions.lib.parsers.to_yt_table_parser import ToYtTableParser
from travel.avia.country_restrictions.lib.table_format.base_format import BaseFormat
from travel.avia.country_restrictions.lib.table_format.metrics_as_columns_format import MetricsAsColumnsFormat
from travel.avia.country_restrictions.lib.table_format.metrics_as_json_format import MetricsAsJsonFormat
from travel.avia.country_restrictions.lib.table_format.extended_metrics_format import ExtendedMetricsFormat
from travel.avia.country_restrictions.lib.types import CountryInfo, InformationTable
from travel.avia.country_restrictions.lib.types.metric_type import ALL_METRICS, METRICS_FOR_EXTENDED_BANNER


class MetricPostprocessParser(ToYtTableParser):
    METRIC_TYPES = ALL_METRICS

    INPUT_TABLE_NAME = 'apply-hierarchy-as-columns'
    UPDATING_TABLE_NAME = 'result-as-columns'
    OUTPUT_TABLE_FORMATS: List[BaseFormat] = [
        MetricsAsColumnsFormat(output_table_short_name='result-as-columns'),
        MetricsAsJsonFormat(output_table_short_name='result-as-json'),
        ExtendedMetricsFormat(
            output_table_short_name='result-as-extended-metrics',
            metrics_list=METRICS_FOR_EXTENDED_BANNER,
        ),
    ]

    PARSER_NAME = 'metric-postprocess'
    SKIP_PREVIOUS_DATA = True

    POSTPROCESSORS = [
        process_tourism_availability_date.processor,
        remove_flights_availability_if_russia.processor,
        union_visa_and_visa_issuance.processor,
        remove_vaccines_if_no_need.processor,
    ]

    def get_input_table_fullname(self) -> str:
        return self.get_path_by_shortname(self.INPUT_TABLE_NAME)

    def get_data(self, old_data):
        data: InformationTable = {}
        source = [dict(row) for row in self.yt_client.read_table(yt.TablePath(self.get_input_table_fullname()))]

        for row in source:
            key = self.get_point_key_from_row(row)
            if key is None:
                continue

            row_data = parse_obj_as(CountryInfo, row)
            if row_data is None:
                continue

            for func in self.POSTPROCESSORS:
                row_data = func(key, row_data, self.geo_format_manager)

            if row_data is None:
                continue

            for metric_name in list(row_data.keys()):
                if row_data[metric_name] is None:
                    del row_data[metric_name]

            data[key] = row_data

        return data
