# coding: utf-8
from __future__ import absolute_import, division, print_function, unicode_literals

import itertools
from collections import defaultdict, namedtuple
from contextlib import contextmanager
from datetime import datetime, timedelta
from functools import partial, total_ordering

import ujson
from django.conf import settings
from django.utils.functional import cached_property

from common.models.geo import Point, Station, Station2Settlement, StationMajority
from common.models.schedule import RThread, RThreadType, RTStation
from common.models.transport import TransportType
from travel.rasp.library.python.common23.date import environment
from common.utils.date import RunMask
from common.utils.title_generator import TitleGenerator
from route_search.helpers import LimitConditions
from travel.rasp.wizards.wizard_lib.utils.functional import tuplify
from travel.rasp.wizards.wizard_lib.utils.originalbox import OriginalBox
from travel.rasp.wizards.wizard_lib.utils.shrinked_timezones import ShrinkedTimezones

DAY_MINUTES = 24 * 60
SEARCH_DAYS = 100

_DAY = timedelta(days=1)
_DEFAULT_THREAD_PREDICATE = lambda thread: True  # noqa E731: lambda is shorter than def
Translations = namedtuple('Translations', settings.MODEL_LANGUAGES)
STATIONS_RETURNS_MANY_SEGMENTS = {
    9607193, 9607190,  # Томск
}


class RThreadBoxMixin(object):
    def cache_title(self, title_getter):
        self.title = Translations(*itertools.imap(title_getter, Translations._fields))

    def cache_stops(self, departure_stop, arrival_stop):
        self.departure_stop = departure_stop
        self.arrival_stop = arrival_stop


class RTStationBoxMixin(object):
    stops_text = None

    def cache_tzinfo(self, timezones):
        self.tzinfo = timezones.get(self.time_zone)

    def cache_stops_text(self):
        stops_translations = self.schedule.stops_translations
        if stops_translations is not None:
            stops_translations_data = ujson.loads(stops_translations)
            self.stops_text = Translations(*(stops_translations_data.get(lang) for lang in Translations._fields))


RThreadBox = OriginalBox(
    'id',
    'year_days',
    'tz_start_time',
    'type_id',
    'type' | OriginalBox('code'),
    'title_common',
    'number',
    'uid',
    'tariff_type_id',
    'express_type',
    't_subtype_id',
    instance_mixins=[RThreadBoxMixin]
)
RTStationBox = OriginalBox(
    'id',
    'station_id',
    'station' | OriginalBox('majority_id'),
    'tz_departure',
    'tz_arrival',
    'time_zone',
    'is_searchable_from',
    'is_searchable_to',
    'departure_code_sharing',
    'arrival_code_sharing',
    'platform',
    'schedule' | OriginalBox('stops_translations'),
    'departure_direction_id',
    'departure_subdir',
    'in_station_schedule',
    instance_mixins=[RTStationBoxMixin]
)


def _fetch_station_settlement_ids():
    station_settlement_ids = defaultdict(set)
    for station_id, settlement_id in itertools.chain(
        Station.objects.exclude(settlement_id=None).values_list('id', 'settlement_id'),
        Station2Settlement.objects.values_list('station_id', 'settlement_id')
    ):
        station_settlement_ids[station_id].add(settlement_id)
    return dict(station_settlement_ids)


def _cache_thread_titles(threads):
    title_dicts = tuple(ujson.loads(thread.title_common) for thread in threads)
    points = Point.in_bulk({point_key
                            for title_dict in title_dicts
                            for point_key in TitleGenerator.extract_point_keys(title_dict)})
    for thread, title_dict in itertools.izip(threads, title_dicts):
        thread.cache_title(lambda lang: TitleGenerator.L_title(title_dict, prefetched_points=points, lang=lang))


@tuplify
def _filter_segments(segments, majority_limit):
    max_departure_majority_id = majority_limit.from_max_majority_id
    max_arrival_majority_id = majority_limit.to_max_majority_id
    for segment in segments:
        departure, arrival = segment

        if not departure.is_searchable_from or departure.station.majority_id > max_departure_majority_id:
            continue

        if not arrival.is_searchable_to or arrival.station.majority_id > max_arrival_majority_id:
            continue

        if departure.departure_code_sharing and arrival.arrival_code_sharing:
            continue

        yield segment


