# -*- coding: utf-8 -*-
import logging
from datetime import datetime, timedelta
from itertools import ifilter, imap, izip

from travel.avia.avia_api.avia.v1.model.filters import AirportFilter, TimeFilter, TimeHelper, TimeOfDay, TransferFilter

logger = logging.getLogger(__name__)


class VariantsFilter(object):
    """
    Фильтратор вариантов. Инициализируется фильтром avia.v1.model.filters.Filter и даёт возможность
    отфильтровать целый набор вариантов через filter_variants
    """

    def __init__(self, _filter):
        """
        :param avia.v1.model.filters.Filter _filter: фильтр вариантов перелёта
        """
        self._filter = _filter

    def filter_variants(self, variants):
        """
        :param typing.Iterable[dict[str, list[dict[str, any]]]] variants: варианты перелёта
        :return: итератор с отфильтрованными вариантами
        """
        return ifilter(self.test, variants)

    def test(self, variant):
        """
        Проверяет вариант на соответствие заданному фильтру

        :param dict[str, list[dict[str, any]]] variant: вариант перелёта
        """
        # ----- Check baggage filter -----
        if self._filter.with_baggage:
            if not variant['with_baggage']:
                return False
        # --------------------------------

        fwd_segments = variant['forward_segments']
        bwd_segments = variant.get('backward_segments')

        # ----- Check arrival and departure filter -----
        time_filter = self._filter.time_filters
        if time_filter:
            assert isinstance(time_filter, TimeFilter)

            if not self.filter_time_of_day(
                get_datetime_with_offset(
                    fwd_segments[0]['departure_time'], fwd_segments[0]['departure_offset']
                ), time_filter.forward_departure
            ):
                return False

            if not self.filter_time_of_day(
                get_datetime_with_offset(
                    fwd_segments[-1]['arrival_time'], fwd_segments[-1]['arrival_offset']
                ), time_filter.forward_arrival
            ):
                return False

            if bwd_segments:
                if not self.filter_time_of_day(
                    get_datetime_with_offset(
                        bwd_segments[0]['departure_time'], bwd_segments[0]['departure_offset']
                    ), time_filter.backward_departure
                ):
                    return False

                if not self.filter_time_of_day(
                    get_datetime_with_offset(
                        bwd_segments[-1]['arrival_time'], bwd_segments[-1]['arrival_offset']
                    ), time_filter.backward_arrival
                ):
                    return False
        # ---------------------------------------------

        # ----- Check airport arrival and departure filters -----
        airport_filters = self._filter.airport_filters
        if airport_filters:
            assert isinstance(airport_filters, AirportFilter)

            if not self.filter_airport(fwd_segments[0]['departure_station_id'], airport_filters.forward_departure):
                return False

            if not self.filter_airport(fwd_segments[-1]['arrival_station_id'], airport_filters.forward_arrival):
                return False

            if bwd_segments:
                if not self.filter_airport(bwd_segments[0]['departure_station_id'], airport_filters.backward_departure):
                    return False

                if not self.filter_airport(bwd_segments[-1]['arrival_station_id'], airport_filters.backward_arrival):
                    return False
        # -------------------------------------------------------

        # ----- Check segments-related filters -----
        f_forward_transfer = None
        f_backward_transfer = None
        if airport_filters:
            f_forward_transfer = airport_filters.forward_transfers
            f_backward_transfer = airport_filters.backward_transfers

        if not self.filter_segments(fwd_segments, f_forward_transfer):
            return False

        if bwd_segments:
            if not self.filter_segments(bwd_segments, f_backward_transfer):
                return False

        # -------------------------------------------
        return True

    def filter_segments(self, segments, f_transfer=None):
        """
        :param list[dict[str, any]] segments: список сегментов для перелёта
        :param typing.Iterator[str] f_transfer: набор кодов аэропортов, которые можно использоват для пересадок
        :return: True если фильтр пройден, иначе False
        """
        # Filter transfer count
        transfer_filter = self._filter.transfer_filters
        if transfer_filter:
            if transfer_filter.count is not None:
                transfers = len(segments) - 1
                if transfers > transfer_filter.count:
                    return False

        if f_transfer:
            f_transfer = set(f_transfer)
        transfer_stations = set()

        if not all(imap(self.filter_airlines, segments)):
            return False

        segments.sort(key=lambda s: get_datetime_with_offset(s['arrival_time'], s['arrival_offset']))
        for idx, (cur, _next) in enumerate(izip(segments, segments[1:])):
            if not self.filter_airport_change(cur, _next):
                return False

            if f_transfer:
                transfer_stations.add(cur['arrival_station_id'])
                transfer_stations.add(_next['departure_station_id'])
                if idx > 0:
                    transfer_stations.add(cur['departure_station_id'])
                if idx < len(segments) - 2:
                    transfer_stations.add(_next['arrival_station_id'])

        if f_transfer and transfer_stations and not set(f_transfer) & transfer_stations:
            return False

        return True

    def filter_airlines(self, cur):
        """
        :param dict[str,any] cur: сегмент перелета с id авиалиний
        :return: True если фильтр пройден, иначе False
        """
        airlines_filter = self._filter.airlines
        if airlines_filter:
            if cur['airline_id'] not in airlines_filter:
                return False
        return True

    def filter_airport_change(self, cur, _next):
        """

        :param dict[str,any] cur: сегмент перелета
        :param dict[str,any] _next: сегмент перелета следующий за
        :return: True если фильтр пройден, иначе False
        """
        if not _next:
            return True
        transfer_filter = self._filter.transfer_filters
        if transfer_filter:
            assert isinstance(transfer_filter, TransferFilter)

        if transfer_filter:
            if (
                transfer_filter.has_airport_change is False and
                cur['arrival_station_id'] != _next['departure_station_id']
            ):
                return False
            if (
                transfer_filter.has_night is False
            ):
                if cur['arrival_offset'] != _next['departure_offset']:
                    logger.warning(
                        'Departure and arrival offsets of neighbour segments dont match: %s %s',
                        cur,
                        _next,
                    )
                arr_dt = get_datetime_with_offset(cur['arrival_time'], cur['arrival_offset'])
                dep_dt = get_datetime_with_offset(_next['departure_time'], _next['departure_offset'])

                if TimeHelper.is_night_in_range(dt_start=arr_dt, dt_stop=dep_dt):
                    return False
        return True

    @staticmethod
    def filter_time_of_day(segment_time, time_of_day):
        """
        :param datetime segment_time:
        :param avia.v1.model.filters.TimeOfDay time_of_day:
        """
        if time_of_day is None:
            return True
        f_time_of_day = TimeOfDay(time_of_day)
        if not TimeHelper.is_time_of_day(segment_time, f_time_of_day):
            return False
        return True

    @staticmethod
    def filter_airport(airport, filter_value):
        return airport in filter_value


def get_datetime_with_offset(timestamp, offset):
    return datetime.utcfromtimestamp(timestamp) + timedelta(seconds=offset)
