# coding: utf-8
from __future__ import absolute_import, division, print_function, unicode_literals

import logging
from collections import defaultdict, namedtuple
from itertools import chain

import numpy as np
from django.conf import settings
from yt.wrapper import YtClient

from common.models.geo import Settlement, Station
from common.utils.geobase import geobase

from travel.rasp.trains.scripts.lib.fetch_routes import fetch_routes
from travel.rasp.trains.scripts.lib.fetch_searches import fetch_searches
from travel.rasp.trains.scripts.generate_crosslinks.stat_yql import stat_yql


ERRORS_FILE_NAME = 'relevance_errors.txt'


class RouteItem(object):
    def __init__(self, point_from, point_to, count):
        self.point_from = point_from
        self.from_key = point_from.point_key
        self.point_to = point_to
        self.to_key = point_to.point_key
        self.count = count

    def __repr__(self):
        return '{}(point_from={!r}, point_to={!r}, count={!r})'.format(self.__class__.__name__, self.from_key, self.to_key, self.count)

    def __hash__(self):
        return hash((self.from_key, self.to_key))

    def __eq__(self, other):
        return self.from_key == other.from_key and self.to_key == other.to_key


RelevantRouteItem = namedtuple('RelevantRouteItem', 'route relevance')


class Runner(object):
    """https://st.yandex-team.ru/TRAVELORGANIC-47#6163bbe769137a21bcf439ae"""
    def __init__(
            self, yt_proxy, yt_token,
            links_count, unviewed_links_count,
            proximity_relevance, distance_additive,
            stat
    ):
        self.yt_client = YtClient(proxy=yt_proxy, token=yt_token)

        self.links_count = links_count
        self.unviewed_links_count = unviewed_links_count
        self.proximity_relevance = proximity_relevance
        self.distance_additive = distance_additive

        self.stat = stat

        self.errors = []

    def generate_links(self, yt_path, dry_run=False):
        links = self.generate_from_sitemap()
        prepared = self.prepare_dicts(links)

        if not dry_run:
            self.write_links(prepared, yt_path)
        logging.info('All done')

        if self.errors:
            with open(ERRORS_FILE_NAME, 'w') as f:
                f.writelines(self.errors)
            raise Exception('points without coords')

    def get_stat(self, point_from, point_to):
        return self.stat.get((point_from.point_key, point_to.point_key), 0)

    def get_sitemap(self):
        return fetch_routes() | fetch_searches()

    def calc_distance(self, point_from, point_to):
        if (point_from.latitude and point_from.longitude and
                point_to.latitude and point_to.longitude):

            return geobase.calculate_points_distance(
                point_from.latitude, point_from.longitude,
                point_to.latitude, point_to.longitude
            )

        self.errors.append('{}\t{}\n'.format(
            'point_from = {} latitude = {} longitude = {}'.format(point_from.point_key, point_from.latitude, point_from.longitude),
            'point_to = {} latitude = {} longitude = {}'.format(point_to.point_key, point_to.latitude, point_to.longitude),
        ))

    def calc_relevance(self, point_from, point_to, relevance_point):
        distance = self.calc_distance(point_to, relevance_point)
        if not distance:
            return

        rel = (self.proximity_relevance / (distance + self.distance_additive)
               + self.get_stat(point_from, relevance_point))

        return rel

    def generate_from_sitemap(self):
        settlements = Settlement.objects.all()
        stations = Station.objects.filter(id=Station.ADLER_ID)
        point_by_slug = {point.slug: point for point in chain(settlements, stations)}

        routes = self.get_sitemap()
        route_items = []

        for slug_from, slug_to in routes:
            point_from, point_to = point_by_slug.get(slug_from), point_by_slug.get(slug_to)
            if (point_from is None or point_to is None) and settings.APPLIED_CONFIG == 'testing':
                # в тестинге может не быть слагов из продовой статистики
                continue
            count = self.get_stat(point_from, point_to)
            route_items.append(RouteItem(point_from, point_to, count))

        routes_by_from = defaultdict(list)
        for route_item in route_items:
            routes_by_from[route_item.point_from].append(route_item)

        linked_routes_by_route = defaultdict(list)

        for point_from, routes_point_from in routes_by_from.items():
            routes_point_from.sort(key=lambda value: value.count, reverse=True)
            viewed = set()
            # A-Б
            for route in routes_point_from:
                relevant_routes = []
                # А-В
                for relevant_route in routes_point_from:
                    if route != relevant_route:
                        relevance = self.calc_relevance(route.point_from, route.point_to, relevant_route.point_to)
                        if not relevance:
                            continue

                        relevant_route_item = RelevantRouteItem(relevant_route, relevance)
                        relevant_routes.append(relevant_route_item)
                relevant_routes.sort(key=lambda x: x.relevance, reverse=True)

                added = set()
                for relevant_route in relevant_routes:
                    if relevant_route.route not in viewed:
                        linked_routes_by_route[route].append(relevant_route)
                        viewed.add(relevant_route.route)
                        added.add(relevant_route.route)
                    if len(linked_routes_by_route[route]) == self.unviewed_links_count:
                        break

                relevant_routes = [rel_route for rel_route in relevant_routes if rel_route.route not in added]
                deficient_links_count = self.links_count - len(linked_routes_by_route[route])

                if len(relevant_routes) > deficient_links_count:
                    total_relevance = sum(rel_route.relevance for rel_route in relevant_routes)
                    probabilities = [rel_route.relevance / total_relevance for rel_route in relevant_routes]
                    rand_ids = list(np.random.choice(
                        len(relevant_routes),
                        deficient_links_count,
                        p=probabilities,
                        replace=False
                    ))

                    for rand_id in rand_ids:
                        linked_routes_by_route[route].append(relevant_routes[rand_id])
                        viewed.add(relevant_routes[rand_id].route)
                else:
                    linked_routes_by_route[route].extend(relevant_routes)
                    viewed.update(rel_route.route for rel_route in relevant_routes)

                linked_routes_by_route[route].sort(key=lambda x: x.relevance, reverse=True)

        return linked_routes_by_route

    def prepare_dicts(self, route_links):
        result = []
        for route, relevant_routes in route_links.items():
            for relevant_route in relevant_routes:
                result.append(
                    {
                        'from_key': route.from_key,
                        'to_key': route.to_key,
                        'crosslink_from_key': relevant_route.route.from_key,
                        'crosslink_to_key': relevant_route.route.to_key,
                        'crosslink_from_slug': relevant_route.route.point_from.slug,
                        'crosslink_to_slug': relevant_route.route.point_to.slug,
                        'crosslink_from_nominative': relevant_route.route.point_from.L_title(),
                        'crosslink_to_nominative': relevant_route.route.point_to.L_title(),
                        'crosslink_relevance': relevant_route.relevance
                    }
                )
        return result

    def write_links(self, route_links, yt_path):
        with self.yt_client.Transaction():
            if self.yt_client.exists(yt_path):
                self.yt_client.remove(yt_path)
            self.yt_client.write_table(yt_path, route_links)


def generate_links(
        links_count=6, unviewed_links_count=2,
        proximity_relevance=1, distance_additive=0.1,
        stat_days_period=30,
        crosslinks_yt_table=None, dry_run=False
):
    yt_path = crosslinks_yt_table or settings.CROSS_LINKS_YT_TABLE

    stat = stat_yql(days_period=stat_days_period)

    runner = Runner(
        settings.YT_PROXY, settings.YT_TOKEN,
        links_count=links_count, unviewed_links_count=unviewed_links_count,
        proximity_relevance=proximity_relevance, distance_additive=distance_additive,
        stat=stat
    )
    runner.generate_links(yt_path=yt_path, dry_run=dry_run)
