# -*- coding: utf-8 -*-

import heapq
import logging
import time as os_time
from collections import defaultdict
from datetime import datetime, timedelta
from itertools import takewhile, islice, groupby, imap

import pytz
from django.db.models import Q

from travel.avia.library.python.common.models.geo import Station, Settlement
from travel.avia.library.python.common.models.schedule import RThread, RTStation
from travel.avia.library.python.common.models.transport import TransportType
from travel.avia.library.python.common.utils import environment
from travel.avia.library.python.common.utils.caching import cache_method_result
from travel.avia.library.python.common.utils.date import MSK_TZ, RunMask
from travel.avia.library.python.route_search.helpers import _t_type2t_type_list, fetch_unrelated, LimitConditions, remove_duplicates, \
    remove_through_trains
from travel.avia.library.python.route_search.models import IntervalRThreadSegment, RThreadSegment, AllDaysRThreadSegment, RThreadSegmentGroup


log = logging.getLogger(__name__)


TRAIN_T_TYPE_ID = 1
BASIC_THREAD_TYPE_ID = 1

NODEROUTE_THREAD_EXTRA = {
    'stops_translations': 'www_znoderoute2.stops_translations',
    'rtstation_from_id': 'www_znoderoute2.rtstation_from_id',
    'rtstation_to_id': 'www_znoderoute2.rtstation_to_id',
    'station_from_id': 'www_znoderoute2.station_from_id',
    'station_to_id': 'www_znoderoute2.station_to_id',
    'settlement_from_id': 'www_znoderoute2.settlement_from_id',
    'settlement_to_id': 'www_znoderoute2.settlement_to_id',
}

I_NODEROUTE_THREAD_EXTRA = {
    'stops_translations': 'www_znoderoute2.stops_translations',
    'rtstation_from_id': 'www_znoderoute2.rtstation_from_id',
    'rtstation_to_id': 'www_znoderoute2.rtstation_to_id',
    'station_from_id': 'www_znoderoute2.station_from_id',
    'station_to_id': 'www_znoderoute2.station_to_id',
    'settlement_from_id': 'www_znoderoute2.settlement_from_id',
    'settlement_to_id': 'www_znoderoute2.settlement_to_id',
}

MINUTES_IN_DAY = 60 * 24


def add_time(time, **kwargs):
    return (
        datetime.combine(datetime(2000, 1, 1), time) + timedelta(**kwargs)
    ).time()


class PreSegment(object):
    def __init__(self, thread):
        rts_from = thread.rtstation_from
        rts_to = thread.rtstation_to

        self.thread = thread
        self.rts_from_departure_time = add_time(thread.tz_start_time,
                                                minutes=rts_from.tz_departure)
        self.rts_tz_to = pytz.timezone(rts_to.time_zone)
        self.mask_shift = (
            thread.tz_start_time.hour * 60 + thread.tz_start_time.minute + rts_from.tz_departure
        ) / MINUTES_IN_DAY
        self.pseudo_duration_td = timedelta(minutes=rts_to.tz_arrival - rts_from.tz_departure)

    def get_cmp_key(self):
        return self.rts_from_departure_time

    def __cmp__(self, other):
        return cmp(self.get_cmp_key(), other.get_cmp_key())


class PreSegmentAware(object):
    def get_cmp_key(self):
        return (self.loc_departure_dt, self.loc_arrival_dt, self.duration, self.thread_start_date)

    def __cmp__(self, other):
        return cmp(self.get_cmp_key(), other.get_cmp_key())


class PreSegmentAllDaysAware(object):
    def get_cmp_key(self):
        return (self.loc_departure_time, self.loc_arrival_time, self.duration, self.thread_start_date)

    def __cmp__(self, other):
        return cmp(self.get_cmp_key(), other.get_cmp_key())


