# -*- coding: utf-8 -*-

import logging
from itertools import chain
import time as time_module

from urllib import urlencode
from datetime import datetime, time, timedelta, date
from itertools import izip, groupby
from xml.etree import cElementTree as ET

from django.conf import settings
from django.core.cache import cache
from django.utils.functional import cached_property

from travel.avia.library.python.common.models.geo import Station
from travel.avia.library.python.common.models.schedule import RThread, RTStation
from travel.avia.library.python.common.models.transport import TransportType
from travel.avia.library.python.common.utils.date import get_msk_time, MSK_TZ
from travel.avia.library.python.common.utils.http import urlopen
from travel.avia.library.python.common.utils.text import md5_hex
from travel.avia.library.python.common.utils import environment
from travel.avia.library.python.common.views.tariffs import DisplayInfo
from travel.avia.library.python.route_search.models import ZNodeRoute2, BaseThreadSegment


log = logging.getLogger(__name__)

# Эта дата используется в Пересадочнике как начало отсчета для многих вычислений.
# Для нас же это покзатель ошибки, т.к. данных за 2008 год у нас нет.
# https://github.yandex-team.ru/rasp/pathfinder/blob/26dce3bb7b8b0e0324dc3c8299cd239750dd1a9e/pathfinderlib/pf.h#L62
PATHFINDER_ERROR_DATE = date(2008, 1, 1)

# Минимальное время на пересадку
MIN_TRANSFER_MINUTES = getattr(settings, 'MIN_TRANSFER_MINUTES', 2 * 60)
# Максимальное время на пересадку
MAX_TRANSFER_MINUTES = getattr(settings, 'MAX_TRANSFER_MINUTES', 23 * 60)

SEARCH_TYPE_MAP = {
    'train': [TransportType.TRAIN_ID],
    'plane': [TransportType.PLANE_ID],
    'bus': [TransportType.BUS_ID],
    'suburban': [TransportType.SUBURBAN_ID],
    'river': TransportType.WATER_TTYPE_IDS,
    'sea': TransportType.WATER_TTYPE_IDS,
    'water': TransportType.WATER_TTYPE_IDS,
}

ALL_TRANSPORT_TYPES = set(chain.from_iterable(SEARCH_TYPE_MAP.values()))


def find_routes(from_point, to_point, departure_date, t_type, delta=120):
    empty_response = u'<?xml version="1.0" encoding="utf-8" ?><groups />'

    departure_datetime = get_msk_time(from_point, local_datetime=datetime.combine(departure_date, time()))

    if t_type:
        t_types_ids = set()

        search_types = t_type if isinstance(t_type, (list, tuple)) else [t_type]

        for search_type in search_types:
            if search_type in SEARCH_TYPE_MAP:
                t_types_ids.update(SEARCH_TYPE_MAP[search_type])

        # no valid transport types specified
        # RASPAPI-551
        if not t_types_ids:
            return [], empty_response
    else:
        t_types_ids = ALL_TRANSPORT_TYPES

    request_params = [
        ('from_type', from_point.type),
        ('from_id', from_point.id),
        ('to_type', to_point.type),
        ('to_id', to_point.id),
        ('date', departure_datetime.strftime("%Y-%m-%d %H:%M:%S")),
        ('ttype', list(t_types_ids)),
        ('boarding', 1440),
        ('min_delay', MIN_TRANSFER_MINUTES),
        ('max_delay', MAX_TRANSFER_MINUTES),
        ('optimize', 'time'),
        ('delta', delta),
        ('max_transfers', 1),
    ]

    request = urlencode(request_params, True)

    if settings.DEBUG:
        log.info("Request: %s" % request)

    # может уже есть такой запрос в кэше
    cache_key = settings.CACHEROOT + 'pathfinder/search/' + md5_hex(request)
    cached_response = cache.get(cache_key)

    if cached_response:
        response = cached_response

    elif settings.PATHFINDER_URL:
        url = settings.PATHFINDER_URL + '?' + request

        try:
            response = urlopen(url, timeout=settings.PATHFINDER_TIMEOUT).read()
        except Exception:
            response = None
            log.error('url: %s' % url)

        if response:
            cache.set(cache_key, response, settings.CACHES['default']['LONG_TIMEOUT'])
    else:
        response = None

    if not response:
        response = empty_response

    return parse_response(response), response


