# coding: utf-8
from __future__ import absolute_import, division, print_function, unicode_literals

from collections import OrderedDict, namedtuple
from datetime import datetime, time, timedelta

from common.models.schedule import DeLuxeTrain
from travel.rasp.library.python.common23.date import environment
from common.utils.date import UTC_TZ
from travel.library.python.tracing.instrumentation import traced_function
from travel.rasp.wizards.wizard_lib.event_date_query import EventDateQuery, limit_segments_with_event_date
from travel.rasp.wizards.wizard_lib.utils.functional import tuplify


TrainVariant = namedtuple('TrainVariant', ('segment', 'places_group'))
TrainSegment = namedtuple('TrainSegment', (
    'train_number',
    'train_title',
    'train_brand',
    'thread_type',
    'duration',
    'departure_station',
    'departure_local_dt',
    'arrival_station',
    'arrival_local_dt',
    'electronic_ticket',
    'places',
    'facilities_ids',
    'updated_at',
    'display_number',
    'has_dynamic_pricing',
    'two_storey',
    'is_suburban',
    'coach_owners',
    'first_country_code',
    'last_country_code',
    'broken_classes',
    'provider',
    'raw_train_name',
    't_subtype_id',
))


def make_tomorrow_query(local_tz):
    tomorrow_date = (environment.now_utc(aware=True) + timedelta(days=1)).astimezone(local_tz).date()
    tomorrow_min_dt = local_tz.localize(datetime.combine(tomorrow_date, time.min))
    return EventDateQuery(tomorrow_min_dt, limit_segments_with_event_date)


@traced_function
@tuplify
def make_segments(raw_segments, tariff_direction_info, trains_info):
    tariff_segment_info, updated_info = tariff_direction_info

    key_to_tariff_segment = {
        (s.departure_station_id, s.arrival_station_id, s.number, s.departure_dt): s
        for s in tariff_segment_info
    }
    key_to_train_info = {
        (info.number, info.departure_at): info for info in trains_info
    }

    for raw_segment in raw_segments:
        thread = raw_segment.thread
        departure_dt = raw_segment.departure_dt
        arrival_dt = raw_segment.arrival_dt
        train_number = thread.number
        departure_station = raw_segment.departure_station
        arrival_station = raw_segment.arrival_station
        departure_local_dt = departure_dt.astimezone(departure_station.pytz)

        tariff_segment_info = key_to_tariff_segment.get(
            (departure_station.id, arrival_station.id, train_number, departure_local_dt.isoformat())
        )
        train_info = key_to_train_info.get(
            (train_number, departure_dt.astimezone(UTC_TZ).replace(tzinfo=None))
        )

        electronic_ticket = None
        if train_info and train_info.electronic_ticket is not None:
            electronic_ticket = train_info.electronic_ticket
        elif tariff_segment_info and tariff_segment_info.electronic_ticket is not None:
            electronic_ticket = tariff_segment_info.electronic_ticket

        yield TrainSegment(
            train_number=train_number,
            train_title=thread.title,
            train_brand=DeLuxeTrain.get_by_number(train_number),
            thread_type=thread.type.code,
            duration=arrival_dt - departure_dt,
            departure_station=departure_station,
            departure_local_dt=departure_local_dt,
            arrival_station=arrival_station,
            arrival_local_dt=arrival_dt.astimezone(arrival_station.pytz),
            places=tariff_segment_info.places if tariff_segment_info else None,
            broken_classes=tariff_segment_info.broken_classes if tariff_segment_info else None,
            electronic_ticket=electronic_ticket,
            facilities_ids=train_info.facilities_ids if train_info else None,
            updated_at=updated_info.get_updated_at(departure_dt),
            display_number=tariff_segment_info.display_number if tariff_segment_info else None,
            has_dynamic_pricing=tariff_segment_info.has_dynamic_pricing if tariff_segment_info else None,
            two_storey=tariff_segment_info.two_storey if tariff_segment_info else None,
            is_suburban=tariff_segment_info.is_suburban if tariff_segment_info else None,
            coach_owners=tariff_segment_info.coach_owners if tariff_segment_info else None,
            first_country_code=tariff_segment_info.first_country_code if tariff_segment_info else None,
            last_country_code=tariff_segment_info.last_country_code if tariff_segment_info else None,
            provider=tariff_segment_info.provider if tariff_segment_info else None,
            raw_train_name=tariff_segment_info.raw_train_name if tariff_segment_info else None,
            t_subtype_id=thread.t_subtype_id,
        )


@traced_function
def split_segments(segments):
    return tuple(
        TrainVariant(segment=segment, places_group=places_group)
        for segment in segments
        for places_group in (segment.places or (None,))
    )


@traced_function
def join_variants(variants):
    segments_with_places = OrderedDict()

    for variant in variants:
        segment, places_group = variant.segment, variant.places_group
        _segment, places = segments_with_places.setdefault(id(segment), (segment, []))
        if places_group is not None:
            places.append(places_group)

    return tuple(
        segment._replace(places=tuple(places) if places else None)
        for segment, places in segments_with_places.values()
    )