class PlainSegmentSearch(object):
    """
    == Класс для поиска обычных ниток ==
    """
    def __init__(self, point_from, point_to, t_type=None, threads_filter=None):
        """
        :param point_from: От
        :type point_from: Settlement | Station

        :param point_to: До
        :type point_to: Settlement | Station

        :type t_type: код типа транспорта или тип транспорта или список кодов или список типов
            транспорта
        :type t_type: str|unicode | TransportType | list[str|unicode] | list[TransportType]

        :param threads_filter: Q-object, изменяющий набор получаемых из базы RThread'ов
        """

        assert isinstance(point_from, (Station, Settlement))
        assert isinstance(point_to, (Station, Settlement))

        self.point_from = point_from
        self.point_to = point_to
        self.t_types = _t_type2t_type_list(t_type)
        self.threads_filter = Q() if not threads_filter else threads_filter

    @cache_method_result
    def _get_threads(self):
        start = os_time.time()

        # Ограничение по важности и типам транспорта
        limit_conditions = LimitConditions(self.point_from, self.point_to, self.t_types)

        threads_qs = limit_conditions.filter_threads_qs(RThread.objects.all())
        threads_qs = threads_qs.exclude(Q(type__code='cancel') | Q(type__code='interval'))
        threads_qs = threads_qs.filter(self.threads_filter)

        if set(self.t_types) != set(TransportType.objects.all_cached()):
            threads_qs = threads_qs.filter(t_type__in=self.t_types)

        threads = list(threads_qs.extra(select=NODEROUTE_THREAD_EXTRA).prefetch_related('supplier'))

        log.debug(u'_get_threads ff %s (total %s)', os_time.time() - start, len(threads))

        fetch_unrelated(threads, Station, 'station_from', 'station_to')
        log.debug(u'_get_threads rf st %s', os_time.time() - start)

        fetch_unrelated(threads, RTStation, 'rtstation_from', 'rtstation_to')

        log.debug(u'_get_threads fl %s', os_time.time() - start)

        # Убираем дубликаты - треды, проходящие через разные станции одного города (RASP-5190).
        # Делаем это до выборки rtstations, так как количество оных уменьшает.
        # После выборки stations - так как нужен приоритет.
        threads = list(remove_duplicates(threads))

        by_rts_time_zone = defaultdict(list)

        for t in threads:
            by_rts_time_zone[t.rtstation_from.time_zone].append(t)

        log.debug(u'_get_threads %s', os_time.time() - start)

        return by_rts_time_zone

    @cache_method_result
    def get_single_zone_presegments(self, zone):
        threads = self._get_threads()[zone]

        start = os_time.time()

        results = map(PreSegment, threads)
        results = sorted(remove_through_trains(results))

        log.debug(u'get_single_zone_presegments %s', os_time.time() - start)

        return results

    def get_single_zone_iter(self, zone, from_dt_awr):
        notawr_presegments = self.get_single_zone_presegments(zone)

        out_tz_from = self.point_from.pytz
        out_tz_to = self.point_to.pytz

        rts_tz = pytz.timezone(zone)

        start_dt = from_dt_awr.astimezone(rts_tz).replace(tzinfo=None)

        start_date = start_dt.date()

        # Не будем отдавать расписание больше 100 дней вперед
        achtung_limit = start_date + timedelta(100)

        while True:
            for presegm in notawr_presegments:
                loc_dt = datetime.combine(start_date, presegm.rts_from_departure_time)

                if loc_dt < start_dt:
                    continue

                thread_start_date = start_date - timedelta(presegm.mask_shift)

                if not presegm.thread.runs_at(thread_start_date):
                    continue

                presegm_awr = PreSegmentAware()

                presegm_awr.thread = presegm.thread

                presegm_awr.thread_start_date = thread_start_date

                presegm_awr.loc_departure_dt = rts_tz.localize(loc_dt).astimezone(out_tz_from)
                presegm_awr.loc_arrival_dt = presegm.rts_tz_to.localize(
                    loc_dt + presegm.pseudo_duration_td
                ).astimezone(out_tz_to)
                presegm_awr.duration = presegm_awr.loc_departure_dt - presegm_awr.loc_arrival_dt

                yield presegm_awr

            start_date += timedelta(1)

            if start_date > achtung_limit:
                return

    def get_all_day_presegments_by_zone(self, zone):
        notawr_presegments = self.get_single_zone_presegments(zone)

        out_tz_from = self.point_from.pytz
        out_tz_to = self.point_to.pytz

        rts_tz = pytz.timezone(zone)

        tz_today = MSK_TZ.localize(environment.now()).astimezone(rts_tz).date()

        result = []

        for presegm in notawr_presegments:
            thread = presegm.thread

            thread_start_date = thread.first_run(tz_today - timedelta(days=presegm.mask_shift))
            if not thread_start_date:
                continue

            loc_dt = datetime.combine(thread_start_date + timedelta(days=presegm.mask_shift),
                                      presegm.rts_from_departure_time)

            presegm_awr = PreSegmentAllDaysAware()

            presegm_awr.thread = presegm.thread

            presegm_awr.thread_start_date = thread_start_date

            presegm_awr.loc_departure_dt = rts_tz.localize(loc_dt).astimezone(out_tz_from)
            presegm_awr.loc_arrival_dt = presegm.rts_tz_to.localize(
                loc_dt + presegm.pseudo_duration_td
            ).astimezone(out_tz_to)
            presegm_awr.duration = presegm_awr.loc_departure_dt - presegm_awr.loc_arrival_dt

            presegm_awr.loc_departure_time = presegm_awr.loc_departure_dt.time()
            presegm_awr.loc_arrival_time = presegm_awr.loc_departure_dt.time()

            result.append(presegm_awr)

        result.sort()

        return result

    def gen_presegments_from_dt(self, from_dt_aware):
        threads_by_rts_time_zone = self._get_threads()

        zones = threads_by_rts_time_zone.keys()

        threads_iterators = []
        for zone in zones:
            threads_iterators.append(
                self.get_single_zone_iter(zone, from_dt_aware)
            )

        return heapq.merge(*threads_iterators)

    def gen_from(self, from_dt_aware):
        now = MSK_TZ.localize(environment.now())

        for presegm in self.gen_presegments_from_dt(from_dt_aware):

            thread = presegm.thread
            segment = RThreadSegment()

            segment.station_from = thread.station_from
            segment.station_to = thread.station_to

            segment.thread = thread

            segment.departure = presegm.loc_departure_dt
            segment.arrival = presegm.loc_arrival_dt

            segment.start_date = presegm.thread_start_date

            segment.rtstation_from = thread.rtstation_from
            segment.rtstation_to = thread.rtstation_to

            segment.stops_translations = thread.stops_translations

            segment.now = now

            segment._init_data()

            yield segment

    def search(self, from_dt_aware, to_dt_aware, add_z_tablos=False, max_count=None):
        if from_dt_aware.tzinfo is None:
            raise ValueError('Start of departure range must be aware datetime object')

        if to_dt_aware.tzinfo is None:
            raise ValueError('End of departure range must be aware datetime object')

        start_t = os_time.time()

        segments = list(
            takewhile(lambda s: s.departure <= to_dt_aware,
                      islice(self.gen_from(from_dt_aware), max_count))
        )

        log.debug(
            u'Поиск %s %s %s %s %s отработал за %s',
            self.point_from.L_title(),
            self.point_to.L_title(),
            [t.code for t in self.t_types],
            from_dt_aware,
            to_dt_aware,
            os_time.time() - start_t
        )

        return segments

    def all_days_search(self):
        threads_by_rts_time_zone = self._get_threads()

        zones = threads_by_rts_time_zone.keys()

        threads_iterators = []
        for zone in zones:
            threads_iterators.append(
                self.get_all_day_presegments_by_zone(zone)
            )

        presegments_awr = list(heapq.merge(*threads_iterators))

        segments = []

        for presegment_awr in presegments_awr:
            segment = AllDaysRThreadSegment()

            thread = presegment_awr.thread
            if not thread.show_in_alldays_pages:
                continue

            segment.station_from = thread.station_from
            segment.station_to = thread.station_to

            segment.thread = thread

            segment.departure = presegment_awr.loc_departure_dt
            segment.arrival = presegment_awr.loc_arrival_dt

            segment.start_date = presegment_awr.thread_start_date

            segment.rtstation_from = thread.rtstation_from
            segment.rtstation_to = thread.rtstation_to

            segment.stops_translations = thread.stops_translations

            segment._init_data()

            segments.append(segment)

        return segments

    def all_days_group(self, title_grouping=True):
        """ RASPADMIN-618 """

        def group_key(segment):
            return (
                segment.departure.time(),
                segment.arrival.time(),
                segment.number,
                segment.title if title_grouping else None,
            )

        segments = self.all_days_search()
        segments.sort(
            key=lambda segment: (group_key(segment), segment.departure)
        )

        for (departure_time, arrival_time,
             number, title), group_segments in groupby(segments, group_key):
            group_segments = list(group_segments)
            mask = RunMask(today=group_segments[0].departure.date())

            for segment in group_segments:
                thread_mask = RunMask(segment.thread.year_days)

                if segment.mask_shift:
                    thread_mask.today = mask.today
                    thread_mask = thread_mask.shifted(segment.mask_shift)

                mask |= thread_mask

            yield RThreadSegmentGroup(
                number, title, departure_time, arrival_time,
                mask, group_segments
            )

    def count(self, from_dt_aware, to_dt_aware):
        if from_dt_aware.tzinfo is None:
            raise ValueError('Start of departure range must be aware datetime object')

        if to_dt_aware.tzinfo is None:
            raise ValueError('End of departure range must be aware datetime object')

        return sum(
            imap(
                lambda x: 1,
                takewhile(
                    lambda prs: prs.loc_departure_dt <= to_dt_aware,
                    self.gen_presegments_from_dt(from_dt_aware)
                )
            )
        )

    def get_service_types(self):
        t_types_ids = {
            t.t_type_id
            for tlist in self._get_threads().values()
            for t in tlist
        }

        return [TransportType.objects.get(pk=i).code for i in t_types_ids]


