# -*- coding: utf-8 -*-
import logging
import itertools
from collections import defaultdict
from typing import Dict

import numpy as np
import six

from yt.wrapper import YtClient  # noqa

from travel.avia.avia_statistics.lib.settlements_geo_index import SettlementsGeoIndex  # noqa
from travel.avia.avia_statistics.landing_routes import LandingRoute
from travel.library.python.dicts.avia.near_direction_repository import NearDirectionRepository  # noqa
from travel.avia.avia_statistics.landing_routes import get_landing_route_crosslinks

logger = logging.getLogger(__name__)


class CrosslinksProvider(object):
    CROSSLINKS_PER_PAGE = 10

    def __init__(
        self,
        landing_routes,
        route_weights,
        default_distance,
        near_direction_repository,
        settlements_geo_index,
        yt_client,
    ):
        """
        :param List[LandingRoute] landing_routes:
        :param Dict[LandingRoute, RouteWeight] route_weights:
        :param int default_distance:
        :param NearDirectionRepository near_direction_repository:
        :param SettlementsGeoIndex settlements_geo_index:
        :param YtClient yt_client:
        """

        self._landing_routes = landing_routes
        self._route_weights = route_weights
        self._default_distance = default_distance
        self._near_direction_repository = near_direction_repository
        self._settlements_geo_index = settlements_geo_index
        self._yt_client = yt_client

    def read_from_yt(self, yt_crosslinks_table):
        return get_landing_route_crosslinks(self._yt_client, yt_crosslinks_table)

    def generate(self):
        crosslinks = self._generate_crosslinks()
        logger.debug('crosslinks count = %d', len(crosslinks))
        logger.debug('non empty crosslinks count = %d', sum(len(v) != 0 for v in crosslinks.values()))
        logger.debug('empty crosslinks count = %d', sum(len(v) == 0 for v in crosslinks.values()))
        logger.debug('fullfilled pages = %d', sum(len(v) >= self.CROSSLINKS_PER_PAGE for v in crosslinks.values()))
        return crosslinks

    def _generate_crosslinks(self):
        landing_routes_by_national_version = defaultdict(list)
        for route in self._landing_routes:
            landing_routes_by_national_version[route.national_version].append(route)

        crosslinks_by_route = {}
        for nv, landing_routes in six.iteritems(landing_routes_by_national_version):
            crosslinks = self._generate_crosslinks_for_routes_with_same_national_version(landing_routes)
            crosslinks_by_route.update(crosslinks)
        return crosslinks_by_route

    def _generate_crosslinks_for_routes_with_same_national_version(self, landing_routes):
        index_by_route = {route: i for i, route in enumerate(landing_routes)}
        outgoing_routes_by_city = defaultdict(list)
        for route in landing_routes[::-1]:
            outgoing_routes_by_city[route.from_id].append(route)

        crosslinks_by_route = defaultdict(list)
        banned_by_route = defaultdict(set)
        appeared_routes = set()
        uncompleted_routes = set(landing_routes)

        logger.info('Start generating crosslinks')
        for route in landing_routes:
            self._init_banned_routes(banned_by_route, route, outgoing_routes_by_city, index_by_route)
            self._add_close_routes(route, crosslinks_by_route, index_by_route, banned_by_route)
            if self._check_completeness(crosslinks_by_route, route, uncompleted_routes):
                continue
            self._add_popular_routes_from_destination(
                route,
                outgoing_routes_by_city,
                index_by_route,
                crosslinks_by_route,
                uncompleted_routes,
                banned_by_route,
            )
        logger.info('Close routes and popular routes from destination were added')
        self._ensure_each_route_appeared_at_least_once(
            landing_routes,
            uncompleted_routes,
            crosslinks_by_route,
            banned_by_route,
            appeared_routes,
        )
        self._fill_with_random_routes(
            landing_routes,
            crosslinks_by_route,
            uncompleted_routes,
            banned_by_route,
        )

        return crosslinks_by_route

    def _add_close_routes(self, route, crosslinks_by_route, index_by_route, banned_by_route):
        distance = self._near_direction_repository.get_default_distance(
            route.from_id,
            route.to_id,
            self._default_distance,
        )
        alternative_from_ids = (s.Id for s in self._settlements_geo_index.get_nearest(route.from_id, distance))
        alternative_to_ids = [route.to_id] + ([
            s.Id for s in self._settlements_geo_index.get_nearest(route.to_id, distance)
        ])

        close_routes = (
            LandingRoute(f, t, route.national_version)
            for (f, t) in itertools.product(alternative_from_ids, alternative_to_ids)
        )
        close_routes = (
            r for r in close_routes
            if r in index_by_route and r not in banned_by_route[r]
        )
        close_routes = list(sorted(close_routes, key=lambda r: index_by_route[r]))[:self.CROSSLINKS_PER_PAGE]
        crosslinks_by_route[route].extend(close_routes)
        for cr in close_routes:
            banned_by_route[route].add(cr)

    def _add_popular_routes_from_destination(
        self,
        route,
        outgoing_routes_by_city,
        index_by_route,
        crosslinks_by_route,
        uncompleted_routes,
        banned_by_route,
    ):
        for next_city in outgoing_routes_by_city[route.to_id]:
            next_route = LandingRoute(route.to_id, next_city, route.national_version)
            if next_route not in index_by_route or next_route in banned_by_route[route]:
                continue
            crosslinks_by_route[route].append(next_route)
            banned_by_route[route].add(next_route)
        self._check_completeness(crosslinks_by_route, route, uncompleted_routes)

    def _ensure_each_route_appeared_at_least_once(
        self,
        landing_routes,
        uncompleted_routes,
        crosslinks_by_route,
        banned_by_route,
        appeared_routes,
    ):
        logger.info('Ensure each route appeared on at least one page')
        unappeared_routes = [r for r in landing_routes if r not in appeared_routes]

        for i, route in enumerate(unappeared_routes):
            choices = tuple(
                r for r in uncompleted_routes
                if route not in banned_by_route[r]
            )
            page = choices[np.random.choice(len(choices))]
            appeared_routes.add(route)
            crosslinks_by_route[page].append(route)
            banned_by_route[page].add(route)
            self._check_completeness(crosslinks_by_route, page, uncompleted_routes)
            if i and i % 1000 == 0:
                logger.info('processed: %d', i)

    def _fill_with_random_routes(
        self,
        landing_routes,
        crosslinks_by_route,
        uncompleted_routes,
        banned_by_route,
    ):
        logger.info('Fill uncompleted routes with random routes')
        routes_to_be_filled = [r for r in landing_routes if r in uncompleted_routes]
        logger.info('total routes to be filled: %d', len(routes_to_be_filled))

        for i, route in enumerate(routes_to_be_filled):
            if i and i % 100 == 0:
                logger.info('processed: %d', i)

            possible_routes = [r for r in landing_routes if r not in banned_by_route[route]]
            total_redirs = sum(self._route_weights[r].redirs for r in possible_routes)
            proba_by_route = {r: float(self._route_weights[r].redirs) / total_redirs for r in possible_routes}
            probas_distribution = [proba_by_route[r] for r in possible_routes]

            while len(crosslinks_by_route[route]) < self.CROSSLINKS_PER_PAGE:
                random_route = possible_routes[np.random.choice(len(possible_routes), p=probas_distribution)]
                if random_route in banned_by_route[route]:
                    continue
                crosslinks_by_route[route].append(random_route)
                banned_by_route[route].add(random_route)

            self._check_completeness(crosslinks_by_route, route, uncompleted_routes)

    def _check_completeness(self, crosslinks_by_route, route, uncompleted_routes):
        if len(crosslinks_by_route[route]) >= self.CROSSLINKS_PER_PAGE:
            crosslinks_by_route[route] = crosslinks_by_route[route][:self.CROSSLINKS_PER_PAGE]
            if route in uncompleted_routes:
                uncompleted_routes.remove(route)
            return True
        return False

    @staticmethod
    def _inverse_route(route):
        # type: (LandingRoute) -> LandingRoute
        return LandingRoute(route.to_id, route.from_id, route.national_version)

    def _get_alternative_routes(self, route, index_by_route):
        distance = self._near_direction_repository.get_default_distance(
            route.from_id,
            route.to_id,
            self._default_distance,
        )
        return [
            LandingRoute(route.from_id, s.Id, route.national_version)
            for s in self._settlements_geo_index.get_nearest(route.to_id, distance)
            if LandingRoute(route.from_id, s.Id, route.national_version) in index_by_route
        ]

    def _init_banned_routes(self, banned_by_route, route, outgoing_routes_by_city, index_by_route):
        banned_by_route[route].add(route)
        banned_by_route[route].add(self._inverse_route(route))
        for r in self._get_alternative_routes(route, index_by_route):
            banned_by_route[route].add(r)
        for r in outgoing_routes_by_city[route.from_id]:
            banned_by_route[route].add(r)


class RoutesDict(object):
    @classmethod
    def as_json(cls, routes_dict):
        return {cls.map_key(k): map(cls.map_value, v) for k, v in six.iteritems(routes_dict)}

    @staticmethod
    def map_key(k):
        # type: (LandingRoute) -> str
        return '{}_{}_{}'.format(k.from_id, k.to_id, k.national_version)

    @staticmethod
    def map_value(value):
        # type: (LandingRoute) -> Dict
        return {
            'from_id': value.from_id,
            'to_id': value.to_id,
            'national_version': value.national_version,
        }
