# -*- coding: utf-8 -*-

from django.db.models import Q

from common.models.transport import TransportType
from common.utils.date import FuzzyDateTime


TRANSPORT_TYPE_IDS_IN_TABLO = (TransportType.PLANE_ID, TransportType.HELICOPTER_ID)


def add_z_tablos_to_segments(segments, separated_segments=False):
    """
    separated_segments: сегменты, разрозненные по времени или станциям,
    указываем точные времена, чтобы меньше выгребать.
    """
    if not segments:
        return

    # Импортируем внутри, чтобы не падало.
    from stationschedule.models import ZTablo2

    def _get_tablo_departure_key(z_tablo):
        return z_tablo.thread_id, z_tablo.station_id, z_tablo.original_departure

    def _get_segment_departure_key(segment):
        return segment.thread.id, segment.station_from.id, _ensure_datetime(segment.departure)

    def _get_tablo_arrival_key(z_tablo):
        return z_tablo.thread_id, z_tablo.station_id, z_tablo.original_arrival

    def _get_segment_arrival_key(segment):
        return segment.thread.id, segment.station_to.id, _ensure_datetime(segment.arrival)

    def _ensure_datetime(dt):
        """При формировании фильтров, django DateTimeField падает при попытке десериализовать FuzzyDateTime"""
        dt = dt.replace(tzinfo=None)
        if isinstance(dt, FuzzyDateTime):
            return dt.dt
        return dt

    thread_ids = list({s.thread.id for s in segments if s.thread.t_type_id in TRANSPORT_TYPE_IDS_IN_TABLO})
    if not thread_ids:
        return

    if separated_segments:
        departure_filter = Q(original_departure__in=list({
            _ensure_datetime(s.departure) for s in segments
        }))
    else:
        departure_filter = Q(
            original_departure__isnull=False,
            original_departure__range=(
                _ensure_datetime(segments[0].departure),
                _ensure_datetime(segments[-1].departure)
            )
        )

    departure_z_tablos = ZTablo2.objects.filter(
        departure_filter,
        station__in=list({s.station_from for s in segments}),
        thread__in=thread_ids,
    )
    departure_z_tablo_by_keys = {
        _get_tablo_departure_key(z): z
        for z in departure_z_tablos
    }

    if separated_segments:
        arrival_filter = Q(original_arrival__in=list({
            _ensure_datetime(s.arrival) for s in segments
        }))
    else:
        arrival_filter = Q(
            original_arrival__isnull=False,
            original_arrival__range=(
                _ensure_datetime(segments[0].arrival),
                _ensure_datetime(segments[-1].arrival)
            )
        )

    arrival_z_tablos = ZTablo2.objects.filter(
        arrival_filter,
        station__in=list({s.station_to for s in segments}),
        thread__in=thread_ids,
    )
    arrival_z_tablo_by_keys = {
        _get_tablo_arrival_key(z): z
        for z in arrival_z_tablos
    }

    for segment in segments:
        segment.departure_z_tablo = departure_z_tablo_by_keys.get(
            _get_segment_departure_key(segment)
        )
        segment.arrival_z_tablo = arrival_z_tablo_by_keys.get(
            _get_segment_arrival_key(segment)
        )