class IntervalSegmentSearch(object):
    """
    == Класс для поиска интервальных ниток ==
    """
    def __init__(self, point_from, point_to, t_type=None):
        """
        :param point_from: От
        :type point_from: Settlement | Station

        :param point_to: До
        :type point_to: Settlement | Station

        :type t_type: код типа транспорта или тип транспорта или список кодов или список типов
            транспорта
        :type t_type: str|unicode | TransportType | list[str|unicode] | list[TransportType]
        """

        assert isinstance(point_from, (Station, Settlement))
        assert isinstance(point_to, (Station, Settlement))

        self.point_from = point_from
        self.point_to = point_to
        self.t_types = _t_type2t_type_list(t_type)

    @cache_method_result
    def _get_threads(self):
        limit_conditions = LimitConditions(self.point_from, self.point_to, self.t_types)

        threads_qs = limit_conditions.filter_threads_qs(RThread.objects.all())
        threads_qs = threads_qs.exclude(type__code='cancel').filter(type__code='interval')

        if set(self.t_types) != set(TransportType.objects.all_cached()):
            threads_qs = threads_qs.filter(t_type__in=self.t_types)

        threads = list(threads_qs.extra(select=I_NODEROUTE_THREAD_EXTRA).prefetch_related('supplier'))

        fetch_unrelated(threads, Station, 'station_from', 'station_to')

        fetch_unrelated(threads, RTStation, 'rtstation_from', 'rtstation_to')

        # Убираем дубликаты - треды, проходящие через разные станции одного города (RASP-5190).
        # Делаем это до выборки rtstations, так как количество оных уменьшает.
        # После выборки stations - так как нужен приоритет.
        threads = list(remove_duplicates(threads))

        threads.sort(key=lambda t: t.begin_time)

        return threads

    def search_by_day(self, loc_date):
        """
        Предполагаем, что интервальные рейсы у нас всегда по местному времени.
        Их дни хождения тоже в местном времени.
        Поэтому всегда ожидаем дату поиска в локальном времени.
        """

        segments = []

        threads = self._get_threads()

        if not threads:
            return []

        for thread in threads:
            if not thread.runs_at(loc_date):
                continue

            segment = IntervalRThreadSegment()

            segment.station_from = thread.station_from
            segment.station_to = thread.station_to

            segment.thread = thread

            segment.rtstation_from = thread.rtstation_from
            segment.rtstation_to = thread.rtstation_to

            segment.stops_translations = thread.stops_translations

            segment.last_departure = thread.station_from.pytz.localize(
                datetime.combine(loc_date, thread.end_time)
            )

            segment.msk_last_departure = segment.last_departure.astimezone(MSK_TZ)

            segment._init_data()
            segment.duration = timedelta(
                minutes=segment.rtstation_to.tz_arrival - segment.rtstation_from.tz_departure)

            segments.append(segment)

        return segments

    def all_days_search(self):
        segments = []

        threads = self._get_threads()

        if not threads:
            return []

        for thread in threads:
            if not thread.show_in_alldays_pages:
                continue

            segment = IntervalRThreadSegment()

            segment.station_from = thread.station_from
            segment.station_to = thread.station_to

            segment.thread = thread

            segment.rtstation_from = thread.rtstation_from
            segment.rtstation_to = thread.rtstation_to

            segment.stops_translations = thread.stops_translations

            segment._init_data()
            segment.duration = timedelta(
                minutes=segment.rtstation_to.tz_arrival - segment.rtstation_from.tz_departure)

            segments.append(segment)

        return segments

    def count(self, loc_date):
        return len(self.search_by_day(loc_date))

    def get_service_types(self):
        t_types_ids = {t.t_type_id for t in self._get_threads()}

        return [TransportType.objects.get(pk=i).code for i in t_types_ids]
