# coding: utf-8
from __future__ import unicode_literals, absolute_import, division, print_function

import logging
from itertools import groupby

from django.db.models import Q
from django.utils.lru_cache import lru_cache

from common.models.geo import StationMajority, Station
from common.models.schedule import RThread, RThreadType, RTStation
from common.models.transport import TransportType, TrainPseudoStationMap
from common.models_utils import fetch_related
from route_search.helpers import LimitConditions
from travel.rasp.train_api.tariffs.train.segment_builder.helpers.title_common import fill_segment_title_common
from travel.rasp.train_api.tariffs.train.segment_builder.helpers.black_list import redirect_blacklisted_trains_to_ufs

log = logging.getLogger(__name__)


def build_price_segments(segments, train_query):
    segments = fill_segment_threads_and_stations(segments, train_query)
    segments = fill_segment_title_common(segments)
    segments = fill_thread_first_last_country_code(segments)
    segments = fill_stations_countries(segments)
    segments = redirect_blacklisted_trains_to_ufs(segments)
    return segments


class ThreadVariant(object):
    station_from = None
    station_to = None

    def __init__(self, thread, station_from_id, station_to_id):
        self.thread = thread
        self.station_from_id = station_from_id
        self.station_to_id = station_to_id

    @classmethod
    def fill_stations(cls, thread_variants):
        station_ids = set()

        for tv in thread_variants:
            station_ids.add(tv.station_from_id)
            station_ids.add(tv.station_to_id)

        stations = Station.objects.in_bulk(station_ids)

        for tv in thread_variants:
            tv.station_from = stations.get(tv.station_from_id)
            tv.station_to = stations.get(tv.station_to_id)


def fill_segment_threads_and_stations(segments, train_query):
    @lru_cache()
    def get_express_station(express_code):
        return Station.get_by_code('express', express_code)

    express_point_from = get_express_station(train_query.departure_point_code)
    express_point_to = get_express_station(train_query.arrival_point_code)

    can_fill_stations_from_request = express_point_from.majority_id != StationMajority.EXPRESS_FAKE_ID \
        and express_point_to.majority_id != StationMajority.EXPRESS_FAKE_ID

    if can_fill_stations_from_request:
        for segment in segments:
            segment.station_from = express_point_from
            segment.station_to = express_point_to
    else:
        for segment in segments:
            if segment.station_from_express_code:
                try:
                    segment.station_from = get_express_station(segment.station_from_express_code)
                except Station.DoesNotExist:
                    pass
            if segment.station_to_express_code:
                try:
                    segment.station_to = get_express_station(segment.station_to_express_code)
                except Station.DoesNotExist:
                    pass

    number_to_thread_variants, ambiguous_variant_numbers = get_number_to_thread_variants(segments, train_query)
    segments_without_thread_variant = fill_inplace_segments_thread_and_thread_variant(segments,
                                                                                      number_to_thread_variants)

    if not can_fill_stations_from_request:
        fill_inplace_segments_stations_from_thread_variant(segments, train_query, express_point_from, express_point_to,
                                                           ambiguous_variant_numbers)

    if segments_without_thread_variant and not can_fill_stations_from_request:
        fill_inplace_segments_stations_from_pseudo_stations(segments_without_thread_variant, train_query,
                                                            express_point_from, express_point_to)

    for segment in segments:
        if not (segment.station_from and segment.station_to):
            log.error('Не смогли восстановить станции отправления и(или) прибытия не показываем цены для %s, %s',
                      segment.original_number, train_query)

    return [s for s in segments if s.station_from and s.station_to]


def get_number_to_thread_variants(ufs_segments, train_query):
    limit_conditions = LimitConditions(train_query.departure_point, train_query.arrival_point,
                                       [TransportType.objects.get(pk=TransportType.TRAIN_ID)])
    possible_numbers = [n for s in ufs_segments for n in s.possible_numbers]

    threads_qs = limit_conditions.filter_threads_qs(RThread.objects.filter(
        t_type__id=TransportType.TRAIN_ID, number__in=possible_numbers))

    thread_variants = [
        ThreadVariant(thread, thread.station_from_id, thread.station_to_id)
        for thread in threads_qs.extra(select={'station_from_id': 'www_znoderoute2.station_from_id',
                                               'station_to_id': 'www_znoderoute2.station_to_id'})
    ]
    ThreadVariant.fill_stations(thread_variants)

    thread_variants.sort(key=lambda t: (t.thread.type_id == RThreadType.THROUGH_TRAIN_ID,
                                        t.station_from.majority_id, t.station_to.majority_id))

    number_to_thread_variants = {}
    ambiguous_variant_numbers = set()
    for thread_variant in thread_variants:
        if thread_variant.thread.number in number_to_thread_variants:
            ambiguous_variant_numbers.add(thread_variant.thread.number)
        else:
            number_to_thread_variants[thread_variant.thread.number] = thread_variant

    return number_to_thread_variants, ambiguous_variant_numbers


