# coding: utf-8
from __future__ import unicode_literals, absolute_import, division, print_function

import logging
from collections import defaultdict
from django.db.models import Q
from itertools import groupby

from common.models.geo import Station
from common.models.schedule import RTStation
from common.models.transport import TransportType
from route_search.shortcuts import search_routes
from travel.rasp.train_api.tariffs.train.base.utils import make_segment_train_keys
from travel.rasp.train_api.tariffs.train.segment_builder.helpers.merge_trains import make_meta_trains

log = logging.getLogger(__name__)


def get_search_segments_with_keys(train_query):
    search_segments, _, _ = search_routes(
        train_query.departure_point,
        train_query.arrival_point,
        departure_date=train_query.departure_date,
        transport_types=[TransportType.objects.get(pk=TransportType.TRAIN_ID)],
        exact_date=True,
        expanded_day=False,
        check_date=None,
        include_interval=False,
        add_z_tablos=False,
        max_count=None,
        threads_filter=None,
        prepared_threads=None
    )
    search_segments = make_meta_trains(search_segments)
    for search_segment in search_segments:
        train_keys = make_segment_train_keys(search_segment)
        search_segment.train_keys = set(train_keys)
    return search_segments


def fill_start_and_end_stations_from_thread(search_segments):
    segments_by_thread_id = defaultdict(list)
    for segment in search_segments:
        segments_by_thread_id[segment.thread.id].append(segment)
    rtstation_qs = (
        RTStation.objects
            .filter(thread__in=segments_by_thread_id.keys())
            .filter(Q(tz_arrival=None) | Q(tz_departure=None))
            .order_by('thread_id', 'id')
            .values('thread_id', 'station_id')
    )
    for thread_id, values in groupby(rtstation_qs, lambda x: x['thread_id']):
        station_ids = [v['station_id'] for v in values]
        for segment in segments_by_thread_id[thread_id]:
            segment.start_station_id, segment.end_station_id = station_ids

    station_ids = set()
    for segment in search_segments:
        station_ids.add(segment.start_station_id)
        station_ids.add(segment.end_station_id)

    stations = Station.objects.in_bulk(station_ids)

    for segment in search_segments:
        segment.start_station = stations.get(segment.start_station_id)
        segment.end_station = stations.get(segment.end_station_id)

    return search_segments
