# coding: utf-8
from __future__ import absolute_import, division, print_function, unicode_literals

import logging
from collections import defaultdict
from itertools import groupby, chain

from django.conf import settings
from yt.wrapper import YtClient

from common.models.geo import Settlement, Station
from common.models_utils.geo import Point
from common.models.transport import TransportType
from route_search.models import ZNodeRoute2


class PointInfo(object):
    def __init__(self, obj):
        self.slug = obj.slug
        self.title = obj.L_title()
        self.popular_title = obj.L_popular_title()

    def __repr__(self):
        return '<{} slug={!r} title={!r} popular_title={!r}>'.format(self.__class__.__name__, self.slug, self.title, self.popular_title)


class Runner(object):
    """https://st.yandex-team.ru/TRAVELORGANIC-129"""
    def __init__(self, yt_proxy, yt_token):
        self.yt_client = YtClient(proxy=yt_proxy, token=yt_token)
        self.settlements = set()
        self.stations = set()

    def generate_canonical(self, yt_path, dry_run=False):
        canonical_by_search = self.generate_canonical_by_search()
        point_by_key = self.get_point_by_key()
        prepared = self.prepare_dicts(canonical_by_search, point_by_key)

        if not dry_run:
            self.write_canonical(prepared, yt_path)
        logging.info('All done')

    def get_znr_objs(self):
        return (
            ZNodeRoute2.objects.filter(t_type_id=TransportType.TRAIN_ID)
            .values_list(
                'settlement_from_id', 'settlement_to_id',
                'station_from_id', 'station_to_id',
                'good_for_start', 'good_for_finish'
            )
            .distinct()
            .order_by('settlement_from_id', 'settlement_to_id')
        )

    def get_key(self, model, model_id):
        return Point.get_point_key(model, model_id)

    def generate_canonical_by_search(self):
        canonical_by_search = {}

        znr_objs = self.get_znr_objs()
        for (settlement_from_id, settlement_to_id), znr_objs_group in groupby(
                znr_objs,
                lambda (settlement_from, settlement_to, station_from, station_to, good_for_start, good_for_finish):
                (settlement_from, settlement_to)
        ):
            settlement_from_key = self.get_key(Settlement, settlement_from_id)
            settlement_to_key = self.get_key(Settlement, settlement_to_id)
            self.settlements.add(settlement_from_id)
            self.settlements.add(settlement_to_id)
            from_to_stations = defaultdict(set)
            to_from_stations = defaultdict(set)
            settlement_to_stations = set()
            settlement_from_stations = set()
            station_pairs = set()

            for (_, _, station_from_id, station_to_id, good_for_start, good_for_finish) in set(znr_objs_group):
                self.stations.add(station_from_id)
                self.stations.add(station_to_id)
                station_from_key = self.get_key(Station, station_from_id)
                station_to_key = self.get_key(Station, station_to_id)

                canonical_by_search[(station_from_key, station_to_key)] = (station_from_key, station_to_key)
                station_pairs.add((station_from_key, station_to_key))
                if good_for_finish:
                    from_to_stations[station_from_key].add(station_to_key)
                    settlement_to_stations.add(station_to_key)
                if good_for_start:
                    to_from_stations[station_to_key].add(station_from_key)
                    settlement_from_stations.add(station_from_key)

                if settlement_from_id:
                    canonical_by_search[(settlement_from_key, station_to_key)] = (settlement_from_key, station_to_key)

                if settlement_to_id:
                    canonical_by_search[(station_from_key, settlement_to_key)] = (station_from_key, settlement_to_key)

            if not settlement_from_id and not settlement_to_id:
                continue
            elif settlement_from_id and settlement_to_id:
                canonical_by_search[(settlement_from_key, settlement_to_key)] = (settlement_from_key, settlement_to_key)
                if settlement_from_id == settlement_to_id:
                    continue

            for (from_key, to_key) in station_pairs:
                stations_to_keys = from_to_stations[from_key]
                stations_from_keys = to_from_stations[to_key]

                if settlement_to_id and stations_to_keys == {to_key}:
                    canonical_by_search[(from_key, to_key)] = (from_key, settlement_to_key)
                elif settlement_from_id and stations_from_keys == {from_key}:
                    canonical_by_search[(from_key, to_key)] = (settlement_from_key, to_key)

            if settlement_from_id and settlement_to_id:
                uniq_from_station = list(settlement_from_stations)[0] if len(settlement_from_stations) == 1 else None
                uniq_to_station = list(settlement_to_stations)[0] if len(settlement_to_stations) == 1 else None

                if len(settlement_from_stations) == 1:
                    canonical_by_search[(uniq_from_station, settlement_to_key)] = (settlement_from_key, settlement_to_key)
                if len(settlement_to_stations) == 1:
                    canonical_by_search[(settlement_from_key, uniq_to_station)] = (settlement_from_key, settlement_to_key)
                if len(settlement_from_stations) == 1 and len(settlement_to_stations) == 1:
                    canonical_by_search[(uniq_from_station, uniq_to_station)] = (settlement_from_key, settlement_to_key)

        return canonical_by_search

    def get_point_by_key(self):
        settlements = Settlement.objects.filter(id__in=self.settlements).only('id', 'slug', 'title_ru')
        stations = Station.objects.filter(id__in=self.stations).only('id', 'slug', 'title_ru', 'popular_title_ru')

        return {s.point_key: PointInfo(s) for s in chain(stations, settlements)}

    def prepare_dicts(self, canonical_by_search, point_by_key):
        result = []
        for (search_from_key, search_to_key), (canonical_from_key, canonical_to_key) in canonical_by_search.items():
            search_from_point = point_by_key[search_from_key]
            search_to_point = point_by_key[search_to_key]
            from_point = point_by_key[canonical_from_key]
            to_point = point_by_key[canonical_to_key]
            result.append(
                {
                    'from_slug': search_from_point.slug,
                    'to_slug': search_to_point.slug,
                    'canonical_from_slug': from_point.slug,
                    'canonical_to_slug': to_point.slug,
                    'canonical_from_popular_title': from_point.popular_title,
                    'canonical_to_popular_title': to_point.popular_title,
                    'canonical_from_title': from_point.title,
                    'canonical_to_title': to_point.title
                }
            )
        return result

    def write_canonical(self, canonical, yt_path):
        with self.yt_client.Transaction():
            if self.yt_client.exists(yt_path):
                self.yt_client.remove(yt_path)
            self.yt_client.write_table(yt_path, canonical)


def generate_canonical(canonical_yt_table=None, dry_run=False):
    yt_path = canonical_yt_table or settings.CANONICAL_YT_TABLE
    runner = Runner(settings.YT_PROXY, settings.YT_TOKEN)
    runner.generate_canonical(yt_path=yt_path, dry_run=dry_run)