def fill_inplace_segments_thread_and_thread_variant(segments, number_to_thread_variants):
    segments_without_thread_variant = []
    for segment in segments:
        for number in segment.possible_numbers:
            thread_variant = number_to_thread_variants.get(number)
            if not thread_variant:
                continue

            segment.thread = thread_variant.thread
            segment.thread_variant = thread_variant
            break
        else:
            segments_without_thread_variant.append(segment)

    return segments_without_thread_variant


def fill_inplace_segments_stations_from_thread_variant(ufs_segments, train_query, express_point_from, express_point_to,
                                                       ambiguous_variant_numbers):
    for ufs_segment in ufs_segments:
        thread = ufs_segment.thread
        thread_variant = ufs_segment.thread_variant

        if not thread_variant:
            continue

        if not ufs_segment.station_from:
            station_from = None
            if thread.number in ambiguous_variant_numbers:
                station_from = resolve_pseudo_station(ufs_segment, express_point_from, train_query)
            ufs_segment.station_from = station_from or thread_variant.station_from

        if not ufs_segment.station_to:
            station_to = None
            if thread.number in ambiguous_variant_numbers:
                station_to = resolve_pseudo_station(ufs_segment, express_point_to, train_query)
            ufs_segment.station_to = station_to or thread_variant.station_to


def fill_inplace_segments_stations_from_pseudo_stations(segments_without_thread_variant, train_query,
                                                        express_point_from, express_point_to):
    for ufs_segment in segments_without_thread_variant:
        if not ufs_segment.station_from:
            station = resolve_pseudo_station(ufs_segment, express_point_from, train_query)
            if station:
                ufs_segment.station_from = station
            else:
                log.error('Не удалось заполнить данные для нитки %s отправление из %s %s, %s',
                          ufs_segment.original_number,
                          express_point_from.id, express_point_from.title,
                          train_query)
                ufs_segment.can_supply_segments = False

        if not ufs_segment.station_to:
            station = resolve_pseudo_station(ufs_segment, express_point_to, train_query)
            if station:
                ufs_segment.station_to = station
            else:
                log.error('Не удалось заполнить данные для нитки %s прибытие в %s %s, %s',
                          ufs_segment.original_number,
                          express_point_to.id, express_point_to.title,
                          train_query)
                ufs_segment.can_supply_segments = False


def resolve_pseudo_station(ufs_segment, station, train_query):
    if station.majority_id != StationMajority.EXPRESS_FAKE_ID:
        return station
    try:
        return TrainPseudoStationMap.objects.get(number=ufs_segment.original_number, pseudo_station=station).station
    except TrainPseudoStationMap.DoesNotExist:
        log.warning('Не смогли разрешить псевдостанцию: %s %s: %s, %s',
                    station.id, station.title, ufs_segment.original_number, train_query)


def fill_thread_first_last_country_code(segments, predicate=lambda s: s.thread is not None):
    """
    :param predicate: Заполняем страны, только сегментам, удовлетворяющим условию
    """
    threads = [s.thread for s in segments if predicate(s)]
    rtstations = (
        RTStation.objects
        .filter(thread__in=threads)
        .filter(Q(tz_arrival=None) | Q(tz_departure=None))
        .order_by('thread', 'id')
        .select_related('station__country')
        .values('thread_id', 'station__country__code')
    )
    country_codes = {}
    for thread_id, stops in groupby(rtstations, lambda stop: stop['thread_id']):
        stops = list(stops)
        country_codes[thread_id] = (stops[0]['station__country__code'], stops[-1]['station__country__code'])

    for thread in threads:
        if thread.id in country_codes:
            thread.first_country_code, thread.last_country_code = country_codes[thread.id]

    return segments


def fill_stations_countries(segments):
    fetch_related([s.station_from for s in segments] + [s.station_to for s in segments],
                  'country', model=Station)

    return segments
