# -*- coding: utf-8 -*-
from abc import ABCMeta, abstractmethod
from collections import namedtuple
from datetime import datetime
from itertools import chain

from more_itertools import pairwise


class CachedDateTimeDeserializer(object):
    # todo: to utils
    def __init__(self, fmt):
        self._format = fmt
        self._cache = {}

    def deserialize(self, raw_datetime):
        if raw_datetime not in self._cache:
            self._cache[raw_datetime] = datetime.strptime(
                raw_datetime, self._format
            )
        return self._cache[raw_datetime]


class Filter(object):
    __metaclass__ = ABCMeta

    @abstractmethod
    def __call__(self, fares, datum):
        """
        :type fares: List[ticket_daemon_lib.protobuf_converting.big_wizard.search_result_pb2.Fare]
        :type datum: ticket_daemon_lib.protobuf_converting.big_wizard.search_result_pb2.SearchResult
        """


class BaggageFilter(Filter):
    def __init__(self, with_baggage):
        self.with_baggage = with_baggage

    def __call__(self, fares, datum):
        if self.with_baggage is True:
            return self._filter(fares)
        return fares

    def _filter(self, fares):
        for fare in fares:
            if all(
                baggage and baggage.startswith('1')
                for baggage in chain(fare.baggage.forward, fare.baggage.backward)
            ):
                yield fare
            elif bool(fare.tariffs.with_baggage.ByteSize()):
                tariff_info = fare.tariffs.with_baggage
                del fare.baggage.forward[:]
                fare.baggage.forward.extend(tariff_info.baggage.forward)
                del fare.baggage.backward[:]
                fare.baggage.backward.extend(tariff_info.baggage.backward)

                fare.partner = tariff_info.partner
                fare.conversion_partner = tariff_info.conversion_partner
                fare.tariff.value = tariff_info.price.value
                fare.tariff.currency = tariff_info.price.currency
                fare.created_at = tariff_info.created_at
                fare.expire_at = tariff_info.expire_at
                yield fare


class TransferFilter(Filter, namedtuple('TransferFilter', (
    'count', 'min_duration', 'max_duration', 'has_airport_change', 'has_night'
))):
    def __call__(self, fares, datum):
        if self.count is not None:
            fares = self._segments_filter(fares)
        if self.has_airport_change is False:
            fares = self._airport_change_filter(fares, datum)
        if (
            self.min_duration is not None
            or self.max_duration is not None
            or self.has_night is False
        ):
            fares = self._duration_filter(fares, datum)

        return fares

    def _segments_filter(self, fares):
        segments_count = self.count + 1

        for fare in fares:
            if (
                len(fare.route.forward) <= segments_count
                and len(fare.route.backward) <= segments_count
            ):
                yield fare

    @staticmethod
    def _airport_change_filter(fares, datum):
        """
        Удаляет предложения со сменой аэропорта
        """
        for fare in fares:
            for r1, r2 in chain(pairwise(fare.route.forward), pairwise(fare.route.backward)):
                if datum.flights[r1].to_id != datum.flights[r2].from_id:
                    break
            else:
                yield fare

    def _duration_filter(self, fares, datum):
        """
        Удаляет пересадочные предложения, длительность пересадок в которых
         не входят в переданный диапазон
        Тут же удаляются ночные пересадки, если был соответствующий флаг
        """
        min_duration = self.min_duration or 0
        max_duration = self.max_duration or 60 * 24 * 7

        dt_deserializer = CachedDateTimeDeserializer('%Y-%m-%dT%H:%M:%S')
        for fare in fares:
            for r1, r2 in chain(pairwise(fare.route.forward), pairwise(fare.route.backward)):
                transfer_arrival = dt_deserializer.deserialize(datum.flights[r1].arrival.local)
                transfer_departure = dt_deserializer.deserialize(datum.flights[r2].departure.local)

                duration = (transfer_departure - transfer_arrival).total_seconds() / 60  # minutes

                if not (min_duration <= duration <= max_duration):
                    break

                if self.has_night is False and (
                    transfer_arrival.day != transfer_departure.day
                    or 0 <= transfer_arrival.hour <= 6
                    or 0 <= transfer_departure.hour <= 6
                ):
                    break
            else:
                yield fare


class TimeFilter(Filter, namedtuple('TimeFilter', (
    'forward_departure_min', 'forward_departure_max',
    'forward_arrival_min', 'forward_arrival_max',
    'backward_departure_min', 'backward_departure_max',
    'backward_arrival_min', 'backward_arrival_max',
))):
    def __call__(self, fares, datum):
        return self._filter(fares, datum)

    def _filter(self, fares, datum):
        for fare in fares:
            if not (
                self._is_filtered(
                    datum.flights[fare.route.forward[0]].departure.local,
                    self.forward_departure_min,
                    self.forward_departure_max
                )
                or self._is_filtered(
                    datum.flights[fare.route.forward[-1]].arrival.local,
                    self.forward_arrival_min,
                    self.forward_arrival_max
                )
                or (
                    fare.route.backward
                    and (
                        self._is_filtered(
                            datum.flights[fare.route.backward[0]].departure.local,
                            self.backward_departure_min,
                            self.backward_departure_max
                        )
                        or self._is_filtered(
                            datum.flights[fare.route.backward[-1]].arrival.local,
                            self.backward_arrival_min,
                            self.backward_arrival_max
                        )
                    )

                )
            ):
                yield fare

    def _is_filtered(self, timestamp, _min, _max):
        return (
            (_min is not None and timestamp < _min)
            or (_max is not None and timestamp > _max)
        )


class AirlineFilter(Filter, namedtuple('AirlineFilter', ('airlines'))):
    def __call__(self, fares, datum):
        if self.airlines:
            return self._filter(fares, datum)

        return fares

    def _filter(self, fares, datum):
        for fare in fares:
            if all(
                datum.flights[route_segment].company in self.airlines
                for route_segment in chain(fare.route.forward, fare.route.backward)
            ):
                yield fare


class AirportFilter(Filter, namedtuple('AirportFilter', (
    'forward_departure', 'forward_arrival', 'forward_transfers',
    'backward_departure', 'backward_arrival', 'backward_transfers'
))):
    def __call__(self, fares, datum):
        return self._filter(fares, datum)

    def _filter(self, fares, datum):
        for fare in fares:
            if not (
                self._invalid_endpoints(datum, fare.route.forward, self.forward_departure, self.forward_arrival)
                or self._invalid_endpoints(datum, fare.route.backward, self.backward_departure, self.backward_arrival)
                or self._invalid_transfers(datum, fare.route.forward, self.forward_transfers)
                or self._invalid_transfers(datum, fare.route.backward, self.backward_transfers)
            ):
                yield fare

    def _invalid_endpoints(self, datum, routes, departures, arrivals):
        if len(routes) <= 0:
            return False

        if departures:
            departure = datum.flights[routes[0]].from_id
            if departure not in departures:
                return True

        if arrivals:
            arrival = datum.flights[routes[-1]].to_id
            if arrival not in arrivals:
                return True

        return False

    def _invalid_transfers(self, datum, routes, transfers):
        if transfers:
            transfer_airports = set()
            for r1, r2 in pairwise(routes):
                transfer_airports.add(datum.flights[r1].to_id)
                transfer_airports.add(datum.flights[r2].from_id)

            return not transfer_airports.intersection(transfers)

        return False