def _find_best_segments(thread_stops):
    """
    Если несколько станций отправления или назначения входят в STATIONS_RETURNS_MANY_SEGMENTS,
    то будет возвращен _find_best_segment для каждой группировки (departure.station_id, arrival.station_id).
    При этом пары (departure, arrival), не входящие в STATIONS_RETURNS_MANY_SEGMENTS будут входить в одну из созданных
    групп, просто чтобы не потерялись, вдруг они лучше всех.
    """
    segment_groups = {}
    default_key = (None, None)
    for (departure, arrival) in thread_stops:
        key1, key2 = default_key
        if departure.station_id in STATIONS_RETURNS_MANY_SEGMENTS:
            key1 = departure.station_id
        if arrival.station_id in STATIONS_RETURNS_MANY_SEGMENTS:
            key2 = arrival.station_id
        segment_groups.setdefault((key1, key2), [])
        segment_groups[(key1, key2)].append((departure, arrival))
    if len(segment_groups) == 1:
        yield _find_best_segment(thread_stops)
    else:
        default_group = []
        if default_key in segment_groups:
            default_group = segment_groups.pop(default_key)
        for key, group in segment_groups.items():
            if default_group:
                group = group + default_group
                default_group = []
            yield _find_best_segment(group)


def _find_best_segment(thread_stops):
    """
    Получаем первым элементом рейс от станции с максимальным приоритетом до станции с максимальным приоритетом,
    если станций с максимальным приоритетом несколько, то от самой поздней до самой ранней из них.
    """
    return min(thread_stops, key=lambda (departure, arrival): (departure.station.majority_id,
                                                               arrival.station.majority_id,
                                                               departure.tz_arrival is not None,
                                                               -(departure.tz_departure or 0),
                                                               arrival.tz_arrival))


def _time_as_minutes(time_obj):
    return time_obj.hour * 60 + time_obj.minute


def _iter_thread_runs(thread, min_thread_start_dt):
    year_days = thread.year_days
    tz_start_time = thread.tz_start_time

    thread_start_dt = datetime.combine(min_thread_start_dt.date(), tz_start_time)
    if tz_start_time < min_thread_start_dt.time():
        thread_start_dt += _DAY

    while True:
        yield thread_start_dt if RunMask.runs_at(year_days, thread_start_dt) else None
        thread_start_dt += _DAY


def _exclude_through_trains(thread_segments):
    basic_trains = defaultdict(list)
    through_trains = defaultdict(list)
    for thread_segment in thread_segments:
        thread, (departure, arrival) = thread_segment
        start_minutes = _time_as_minutes(thread.tz_start_time)
        train_key = (
            departure.station_id,
            arrival.station_id,
            (start_minutes + departure.tz_departure) % DAY_MINUTES,
            (start_minutes + arrival.tz_arrival) % DAY_MINUTES
        )
        if thread.type_id == RThreadType.THROUGH_TRAIN_ID:
            if train_key in basic_trains:
                continue
            through_trains[train_key].append(thread_segment)
        else:
            if train_key in through_trains:
                del through_trains[train_key]
            basic_trains[train_key].append(thread_segment)

    return tuple(thread_segment
                 for thread_segments in itertools.chain(basic_trains.itervalues(), through_trains.itervalues())
                 for thread_segment in thread_segments)


@total_ordering
class _BaseRawSegment(namedtuple(
    '_BaseRawSegment',
    'thread, thread_start_dt, departure, departure_station, arrival, arrival_station'
)):
    """
        Объект, который умеет сравнивать сегменты в одной таймзоне.
        Это позволяет не вызывать медленный localize без необходимости.
    """
    def __lt__(self, other):
        return (
            self._event_naive_dt < other._event_naive_dt if self.event_stop.tzinfo is other.event_stop.tzinfo else
            self.event_dt < other.event_dt
        )


class _RawDepartureSegment(_BaseRawSegment):
    def __init__(self, *args, **kwargs):
        self._event_naive_dt = self.thread_start_dt + timedelta(minutes=self.departure.tz_departure)

    @cached_property
    def departure_dt(self):
        return self.departure.tzinfo.localize(self._event_naive_dt)

    @cached_property
    def arrival_dt(self):
        return self.arrival.tzinfo.localize(self.thread_start_dt + timedelta(minutes=self.arrival.tz_arrival))

    @property
    def event_stop(self):
        return self.departure

    @property
    def event_station(self):
        return self.departure_station

    @property
    def event_dt(self):
        return self.departure_dt


