# coding: utf-8
from __future__ import absolute_import, division, print_function, unicode_literals

from collections import namedtuple
from datetime import datetime, time, timedelta
from functools import partial

from travel.rasp.library.python.common23.date import environment
from travel.library.python.tracing.instrumentation import traced_function
from travel.rasp.wizards.suburban_wizard_api.lib.direction.tariffs_cache import TariffsCache
from travel.rasp.wizards.wizard_lib.event_date_query import EventDateQuery, make_fixed_date_query
from travel.rasp.wizards.wizard_lib.serialization.thread_express_type import ThreadExpressType
from travel.rasp.wizards.wizard_lib.utils.functional import tuplify


SuburbanVariant = namedtuple('SuburbanVariant', ('segment',))
SuburbanSegment = namedtuple('SuburbanSegment', (
    'train_number',
    'train_title',
    'duration',
    'departure_station',
    'departure_local_dt',
    'arrival_station',
    'arrival_local_dt',
    'price',
    'thread_express_type',
    'thread_start_date',
    'thread_transport_subtype_id',
    'thread_uid',
))


def filter_next_segments(raw_segments_iter, local_tz, min_count=3):
    try:
        first_segment = next(raw_segments_iter)
    except StopIteration:
        return

    local_departure_dt = first_segment.departure_dt.astimezone(local_tz)
    left_border, right_border = sorted((
        local_tz.localize(datetime.combine(local_departure_dt.date(), time.min) + timedelta(days=1, hours=4)),
        local_departure_dt + timedelta(days=1),
    ))
    max_departure_dt = left_border

    yield first_segment

    for count, raw_segment in enumerate(raw_segments_iter, start=1):
        departure_dt = raw_segment.departure_dt
        if departure_dt >= max_departure_dt:
            if count >= min_count or departure_dt >= right_border:
                return

            max_departure_dt = right_border

        yield raw_segment


@traced_function
def make_next_query(local_date, local_tz):
    return (
        EventDateQuery(environment.now_aware(), partial(filter_next_segments, local_tz=local_tz))
        if local_date is None else
        make_fixed_date_query(local_date, local_tz)
    )


@traced_function
@tuplify
def make_segments(raw_segments):
    for raw_segment in raw_segments:
        thread = raw_segment.thread
        departure_dt = raw_segment.departure_dt
        arrival_dt = raw_segment.arrival_dt
        departure_station = raw_segment.departure_station
        arrival_station = raw_segment.arrival_station
        departure_local_dt = departure_dt.astimezone(departure_station.pytz)

        yield SuburbanSegment(
            train_number=thread.number,
            train_title=thread.title,
            duration=arrival_dt - departure_dt,
            departure_station=departure_station,
            departure_local_dt=departure_local_dt,
            arrival_station=arrival_station,
            arrival_local_dt=arrival_dt.astimezone(arrival_station.pytz),
            price=TariffsCache.get_tariff(raw_segment),
            thread_express_type=ThreadExpressType(thread.express_type) if thread.express_type else None,
            thread_start_date=raw_segment.thread_start_dt.date(),
            thread_transport_subtype_id=thread.t_subtype_id,
            thread_uid=thread.uid,
        )


def split_segments(segments):
    return tuple(SuburbanVariant(segment=segment) for segment in segments)


def join_variants(variants):
    return tuple(variant.segment for variant in variants)
