from abc import abstractmethod
from typing import Dict, List, Optional

from yt.common import YtError
import yt.wrapper as yt
from pydantic import parse_obj_as

from travel.avia.country_restrictions.lib.table_format.base_format import BaseFormat
from travel.avia.country_restrictions.lib.parsers.abc_parser import AbcParser
from travel.avia.country_restrictions.lib.types import Environment, InformationTable, Metric, OutputFormat
from travel.avia.country_restrictions.lib.types.metric_type.metric_type import MetricType
from travel.avia.country_restrictions.lib.parsers.metric_correctness_validators import run_all_metric_correctness_validators


class ToYtTableParser(AbcParser):
    """
    Abstract class for parsing to YT table.
    Algorithm:
    1. Read result YT table if it exists.
    2. Update data with new values from source
    3. Write to result YT table
    """

    SOURCE_ID = None
    OUTPUT_FORMAT = OutputFormat.MetricsAsColumns
    TABLE_BASE_PATH = '//home/avia/data/country-restrictions'
    YT_PROXIES = ('hahn', 'arnold')
    METRIC_TYPES: List[MetricType] = []

    # From this table data will be loaded to update
    # Table must be in OUTPUT_TABLE_FORMATS
    UPDATING_TABLE_NAME = ''

    # Where to write results
    OUTPUT_TABLE_FORMATS: List[BaseFormat] = []

    YT_POINT_KEY_COLUMN_NAME: str = 'point_key'
    YT_GEO_ID_COLUMN_NAME: str = 'key'
    YT_METRICS_AS_LIST_COLUMN_NAME: str = 'value'

    # If the flag is enabled, new data won't be applied to the data from previous step.
    # YT storage will be just overwritten by the new data.
    # In this case also no metrics will be sent to solomon
    SKIP_PREVIOUS_DATA = False

    def __init__(self, base_path: str = TABLE_BASE_PATH, yt_client=None, **kwargs):
        super().__init__(**kwargs)
        self.table_base_path = base_path
        self.yt_client = yt_client
        self.metric_names = [x.name for x in self.METRIC_TYPES]

    @staticmethod
    def get_base_path_static(
        base_path: str,
        environment: Environment,
        version: int,
        **kwargs,
    ) -> str:
        return f'{base_path}/{environment.value}/v{version}'

    @classmethod
    def get_path_by_shortname_static(
        cls,
        base_path: str,
        environment: Environment,
        version: int,
        shortname: str,
        **kwargs,
    ) -> str:
        return f'{cls.get_base_path_static(base_path, environment, version)}/{shortname}'

    def get_path_by_shortname(self, shortname: str) -> str:
        return self.get_path_by_shortname_static(self.table_base_path, self.environment, self.version, shortname)

    @classmethod
    def get_updating_table_fullname_static(cls, **kwargs) -> str:
        return cls.get_path_by_shortname_static(shortname=cls.UPDATING_TABLE_NAME, **kwargs)

    def get_updating_table_fullname(self) -> str:
        return self.get_path_by_shortname(self.UPDATING_TABLE_NAME)

    def get_point_key_from_row(self, row) -> str:
        if self.YT_POINT_KEY_COLUMN_NAME in row:
            key = row.pop(self.YT_POINT_KEY_COLUMN_NAME)
            row.pop(self.YT_GEO_ID_COLUMN_NAME)
        else:
            geo_id = row.pop(self.YT_GEO_ID_COLUMN_NAME)
            key = self.geo_format_manager.get_point_key_by_geo_id(geo_id)
        return key

    def get_initial_data(self) -> InformationTable:
        """
        Get data from result YT table if it exists
        """

        data: InformationTable = {}

        try:
            table = self.yt_client.read_table(yt.TablePath(self.get_updating_table_fullname()))
            for row in table:
                key = self.get_point_key_from_row(row)
                if key is None:
                    continue

                country_value = parse_obj_as(Dict[str, Optional[Metric]], dict(row))
                data[key] = {k: v for k, v in country_value.items() if v is not None and k in self.metric_names}
        except YtError:
            pass

        return data

    @abstractmethod
    def get_data(self, old_data) -> InformationTable:
        pass

    def write_to_yt(self, data: InformationTable):
        for table_format in self.OUTPUT_TABLE_FORMATS:
            output_yt_table = self.get_path_by_shortname(table_format.output_table_short_name)
            self.yt_client.write_table(
                yt.TablePath(output_yt_table, schema=table_format.get_yt_schema(self.METRIC_TYPES)),
                table_format.prepare_data(data, self.geo_format_manager),
            )

    def run(self):
        table_data = self.get_initial_data()
        new_data = self.get_data(table_data)

        if self.SKIP_PREVIOUS_DATA:
            table_data = new_data
        else:
            # stats (definitions below)
            updated_geo = 0
            updated_metrics_set = set()
            updated_cells = 0

            for country_key, country_metrics in new_data.items():
                is_current_geo_updated = False

                old_country_metrics = table_data.get(country_key, {})

                for metric_name, metric_value in country_metrics.items():
                    if metric_value is None:
                        continue

                    cell_changed = False

                    if metric_name not in old_country_metrics:
                        old_country_metrics[metric_name] = metric_value
                        cell_changed = True
                    else:
                        if Metric.equal_without_meta_info(old_country_metrics[metric_name], metric_value):
                            old_country_metrics[metric_name].updated_time = metric_value.updated_time
                        else:
                            old_country_metrics[metric_name] = metric_value
                            cell_changed = True

                    if cell_changed:
                        updated_metrics_set.add(metric_name)
                        is_current_geo_updated = True
                        updated_cells += 1

                    if self.SOURCE_ID is not None:
                        metric_value.source = self.SOURCE_ID

                if len(old_country_metrics) > 0:
                    table_data[country_key] = old_country_metrics

                if is_current_geo_updated:
                    updated_geo += 1

            # if SOURCE_ID is set, set source in all metrics
            if self.SOURCE_ID is not None:
                for country, metrics in table_data.items():
                    for metric_name, metric_value in metrics.items():
                        if metric_value is not None:
                            metric_value.source = self.SOURCE_ID

            self.solomon_pusher.push_stats(
                **self.count_stats(table_data, updated_geo, updated_metrics_set, updated_cells),
            )
            self.solomon_pusher.push_errors(self.count_errors(table_data))

        self.write_to_yt(table_data)

    # Solomon metrics description
    # updated_geo - geo with ANY cell updated
    # updated_metrics -metrics with ANY cell updated
    # updated_cells - updated cells
    # null_geo - geo with ALL empty cells
    # null_metrics - metrics with ALL empty cells
    # null_cells - empty cells
    # total_geo - geo count
    # total_metrics - metrics count
    # total_cells - cells count
    def count_stats(self, table_data, updated_geo, updated_metrics_set, updated_cells):
        total_geo = len(table_data)
        null_geo = 0

        not_null_metrics_set = set()
        total_metrics_set = set(self.metric_names)

        null_cells = 0

        for country_key, country_metrics in table_data.items():
            is_current_geo_null = True

            for metric_type in self.METRIC_TYPES:
                metric_name = metric_type.name
                metric_value = country_metrics.get(metric_name, None)

                if metric_value is not None:
                    is_current_geo_null = False
                    not_null_metrics_set.add(metric_name)
                else:
                    null_cells += 1

            if is_current_geo_null:
                null_geo += 1

        total_metrics = len(total_metrics_set)
        null_metrics = len(total_metrics_set - not_null_metrics_set)
        updated_metrics = len(updated_metrics_set)
        total_cells = total_geo * total_metrics

        return dict(
            updated_geo=updated_geo,
            updated_metrics=updated_metrics,
            updated_cells=updated_cells,
            null_geo=null_geo,
            null_metrics=null_metrics,
            null_cells=null_cells,
            total_geo=total_geo,
            total_metrics=total_metrics,
            total_cells=total_cells,
        )

    @staticmethod
    def count_errors(table_data):
        errors = {}

        for country_key, country_metrics in table_data.items():
            current_row_errors = run_all_metric_correctness_validators(country_metrics)
            for error in current_row_errors:
                if error not in errors:
                    errors[error] = 1
                else:
                    errors[error] += 1

        return errors
