# -*- coding: utf-8 -*-

import re
from datetime import datetime
from common.utils.date import MSK_TZ, RunMask
from travel.rasp.library.python.common23.date import environment


def match_supplier_route(blacklist, supplier_route):
    match_functions = [
        _check_t_type, _check_thread_type, _check_number, _check_supplier_number,
        _check_supplier_title, _check_start_time, _check_finish_time, _check_start_station,
        _check_finish_station, _check_blacklist_thread_stations
    ]

    route = supplier_route.get_route()

    for func in match_functions:
        if not func(blacklist, route, supplier_route):
            return False

    return True


def _check_t_type(blacklist, route, supplier_route):
    return blacklist.t_type_id == route.t_type_id


def _check_thread_type(blacklist, route, supplier_route):
    if blacklist.thread_type_id:
        return blacklist.thread_type_id == route.threads[0].type_id
    else:
        return True


def _check_number(blacklist, route, supplier_route):
    if not blacklist.number:
        return True

    number_pattern = re.compile(blacklist.number, re.U | re.I)

    return all(
        number_pattern.match(thread.number or u"")
        for thread in route.threads
    )


def _check_supplier_number(blacklist, route, supplier_route):
    if blacklist.supplier_number:
        return bool(re.match(blacklist.supplier_number, supplier_route.get_supplier_number() or u"", re.U + re.I))
    else:
        return True


def _check_supplier_title(blacklist, route, supplier_route):
    if blacklist.supplier_title:
        return bool(re.match(blacklist.supplier_title, supplier_route.get_supplier_title() or u"", re.U + re.I))
    else:
        return True


def _check_start_time(blacklist, route, supplier_route):
    if blacklist.start_time_start is None:
        return True

    thread = route.threads[0]

    if blacklist.is_moscow_time:
        tz = MSK_TZ

    else:
        tz = route.threads[0].rtstations[0].station.pytz

    first_run_date = RunMask.first_run(thread.year_days, environment.today()) or environment.today()

    naive_start_dt = datetime.combine(first_run_date, thread.tz_start_time)

    start_time = thread.pytz.localize(naive_start_dt).astimezone(tz).time()

    if blacklist.start_time_end is not None:
        time_range = (blacklist.start_time_start, blacklist.start_time_end)

    else:
        time_range = (blacklist.start_time_start, blacklist.start_time_start)

    return time_in_range(start_time, time_range)


def _check_finish_time(blacklist, route, supplier_route):
    if blacklist.finish_time_start is None:
        return True

    last_rts = route.threads[0].rtstations[-1]

    thread = route.threads[0]

    first_run_date = RunMask.first_run(thread.year_days, environment.today()) or environment.today()

    naive_start_dt = datetime.combine(first_run_date, thread.tz_start_time)

    if blacklist.is_moscow_time:
        tz = MSK_TZ

    else:
        tz = last_rts.station.pytz

    finish_time = last_rts.get_arrival_dt(naive_start_dt, out_tz=tz).time()

    if blacklist.finish_time_end is not None:
        time_range = (blacklist.finish_time_start, blacklist.finish_time_end)

    else:
        time_range = (blacklist.finish_time_start, blacklist.finish_time_start)

    return time_in_range(finish_time, time_range)


def _check_start_station(blacklist, route, supplier_route):
    if blacklist.start_station_id:
        return blacklist.start_station_id == route.threads[0].rtstations[0].station_id
    else:
        return True


def _check_finish_station(blacklist, route, supplier_route):
    if blacklist.finish_station_id:
        return blacklist.finish_station_id == route.threads[0].rtstations[-1].station_id
    else:
        return True


def _check_blacklist_thread_stations(blacklist, route, supplier_route):
    if blacklist.backlist_thread_stations:
        rtstations = route.threads[0].rtstations
        thread = route.threads[0]

        for bl_tt in blacklist.backlist_thread_stations:
            if not _check_blacklist_thread_station(bl_tt, thread, rtstations):
                return False

        return True
    else:
        return True


def _check_blacklist_thread_station(bl_tt, thread, rtstations):
    has_match = False

    for rts in rtstations:
        if rts.station_id == bl_tt.station_id:
            has_match = _check_bl_tt_times_match(bl_tt, thread, rts)

            if has_match:
                break

    return has_match


def _check_bl_tt_times_match(bl_tt, thread, rts):
    match_arrival = _check_bl_tt_arrival_match(bl_tt, thread, rts)
    match_departure = _check_bl_tt_departure_match(bl_tt, thread, rts)

    return match_departure and match_arrival


def _check_bl_tt_arrival_match(bl_tt, thread, rts):
    if bl_tt.arrival_time_start is None:
        return True

    if rts.tz_arrival is None:
        return False

    first_run_date = RunMask.first_run(thread.year_days, environment.today()) or environment.today()

    naive_start_dt = datetime.combine(first_run_date, thread.tz_start_time)

    if bl_tt.is_moscow_time:
        tz = MSK_TZ

    else:
        tz = rts.station.pytz

    arrival_time = rts.get_arrival_dt(naive_start_dt, out_tz=tz).time()

    if bl_tt.arrival_time_end is not None:
        time_range = (bl_tt.arrival_time_start, bl_tt.arrival_time_end)

    else:
        time_range = (bl_tt.arrival_time_start, bl_tt.arrival_time_start)

    return time_in_range(arrival_time, time_range)


def _check_bl_tt_departure_match(bl_tt, thread, rts):
    if bl_tt.departure_time_start is None:
        return True

    if rts.tz_departure is None:
        return False

    first_run_date = RunMask.first_run(thread.year_days, environment.today()) or environment.today()

    naive_start_dt = datetime.combine(first_run_date, thread.tz_start_time)

    if bl_tt.is_moscow_time:
        tz = MSK_TZ

    else:
        tz = rts.station.pytz

    departure_time = rts.get_departure_dt(naive_start_dt, out_tz=tz).time()

    if bl_tt.departure_time_end is not None:
        time_range = (bl_tt.departure_time_start, bl_tt.departure_time_end)

    else:
        time_range = (bl_tt.departure_time_start, bl_tt.departure_time_start)

    return time_in_range(departure_time, time_range)


def time_in_range(start_time, time_range):
    """
    time_in_range(time(11:00), (time(11:00), time(11:00))
    time_in_range(time(11:00), (time(10:00), time(12:00))
    time_in_range(time(11:00), (time(22:00), time(12:00))
    not time_in_range(time(11:00), (time(22:00), time(23:00))
    """

    inverted = False

    if time_range[0] > time_range[1]:
        inverted = True

    if inverted:
        return (start_time <= time_range[1]) or (time_range[0] <= start_time)
    else:
        return time_range[0] <= start_time <= time_range[1]
