import copy
from dataclasses import dataclass, field
import logging
from typing import Any, Dict, List, Optional

from travel.avia.country_restrictions.aggregator.hierarchy_applier.custom_parents import CUSTOM_PARENTS
from travel.avia.country_restrictions.lib.geo_format_manager import GeoFormatManager
from travel.avia.country_restrictions.lib.types import CountryInfo, InformationTable, PointType, Metric


@dataclass
class Stats:
    replaced_cells: int = 0


@dataclass
class TreeNode:
    """
    Tree is used for fast applying of parent metrics to children.
    Forwarding metrics from top to bottom is faster then iterating through all points list to get parent of point.
    """
    parent_node: Any
    point_key: str
    country_info: CountryInfo
    children: List[Any] = field(default_factory=list)

    def __post_init__(self):
        if self.parent_node is not None:
            self.parent_node.children.append(self)

    def update_metrics_with_hierarchy(self, country_info: CountryInfo, stats: Stats):
        """
        Add parents metrics for self if there are no data in self metrics.
        Then forwards current value to children.
        :param country_info: CountryInfo
        :param stats: Stats: statistics for adding counters
        :return: None
        """
        country_info_copy: CountryInfo = copy.deepcopy(country_info)

        for metric_name, data in self.country_info.items():
            if data is None:
                continue
            country_info_copy[metric_name] = data

        metric_names_to_look = list(set(self.country_info.keys()) | set(country_info_copy.keys()))
        for metric_name in metric_names_to_look:
            old_metric = self.country_info.get(metric_name, None)
            new_metric = country_info_copy.get(metric_name, None)
            if old_metric is None and new_metric is None:
                continue
            elif old_metric is not None and new_metric is not None:
                if not Metric.equal_without_meta_info(old_metric, new_metric):
                    stats.replaced_cells += 1
            else:
                stats.replaced_cells += 1

        self.country_info = country_info_copy

        for tree_node in self.children:
            tree_node.update_metrics_with_hierarchy(country_info_copy, stats)


TreeNodeStorage = Dict[str, TreeNode]
Layers = List[List[str]]


def add_point_key_to_point_type_list(point_key: str, settlements, regions, countries):
    point_type = PointType(point_key[0])
    if point_type == PointType.SETTLEMENT:
        settlements.add(point_key)
    elif point_type == PointType.REGION:
        regions.add(point_key)
    elif point_type == PointType.COUNTRY:
        countries.add(point_key)
    else:
        logging.getLogger().warning(f'Unknown point key prefix: {point_type}. Point Key: {point_key}')


def get_layers_from_information_table(info_table: InformationTable) -> Layers:
    """
    Looks through data in InformationTable and return all layers of hierarchy tree.
    Еhe result has the property: all elements from i layer can't be children for element from (i + 1) layer and from
    self layer.
    For example: [countries, regions, settlements]
    The result is used for optimisation of building hierarchy tree: now we look through smaller list of parent variants,
    and also we are sure that if the parent exists, we will find it in the layers below.
    :param info_table: InformationTable from parser
    :return: Layers
    """

    settlements = set()
    regions = set()
    countries = set()

    for point_key in info_table:
        add_point_key_to_point_type_list(point_key, settlements, regions, countries)

    for point_key, point_keys in CUSTOM_PARENTS.items():
        add_point_key_to_point_type_list(point_key, settlements, regions, countries)
        for internal_point_key in point_keys:
            add_point_key_to_point_type_list(internal_point_key, settlements, regions, countries)

    return [list(countries), list(regions), list(settlements)]


def get_node_parent(
    tree_node_storage: TreeNodeStorage,
    geo_format_manager: GeoFormatManager,
    point_key: str,
) -> Optional[TreeNode]:
    """
    First, looks to list of custom parents and uses the parent from it if it the parent for the point is found in it.
    Otherwise uses parent from ordinary parents manager (geo_format_manager).
    Then looks to node storage and returns parent TreeNode if it exists
    :param tree_node_storage: TreeNodeStorage
    :param geo_format_manager: GeoFormatManager
    :param point_key: str
    :return: Optional[TreeNode]: parent tree node
    """
    custom_point_key_parents = CUSTOM_PARENTS.get(point_key, None)
    if custom_point_key_parents is not None:
        point_key_parents = custom_point_key_parents
    else:
        point_key_parents = geo_format_manager.get_point_key_parents(point_key)

    for point_key_parent in point_key_parents:
        node_parent = tree_node_storage.get(point_key_parent, None)
        if node_parent is not None:
            return node_parent

    return None


def build_tree(
    info_table: InformationTable,
    geo_format_manager: GeoFormatManager,
) -> (TreeNode, TreeNodeStorage, Stats):
    """
    Builds geo hierarchy tree using get_layers_from_information_table for optimisation.
    :param info_table: InformationTable
    :param geo_format_manager: GeoFormatManager
    :return: (TreeNode, TreeNodeStorage, Stats): (root hierarchy node; hashmap of all nodes with point_key as key;
                                                  statistics)
    """
    layers = get_layers_from_information_table(info_table)

    root_node = TreeNode(parent_node=None, point_key='', country_info={})
    tree_node_storage: TreeNodeStorage = {}

    for layer in layers:
        for point_key in layer:
            tree_node_storage[point_key] = TreeNode(
                parent_node=get_node_parent(tree_node_storage, geo_format_manager, point_key) or root_node,
                point_key=point_key,
                country_info=info_table.get(point_key, {}),
            )

    return root_node, tree_node_storage