class Group(object):
    """Группа маршрутов с пересадками"""

    @property
    def best_variant(self):
        return self.variants[0]

    def __init__(self, element):
        self.variants = [Variant(variant_element) for variant_element in element.findall('variant')]

    @property
    def first_arrival(self):
        return min(v.arrival for v in self.variants)

    @property
    def last_arrival(self):
        return max(v.arrival for v in self.variants)

    def __repr__(self):
        return '<Group ' + " ".join([
            "best_time=%r" % self.best_variant.duration,
            "departure=%r" % self.best_variant.departure,
            "first_arrival=%r" % self.first_arrival,
            "last_arrival=%r" % self.last_arrival,
            "has %d variants" % len(self.variants),
            ]) + '>'

    @property
    def best_departure(self):
        return min(v.departure for v in self.variants)

    @property
    def station_from(self):
        return self.best_variant.station_from

    @property
    def station_to(self):
        return self.best_variant.station_to

    def summary(self):
        """Суммарная информация по различным вариантам группы (список отрезков пути от города до города,
        в каждом списке - список сегментов"""

        return [sorted(segment_variants) for segment_variants in self.variants_by_segment()]

    def variants_by_segment(self):
        """Варианты, разбитые по сегментам маршрута"""

        return [set(segment_variants) for segment_variants in izip(*[variant.segments for variant in self.variants])]

    @property
    def transfers(self):
        """Города пересадки с информацией о количестве рейсов до пункта пересадки"""

        transfers = []

        for variants in self.variants_by_segment()[:-1]:
            number_of_variants = len(variants)

            station = list(variants)[0].station_to

            if station.settlement:
                point = station.settlement
            else:
                point = station

            transfers.append((point, number_of_variants))

        return transfers


class Variant(object):
    """Маршрут движения с пересадками"""

    @property
    def duration(self):
        if self.arrival is None or self.departure is None:
            return None

        return self.arrival - self.departure

    @property
    def transport_types(self):
        """Виды транспорта"""
        return [segment.thread.t_type for segment in self.segments]

    @property
    def display_t_codes(self):
        return [segment.display_t_code for segment in self.segments]

    @property
    def station_from(self):
        return self.segments[0].station_from

    @property
    def station_to(self):
        return self.segments[-1].station_to

    @property
    def rtstation_from(self):
        return self.segments[0].rtstation_from

    @property
    def rtstation_to(self):
        return self.segments[-1].rtstation_to

    @property
    def departure(self):
        return self.segments[0].departure if self.segments else None

    @property
    def arrival(self):
        return self.segments[-1].arrival if self.segments else None

    @property
    def msk_departure(self):
        return self.segments[0].msk_departure

    def transfer_title_point(self, from_segment, to_segment):
        # RASP-11372, для электричек показываем станцию пересадки
        if from_segment.t_type.code == to_segment.t_type.code == 'suburban':
            if from_segment.station_to == to_segment.station_from:
                return from_segment.station_to

        station = from_segment.station_to

        if station.settlement:
            return station.settlement

        return station

    @property
    def transfers(self):
        """Города пересадки"""

        transfers = []

        for i in xrange(0, len(self.segments) - 1):
            transfers.append(self.transfer_title_point(self.segments[i], self.segments[i + 1]))

        return transfers

    def _remove_empty_segments(self):
        self.segments = [segment for segment in self.segments if not segment.is_empty]

    def _add_transfers_info(self):
        prev_segment = self.segments[0]

        for segment in self.segments[1:]:
            in_ = segment.station_from.settlement or segment.station_from

            prev_segment.display_info['transfer'] = transfer_info = {
                'in': in_,
                'duration': segment.departure - prev_segment.arrival,
                'price': prev_segment.transfer_price,
                'convenience': prev_segment.transfer_convenience
            }

            if in_ != prev_segment.station_to:
                transfer_info.update({
                    'from': prev_segment.station_to,
                    'to': segment.station_from,
                })

            prev_segment = segment

    def _add_transfer_data(self):
        prev_segment = self.segments[0]

        for segment in self.segments[1:]:
            if segment.is_transfer:
                prev_segment.transfer_price = segment.price
                prev_segment.transfer_convenience = segment.convenience

            prev_segment = segment

    def __nonzero__(self):
        return bool(self.segments)

    def __init__(self, element):
        now = environment.now_aware()

        self.display_info = DisplayInfo()

        self.gone = False

        try:
            self.price = float(element.get('price'))

            if self.price <= 0:
                self.price = None
        except:
            self.price = None

        if settings.SHOW_TRANSFER_PRICES:
            self.display_info.set_tariff(self.price)

        try:
            self.convenience = element.get('tr')
            if self.convenience <= 0:
                self.convenience = None
        except:
            self.convenience = None

        self.segments = [TransferSegment(route_element, now) for route_element in element.findall('route')]

        self._add_transfer_data()

        self._remove_empty_segments()