class _RawArrivalSegment(_BaseRawSegment):
    def __init__(self, *args, **kwargs):
        self._event_naive_dt = self.thread_start_dt + timedelta(minutes=self.arrival.tz_arrival)

    @cached_property
    def departure_dt(self):
        return self.departure.tzinfo.localize(self.thread_start_dt + timedelta(minutes=self.departure.tz_departure))

    @cached_property
    def arrival_dt(self):
        return self.arrival.tzinfo.localize(self._event_naive_dt)

    @property
    def event_stop(self):
        return self.arrival

    @property
    def event_station(self):
        return self.arrival_station

    @property
    def event_dt(self):
        return self.arrival_dt


class ThreadsCache(object):
    _threads_cache = _station_stops_cache = _settlement_stops_cache = _stations_cache = None

    def __init__(self, t_type_id, threads_qs=None):
        if threads_qs is None:
            threads_qs = RThread.objects.all()
        self._threads_qs = threads_qs.filter(t_type_id=t_type_id)
        self._t_type_id = t_type_id

    def _build_caches(self, station_settlement_ids, timezones):
        threads_cache = {}
        station_stops = defaultdict(lambda: defaultdict(list))
        settlement_stops = defaultdict(lambda: defaultdict(list))

        threads_qs = (self._threads_qs.exclude(type_id=RThreadType.CANCEL_ID)
                                      .exclude(year_days=RunMask.EMPTY_YEAR_DAYS)
                                      .exclude(hidden=True)
                                      .exclude(route__hidden=True)
                                      .order_by())
        rtstations_qs = (RTStation.objects.filter(is_technical_stop=False)
                                          .filter(station__majority__lt=StationMajority.NOT_IN_SEARCH_ID)
                                          .order_by())

        settlement_reachable_cache = defaultdict(set)

        for thread in RThreadBox.iter_queryset(threads_qs):
            thread_id = thread.id
            threads_cache[thread_id] = thread
            departure_stop = arrival_stop = None
            reachable_settlement_ids = set()

            for rtstation in RTStationBox.iter_queryset(rtstations_qs.filter(thread_id=thread_id)):
                if rtstation.tz_arrival is None:
                    departure_stop = rtstation
                elif rtstation.tz_departure is None:
                    arrival_stop = rtstation

                if rtstation.tz_arrival == rtstation.tz_departure:
                    continue

                rtstation.cache_tzinfo(timezones)
                rtstation.cache_stops_text()

                station_id = rtstation.station_id
                station_stops[station_id][thread_id].append(rtstation)
                for settlement_id in station_settlement_ids.get(station_id, ()):
                    settlement_stops[settlement_id][thread_id].append(rtstation)
                    reachable_settlement_ids.add(settlement_id)

            reachable_settlement_ids = list(reachable_settlement_ids)
            for i in range(len(reachable_settlement_ids)):
                for j in range(i + 1, len(reachable_settlement_ids)):
                    settlement_reachable_cache[reachable_settlement_ids[i]].add(reachable_settlement_ids[j])
                    settlement_reachable_cache[reachable_settlement_ids[j]].add(reachable_settlement_ids[i])

            thread.cache_stops(departure_stop, arrival_stop)  # TODO: move the stops cache into the suburban wizard

        _cache_thread_titles(threads_cache.viewvalues())

        return (
            threads_cache,
            {station_id: dict(stops) for station_id, stops in station_stops.iteritems()},
            {settlement_id: dict(stops) for settlement_id, stops in settlement_stops.iteritems()},
            dict(settlement_reachable_cache)
        )

    @contextmanager
    def using_precache(self):
        if self._threads_cache is None:
            station_settlement_ids = _fetch_station_settlement_ids()

            now_utc = environment.now_utc()
            timezones = ShrinkedTimezones(now_utc - timedelta(days=settings.DAYS_TO_PAST),
                                          now_utc + timedelta(days=366))

            (
                self._threads_cache,
                self._station_stops_cache,
                self._settlement_stops_cache,
                self._settlement_reachable_cache
            ) = self._build_caches(station_settlement_ids, timezones)
            self._stations_cache = Station.objects.in_bulk(self._station_stops_cache.keys())

            try:
                yield
            finally:
                self._threads_cache = self._station_stops_cache = self._settlement_stops_cache = \
                    self._stations_cache = None
        else:
            yield

    def _get_point_thread_stops(self, point):
        if isinstance(point, Station):
            return self._station_stops_cache.get(point.id)
        else:
            return self._settlement_stops_cache.get(point.id)

    @tuplify
    def _find_thread_segments(self, thread_departures, thread_arrivals, majority_limit, thread_predicate):
        if not (thread_departures and thread_arrivals):
            return

        if thread_predicate is None:
            thread_predicate = _DEFAULT_THREAD_PREDICATE

        for thread_id in thread_departures.viewkeys() & thread_arrivals.viewkeys():
            thread = self._threads_cache[thread_id]
            if not thread_predicate(thread):
                continue

            departures = thread_departures[thread_id]
            arrivals = thread_arrivals[thread_id]
            segments = tuple((departure, arrival,)
                             for departure in departures for arrival in arrivals
                             if departure.id < arrival.id)
            if not segments:
                continue

            segments = _filter_segments(segments, majority_limit)
            if not segments:
                continue

            if len(segments) > 1:
                for segment in _find_best_segments(segments):
                    yield (thread, segment)
            else:
                yield (thread, segments[0])

    def _iter_raw_segments(self, thread, segment, min_departure_dt):
        departure, arrival = segment
        raw_segment_factory = partial(
            _RawDepartureSegment,
            thread=thread,
            departure=departure,
            departure_station=self._stations_cache[departure.station_id],
            arrival=arrival,
            arrival_station=self._stations_cache[arrival.station_id]
        )

        for thread_start_dt in _iter_thread_runs(
            thread, min_departure_dt.astimezone(departure.tzinfo) - timedelta(minutes=departure.tz_departure)
        ):
            yield raw_segment_factory(thread_start_dt=thread_start_dt) if thread_start_dt is not None else None

    def find_segments(self, departure_point, arrival_point, min_departure_dt, thread_predicate=None):
        assert self._threads_cache is not None

        majority_limit = LimitConditions(departure_point, arrival_point,
                                         t_types=(TransportType.objects.get(id=self._t_type_id),))
        if not majority_limit.t_types:
            return ()

        thread_segments = self._find_thread_segments(self._get_point_thread_stops(departure_point),
                                                     self._get_point_thread_stops(arrival_point),
                                                     majority_limit,
                                                     thread_predicate)
        raw_segments_gen = itertools.izip(*(self._iter_raw_segments(thread, segment, min_departure_dt)
                                            for thread, segment in _exclude_through_trains(thread_segments)))

        return (raw_segment
                for raw_segments in itertools.islice(raw_segments_gen, SEARCH_DAYS)
                for raw_segment in sorted(itertools.ifilter(None, raw_segments)))

    def get_settlement_reachable_cache(self):
        return self._settlement_reachable_cache

    def has_direction(self, from_id, to_id):
        ids = self._settlement_reachable_cache.get(from_id)
        if ids is None:
            return False
        return to_id in ids

    def _iter_raw_station_stops(self, thread, min_event_dt, stop, station):
        if stop.tz_departure is None:
            segment_class = _RawArrivalSegment
            minutes_offset = stop.tz_arrival
            departure_stop = thread.departure_stop
            departure_station = self._stations_cache[departure_stop.station_id]
            arrival_stop, arrival_station = stop, station
        else:
            segment_class = _RawDepartureSegment
            minutes_offset = stop.tz_departure
            departure_stop, departure_station = stop, station
            arrival_stop = thread.arrival_stop
            arrival_station = self._stations_cache[arrival_stop.station_id]

        raw_segment_factory = partial(
            segment_class,
            thread=thread,
            departure=departure_stop,
            departure_station=departure_station,
            arrival=arrival_stop,
            arrival_station=arrival_station
        )

        for thread_start_dt in _iter_thread_runs(
            thread, min_event_dt.astimezone(stop.tzinfo) - timedelta(minutes=minutes_offset)
        ):
            yield raw_segment_factory(thread_start_dt=thread_start_dt) if thread_start_dt is not None else None

    def find_station_segments(self, station, min_event_dt, thread_predicate=None):
        if thread_predicate is None:
            thread_predicate = _DEFAULT_THREAD_PREDICATE

        segment_iterators = []

        for thread_id, thread_stops in self._station_stops_cache.get(station.id, {}).iteritems():
            thread = self._threads_cache[thread_id]
            if not thread_predicate(thread) or (
                thread.type_id in (RThreadType.INTERVAL_ID, RThreadType.THROUGH_TRAIN_ID)
            ):
                continue

            segment_iterators.extend(
                self._iter_raw_station_stops(
                    thread=thread, min_event_dt=min_event_dt, stop=stop, station=station
                )
                for stop in thread_stops
                if stop.in_station_schedule
            )

        return (
            raw_segment
            for raw_segments in itertools.islice(itertools.izip(*segment_iterators), SEARCH_DAYS)
            for raw_segment in sorted(itertools.ifilter(None, raw_segments))
        )
