# coding: utf-8
from __future__ import absolute_import, division, print_function, unicode_literals

import itertools
from collections import namedtuple
from datetime import datetime, time, timedelta
from functools import partial

from travel.rasp.library.python.common23.date import environment

EventDateQuery = namedtuple('EventDateQuery', 'min_event_dt, raw_filter_func')


def _get_event_date(raw_segment):
    event_dt = raw_segment.event_dt
    event_tz = raw_segment.event_station.pytz
    return event_dt.astimezone(event_tz).date()


def limit_segments_with_event_date(raw_segments_iter):
    return next(itertools.groupby(raw_segments_iter, key=_get_event_date), (None, ()))[1]


def limit_segments_by_max_event_dt(raw_segments_iter, max_event_dt):
    return itertools.takewhile(lambda raw_segment: raw_segment.event_dt < max_event_dt, raw_segments_iter)


def make_fixed_date_query(local_date, local_tz):
    local_now = environment.now_aware().astimezone(local_tz)
    min_dt = (
        local_now if local_now.date() == local_date else
        local_tz.localize(datetime.combine(local_date, time.min))
    )
    max_dt = local_tz.localize(datetime.combine(local_date + timedelta(days=1), time.min))
    return EventDateQuery(min_dt, partial(limit_segments_by_max_event_dt, max_event_dt=max_dt))


def make_next_date_query(local_tz):
    return EventDateQuery(environment.now_aware(), limit_segments_with_event_date)


def make_event_date_query(local_date, local_tz):
    return (
        make_next_date_query(local_tz) if local_date is None else
        make_fixed_date_query(local_date, local_tz)
    )