class TransferSegment(BaseThreadSegment):
    """Часть маршрута с пересадками"""

    def is_valid(self):
        return (
            self.thread
            and self.station_from is not None
            and self.station_to is not None
            and self.rtstation_from is not None
            and self.rtstation_to is not None
            and self.arrival is not None
            and self.departure is not None
            and self.start_date is not None

            # RASPAPI-392
            and self.arrival.date() != PATHFINDER_ERROR_DATE
            and self.departure.date() != PATHFINDER_ERROR_DATE
        )

    def __eq__(self, other):
        # Предполагаем, что один тред не может прибывать в другое время или на другую станцию,
        # отправившись в определенное время из определенной станции.

        return (
            self.thread == other.thread
            and self.departure == other.departure
            and self.station_from == other.station_from
        )

    def __hash__(self):
        return hash((self.thread, self.departure, self.station_from))

    def __cmp__(self, other):
        return cmp(self.departure, other.departure)

    def _parse_datetime(self, str):
        return MSK_TZ.localize(datetime(*time_module.strptime(str, '%Y-%m-%d %H:%M')[:6]))

    def _parse_thread(self, thread_uid):
        if thread_uid != 'NULL':
            self.thread_uid = thread_uid
            self.is_empty = False
            self.is_transfer = False
        else:
            self.thread = None
            self.is_empty = True
            self.is_transfer = True

    def __init__(self, element, now):
        self.now = now
        self.msk_departure = self._parse_datetime(element.get('departure_datetime'))
        self.msk_arrival = self._parse_datetime(element.get('arrival_datetime'))

        self.thread = None
        self._parse_thread(element.get('thread_id'))

        self.msk_start_date = datetime.strptime(element.get('start_date') + 'T00:00:00', '%Y-%m-%dT%H:%M:%S').date()

        self.station_from = None
        self.station_from_id = int(element.get('departure_station_id'))

        self.station_to = None
        self.station_to_id = int(element.get('arrival_station_id'))

        self.rtstation_from = None
        self.rtstation_to = None

        try:
            self.price = float(element.get('price'))
            if self.price <= 0:
                self.price = None
        except:
            self.price = None

        try:
            self.convenience = element.get('tr')
            # self.convenience = int(element.get('tr'))
            # if self.convenience <= 0:
            if self.convenience == "":
                self.convenience = None
        except:
            self.convenience = None

        self.is_transfer_segment = True

        self.display_info = DisplayInfo()

    @cached_property
    def start_date(self):
        shift = self.rtstation_from.get_departure_mask_shift(self.thread.tz_start_time)

        if shift is None:
            # Станцию отправления не нашли, так как выбрали неправильный rtstation
            return None

        return self.departure.astimezone(self.rtstation_from.pytz).date() - timedelta(shift)

    @cached_property
    def arrival(self):
        return self.msk_arrival.astimezone(self.station_to.pytz)

    @cached_property
    def departure(self):
        return self.msk_departure.astimezone(self.station_from.pytz)

    @cached_property
    def gone(self):
        return self.departure < self.now

    def __repr__(self):
        return "<Segment thread_uid=%r>" % (self.thread_uid or 'NULL')


