from collections import Counter
from itertools import product
from typing import Callable, Dict, Iterable, List, NamedTuple, Set, Tuple
import logging
from math import floor

from travel.hotels.tools.region_pages_builder.common.tanker_data import ConfigRegionCrossLinks


class CrossLinkGenerator:

    def __init__(
        self,
        link_count_limit: int,
        custom_incoming_link_count: Dict[int, int] = None,
        fixed_links: Set[Tuple[int, int]] = None,
        cost_function: Callable[[int, int], float] = None,
        popularity_function: Callable[[int], float] = None,
        city_id_to_name: Dict[int, str] = None,
    ):
        """
        Initialize cross-link generator
        :param custom_incoming_link_count: dict geo_id -> custom incoming link count
        :param fixed_links: set of tuple from_geo_id -> to_geo_id of fixed links that must be presented in result graph
        :param cost_function: cost between two cities for some algorithms
        :param popularity_function: some city popularity measure to define incoming link count
        :param city_id_to_name: city id to city name dict
        """

        self.incoming_link_count = custom_incoming_link_count if custom_incoming_link_count is not None else {}
        self.link_count_limit = link_count_limit
        self.fixed_links = fixed_links if fixed_links is not None else set()
        self.cost_function = cost_function if cost_function is not None else lambda _1, _2: 0
        self.popularity_function = popularity_function if popularity_function is not None else lambda _: 1
        self.city_id_to_name = city_id_to_name if city_id_to_name is not None else {}

    def generate_links(self, cities: Iterable[int]) -> Dict[int, List[int]]:
        city_list = list(cities)
        outgoing_link_count = {city_id: self.link_count_limit for city_id in city_list}
        incoming_link_count = self._calculate_incoming_link_count(city_list, self.incoming_link_count, self.popularity_function)

        result = generate_greedy(city_list, outgoing_link_count, incoming_link_count, self.fixed_links, self.cost_function, self.link_count_limit)
        return {geo_id: sorted(links_to_city, key=self.popularity_function, reverse=True) for geo_id, links_to_city in result.items()}

    def _get_city_name(self, city_id: int) -> str:
        if city_id in self.city_id_to_name:
            return f"{self.city_id_to_name[city_id]}[{city_id}]"
        else:
            return f"geoId={city_id}"

    def _calculate_incoming_link_count(
        self,
        cities: List[int],
        custom_incoming_link_count: Dict[int, int],
        popularity: Callable[[int], float]
    ) -> Dict[int, int]:
        """
        :param cities: list of cities id
        :param custom_incoming_link_count: dict geo_id -> custom incoming link count
        :param popularity: popularity function
        :return: dict of geo_id -> incoming_link_count for all cities
        """
        logger = logging.getLogger("incoming_link_calculator")

        for link_count in custom_incoming_link_count.values():
            assert 0 < link_count <= len(cities)

        total_link_count = len(cities) * self.link_count_limit

        cities_without_fixed_link_count = sorted(filter(lambda city_id: city_id not in custom_incoming_link_count, cities),
                                                 key=popularity, reverse=True)

        arbitrary_link_count = total_link_count - len(cities_without_fixed_link_count) - sum(custom_incoming_link_count.values())
        assert 0 <= arbitrary_link_count

        total_popularity = sum(map(popularity, cities_without_fixed_link_count))
        link_for_popularity = arbitrary_link_count / total_popularity

        arbitrary_link_dict = {}
        for city_id in cities_without_fixed_link_count:
            arbitrary_link_count_for_city = floor(popularity(city_id) * link_for_popularity)
            arbitrary_link_dict[city_id] = 1 + arbitrary_link_count_for_city
            arbitrary_link_count -= arbitrary_link_count_for_city

        current_city_index = 0
        while arbitrary_link_count != 0:
            geo_id = cities_without_fixed_link_count[current_city_index]
            arbitrary_link_dict[geo_id] += 1
            arbitrary_link_count -= 1
            current_city_index += 1

        result = {**arbitrary_link_dict, **custom_incoming_link_count}
        assert sum(result.values()) == total_link_count

        top_3_arbitrary = sorted(arbitrary_link_dict.items(), key=lambda x: x[1], reverse=True)[:3]
        top_3_arbitrary_message = ", ".join(map(lambda x: f"{self._get_city_name(x[0])}(link_count={x[1]})", top_3_arbitrary))

        top_3 = sorted(result.items(), key=lambda x: x[1], reverse=True)[:3]
        top_3_message = ", ".join(map(lambda x: f"{self._get_city_name(x[0])}(link_count={x[1]})", top_3))

        logger.info(f"Incoming link count calculated, top 3 generated are [{top_3_arbitrary_message}], top 3 are [{top_3_message}]")

        return result


def extract_restrictions_from_cross_links_config(config: ConfigRegionCrossLinks, slug_dict: Dict[str, int]) -> Tuple[Dict[int, int], Set[Tuple[int, int]]]:
    """
    :return: Tuple of incoming link count dict and fixed links list
    """
    incoming_link_count = {}
    fixed_links = set()

    for geo_identifier, region_links_data in config.regions.items():
        if geo_identifier in slug_dict:
            geo_id = slug_dict[geo_identifier]
        else:
            raise Exception(f"Unknown region slug '{geo_identifier}'")

        if region_links_data.incomingLinksCount is not None:
            incoming_link_count[geo_id] = region_links_data.incomingLinksCount

        for to_region_slug in region_links_data.fixedLinks:
            if to_region_slug in slug_dict:
                to_geo_id = slug_dict[to_region_slug]
            else:
                raise Exception(f"Unknown region slug '{to_region_slug}'")

            fixed_links.add((geo_id, to_geo_id))

    return incoming_link_count, fixed_links


class WeightedLink(NamedTuple):
    c_from: int
    c_to: int
    cost: float


def generate_greedy(
    cities: List[int],
    outgoing_link_count: Dict[int, int],
    incoming_link_count: Dict[int, int],
    fixed_links: Set[Tuple[int, int]],
    cost_function: Callable[[int, int], float],
    link_count_limit: int,
) -> Dict[int, List[int]]:
    incoming_link_count = Counter(incoming_link_count)
    outgoing_link_count = Counter(outgoing_link_count)

    weighted_links = list()

    for c_from, c_to in product(cities, cities):
        if c_from == c_to:
            continue

        cost = 0 if (c_from, c_to) in fixed_links else cost_function(c_from, c_to)
        weighted_links.append(WeightedLink(c_from, c_to, cost))

    links = list()
    for link in sorted(weighted_links, key=lambda x: x.cost):
        if incoming_link_count[link.c_to] == 0:
            continue
        if outgoing_link_count[link.c_from] == 0:
            continue
        links.append(link)
        incoming_link_count[link.c_to] -= 1
        outgoing_link_count[link.c_from] -= 1

        if incoming_link_count[link.c_to] == 0:
            del incoming_link_count[link.c_to]
        if not incoming_link_count:
            logging.info('No more incoming links')
            break

        if outgoing_link_count[link.c_from] == 0:
            del outgoing_link_count[link.c_from]
        if not outgoing_link_count:
            logging.info('No more outgoing links')
            break

    result = dict()
    for link in links:
        outgoing_links = result.setdefault(link.c_from, list())
        outgoing_links.append(link.c_to)

    return result