def add_stops(segments):
    stations_from = set()
    stations_to = set()

    threads = set()

    for segment in segments:
        if segment.thread:
            threads.add(segment.thread.id)

            stations_from.add(segment.station_from.id)
            stations_to.add(segment.station_to.id)

    noderoutes = ZNodeRoute2.objects.filter(
        station_from__id__in=list(stations_from),
        station_to__id__in=list(stations_to),
        thread__id__in=list(threads)
    ).only('station_from', 'station_to', 'thread', 'stops_translations')

    stops = {}

    for n in noderoutes:
        stops[n.station_from_id, n.station_to_id, n.thread_id] = n.stops_translations

    for segment in segments:
        if segment.thread:
            segment.stops_translations = stops.get(
                (segment.station_from.id, segment.station_to.id, segment.thread.id)
            )


def parse_response(response):
    try:
        if settings.SHOW_TRANSFER_PRICES:
            response = response.decode('cp1251')

        if isinstance(response, str):
            response = response.decode('cp1251')

        log.debug(u'Response:\n%s', response)

        tree = ET.fromstring(response.encode('utf-8'))
    except SyntaxError:
        log.error('bad response')
        return []

    groups = [Group(group_element) for group_element in tree.findall('group')]

    segments = [
        segment
        for group in groups
        for variant in group.variants
        for segment in variant.segments
    ]

    segments = fill_segments(segments)

    # Убираем невалидные варианты
    for group in groups:
        group.variants = [v for v in group.variants if all(s.is_valid() for s in v.segments)]

    i = 0

    for group in groups:
        for variant in group.variants:
            variant.display_info['variant_id'] = i
            variant._add_transfers_info()

            i += 1

        # RASP-1898, убирать группые без пересадок
        # Убираем варианты без пересадок
        group.variants = [v for v in group.variants if v.transfers]

    # Убираем и пустые группы тоже (они могут получиться в результате фильтрации валидных вариантов)
    groups = [group for group in groups if group.variants and group.transfers]

    add_stops(segments)

    return groups


def fill_segments(segments):
    segments, threads, stations = fetch_threads_and_stations(segments)
    fetch_rtstaitons(segments, threads, stations)

    return segments


def fetch_threads_and_stations(segments):
    threads_uids = set()
    station_ids = set()

    for segment in segments:
        threads_uids.add(segment.thread_uid)

        station_ids.add(segment.station_from_id)
        station_ids.add(segment.station_to_id)

    threads_by_uid = dict((t.uid, t) for t in RThread.objects.select_related('route').filter(uid__in=threads_uids))
    stations_by_id = Station.objects.in_bulk(list(station_ids))

    new_segments = []

    for segment in segments:
        try:
            segment.thread = threads_by_uid[segment.thread_uid]
            segment.station_from = stations_by_id[segment.station_from_id]
            segment.station_to = stations_by_id[segment.station_to_id]
            new_segments.append(segment)
        except KeyError:
            log.warning(u'Данные, которые отдал пересадочник не нашли в базе', exc_info=True)

    return new_segments, list(set(threads_by_uid.values())), list(set(stations_by_id.values()))


def fetch_rtstaitons(segments, threads, stations):
    rtstations = set()

    for segment in segments:
        rtstations.add((segment.thread.id, segment.station_from_id))
        rtstations.add((segment.thread.id, segment.station_to_id))

    # Достаем rtstations
    rtstations_iter = RTStation.objects.filter(
        thread__in=threads,
        station__in=stations
    ).order_by('thread', 'station', 'id')

    rtstations = {}

    for k, g in groupby(rtstations_iter, key=lambda rts: (rts.thread_id, rts.station_id)):
        thread_id, station_id = k

        # to - первая найденная станция, from_ = последняя
        # g.next() не поднимает StopIteration, потому-что groupby не возвращает итераторы нулевой длины

        to = from_ = g.next()

        try:
            while 1:
                from_ = g.next()
        except StopIteration:
            pass

        rtstations[thread_id, station_id, 't'] = to
        rtstations[thread_id, station_id, 'f'] = from_

    for segment in segments:
        segment.rtstation_from = rtstations.get((segment.thread.id, segment.station_from.id, 'f'))
        segment.rtstation_to = rtstations.get((segment.thread.id, segment.station_to.id, 't'))


def get_variants(point_from, point_to, when, type_=None):
    groups, xml = find_routes(point_from, point_to, when, type_)

    for group in groups:
        for variant in group.variants:
            # RASP-5908, пропускаем варианты с пересадками в городах запроса
            transfers = variant.transfers

            if point_from not in transfers and point_to not in transfers:
                yield variant
