# -*- coding: utf-8 -*-
import ujson
from collections import defaultdict
from datetime import datetime
from logging import getLogger
from typing import Dict, Any, Set

from django.conf import settings

from travel.avia.library.python.common.utils.date import MSK_TZ
from travel.avia.library.python.ticket_daemon.memo import CacheWithKeyTTL, memoize
from travel.avia.library.python.ticket_daemon.ydb.banned_variants.cache import BannedVariantsCache
from travel.avia.library.python.ticket_daemon.ydb.django.utils import session_manager
from travel.avia.ticket_daemon.ticket_daemon.api.models_utils import Country
from travel.avia.ticket_daemon.ticket_daemon.api.result.add_price_prediction import safe_add_price_prediction_category
from travel.avia.ticket_daemon.ticket_daemon.api.result.filters.too_fast_moscow_erevan import TooFastMoscowErevanFilter
from travel.avia.ticket_daemon.ticket_daemon.lib import feature_flags
from travel.avia.ticket_daemon.ticket_daemon.lib.currency import Price
from travel.avia.ticket_daemon.ticket_daemon.lib.utils import fix_flight_number
from travel.avia.ticket_daemon.ticket_daemon.lib.yt_loggers import (
    unknown_companies_logger,
    log_filtering,
    unknown_station_codes_logger,
)
from travel.avia.ticket_daemon.ticket_daemon.models import QueryBlackList

log = getLogger(__name__)
filtered_variants_logger = getLogger('filtered_variants')

NO_VISA_REQUIRED = frozenset(
    [
        Country.RUSSIA_ID,
        168,  # armenia
        Country.BELARUS_ID,
        Country.KAZAKHSTAN_ID,
        171,  # uzbekistan
        207,  # kyrgyzstan
    ]
)

MIN_TICKET_PRICE = Price(value=99, currency='RUR')


def prepare_variants_for_result(query, partner, rates, variants):
    if not variants:
        return variants

    for variant in variants:
        for segment in variant.all_segments:
            try:
                segment.complete(partner_code=partner.code, is_charter=variant.is_charter)
            except Exception as e:
                log.warning('Flight.complete: %r', e)
            if not segment.company:
                unknown_companies_logger.log(query.id, partner.code, segment)
            if segment.operating and not segment.operating.company:
                unknown_companies_logger.log(query.id, partner.code, segment.operating)

    banned_variants_cache = BannedVariantsCache(session_manager=session_manager)

    filters_applier = FiltersApplier(query, partner, rates, banned_variants_cache)
    vs = filters_applier.apply(variants)

    vs = _handle_variants_tags(vs, filters_applier.msg)
    vs = list(vs)

    safe_add_price_prediction_category(vs, query)

    return vs


class FiltersApplier(object):
    _too_fast_moscow_erevan_filter = TooFastMoscowErevanFilter()

    def __init__(self, query, partner, rates, banned_variants_cache=None):
        self.__applied_filter_count = 0
        self.__query = query
        self.__partner = partner
        self._currency_rates = rates
        self._min_price_national = Price.convert_to_national(
            MIN_TICKET_PRICE, self.__query.national_version, self._currency_rates,
        )
        self.msg = '{qid}_{partner_code}'.format(qid=query.id, partner_code=partner.code)
        self._banned_variants = BannedVariantsFilter(banned_variants_cache, query, partner.code)

    def apply(self, variants):
        log_filtering(self.__query, self.msg, 0, len(variants), self.__partner.code, number=0)
        _log_unknown_station_codes(self.__query, self.__partner, variants)
        vs = self._check_bad(variants)

        vs = self._apply_filter_fn(
            'by_currency_filter',
            (lambda v: filter_by_currency(v, self.__query)),
            vs,
        )

        vs = self._apply_filter_fn('station_postfilter', self.__query.station_postfilter, vs)

        now = MSK_TZ.localize(datetime.now())
        vs = self._apply_filter_fn('booking_available', (lambda v: v.booking_available(now)), vs)

        vs = self._apply_filter_fn(
            'bad_flight_number',
            (lambda v: good_flight_numbers(v, self.__partner.code, self.__query.id)),
            vs,
        )

        vs = self._apply_filter_fn('visa_required_filter', (lambda v: not visa_required_for_inner_flights(v)), vs)

        vs = self._blacklisted_filter_applier(vs)

        vs = self._apply_filter_fn(
            filter_name='aeroflot_banned_variants',
            predicate=(lambda v: not self._banned_variants.is_banned_variant(v)),
            variants=vs,
        )

        if settings.ENABLE_TOO_FAST_MOSCOW_EREVAN_FILTER:
            vs = self._apply_filter_fn(
                filter_name='too_fast_moscow_erevan_variants',
                predicate=(lambda v: not self._too_fast_moscow_erevan_filter.is_too_fast_variant(v)),
                variants=vs
            )

        return vs

    def _apply_filter_fn(self, filter_name, predicate, variants):
        self.__applied_filter_count += 1
        filter_number = self.__applied_filter_count
        bad_variants_count = 0
        good_variants_count = 0

        for v in variants:
            try:
                if predicate(v):
                    good_variants_count += 1
                    yield v
                else:
                    bad_variants_count += 1
            except Exception:
                log.exception('Exception on %s variants filter', filter_name)

        if bad_variants_count:
            filtered_variants_logger.info(
                u'%s filtered %s out of %s by %s. left variants %s',
                self.msg,
                bad_variants_count,
                bad_variants_count + good_variants_count,
                filter_name,
                good_variants_count,
            )
        log_filtering(
            self.__query,
            self.msg,
            bad_variants_count,
            good_variants_count,
            self.__partner.code,
            filter_name=filter_name,
            number=filter_number,
        )

    def _blacklisted_filter_applier(self, variants):
        national_version = self.__query.national_version
        rules_by_pcode = _queryblacklist_rules_by_partner_code()
        rules = rules_by_pcode.get(self.__partner.code, []) + (rules_by_pcode.get(None, []))
        if not rules:
            return variants

        whitelist_rules = [r for r in rules if r.allow]
        if whitelist_rules:
            return self._apply_filter_fn(
                'whitelist_filter',
                lambda v: v.check_variant(self._currency_rates, national_version, whitelist_rules),
                variants
            )
        else:
            blacklist_rules = [r for r in rules if not r.allow]
            return self._apply_filter_fn(
                'blacklist_filter',
                lambda v: not v.check_variant(self._currency_rates, national_version, blacklist_rules),
                variants
            )

    def _check_bad(self, vs):
        vs = self._apply_filter_fn('no_forward', (lambda v: v.forward_exists()), vs)

        vs = self._apply_filter_fn('bad_forward_transfers_times', (lambda v: v.forward_transfer_ok()), vs)

        vs = self._apply_filter_fn('date_forward_does_not_match_query', (lambda v: v.date_forward_ok(self.__query)), vs)

        vs = self._apply_filter_fn('no_backward_for_roundtrip', (lambda v: v.backward_fill_if_need(self.__query)), vs)

        vs = self._apply_filter_fn('bad_backward_transfers_times', (lambda v: v.backward_transfer_ok()), vs)

        vs = self._apply_filter_fn(
            'backward_departure_is_earlier_than_forward_arrival',
            (lambda v: v.departure_after_arrival_at_different_steps()),
            vs,
        )

        vs = self._apply_filter_fn(
            'date_backward_does_not_match_query', (lambda v: v.date_backward_ok(self.__query)), vs
        )

        vs = self._apply_filter_fn(
            'segments_service_class_does_not_match_query', (lambda v: v.is_segment_service_ok(self.__query)), vs
        )

        vs = self._apply_filter_fn(
            'variant_service_class_does_not_match_query', (lambda v: v.variant_service_ok(self.__query)), vs
        )

        vs = self._apply_filter_fn('zero_or_negative_price', (lambda v: v.price_ok()), vs)

        vs = self._apply_filter_fn('price_is_too_small', lambda v: v.national_tariff > self._min_price_national, vs)

        vs = self._apply_filter_fn('completeness_filter', (lambda v: v.completed_ok), vs)

        if feature_flags.filter_circle_routes():
            vs = self._apply_filter_fn('circle_flight', (lambda v: v.is_not_circle_variant()), vs)

        return vs


class BannedVariantsFilter(object):
    USE_PARTNER_CODES = {'aeroflot'}

    def __init__(self, banned_variants_cache, query, partner_code):
        # type: (BannedVariantsCache, Query, basestring)->None
        self.banned_variants_cache = banned_variants_cache
        self.query = query
        self.partner_code = partner_code
        if partner_code not in BannedVariantsFilter.USE_PARTNER_CODES:
            self._banned_tags = set()
        else:
            self._banned_tags = None

    @property
    def banned_tags(self):
        if self._banned_tags is None:
            self._banned_tags = self._query_banned_variant_tag_set(
                banned_variants_cache=self.banned_variants_cache,
                query=self.query,
                partner_code=self.partner_code,
            )
        return self._banned_tags

    def is_banned_variant(self, variant):
        # type: (Variant)->bool
        return variant.tag in self.banned_tags

    @staticmethod
    def _query_banned_variant_tag_set(banned_variants_cache, query, partner_code):
        # type: (BannedVariantsCache, Query, basestring)->Set[basestring]
        banned_variant_tags = set()
        try:
            payload = ujson.loads(
                banned_variants_cache.get(
                    query=query, partner_code=partner_code,
                )['payload'],
            )  # type: Dict[basestring, Dict[basestring, Any]]
            banned_variant_tags = set(payload['tags'].keys())
        except (TypeError, ValueError, KeyError):
            pass
        except Exception:
            log.exception('Unexpected error loading banned variants')
        return banned_variant_tags


@memoize(lambda: True, CacheWithKeyTTL(180))
def _queryblacklist_rules_by_partner_code():
    by_p_code = defaultdict(list)

    for r in QueryBlackList.objects.filter(active=True).prefetch_related(
        # Выполнит ещё один запрос
        'partner',
    ):
        by_p_code[r.partner.code if r.partner else None].append(r)

    return dict(by_p_code)


def good_flight_numbers(v, p_code, qid):
    for s in v.all_segments:
        fixed_number = fix_flight_number(s.number, bool(v.charter))
        if fixed_number is None:
            log.warning(u'cant fix flight number: "%s" %s %s', s.number, p_code, qid)
            return False
        if fixed_number != s.number:
            log.warning(u'convert flight number: "%s" to "%s" %s %s', s.number, fixed_number, p_code, qid)
            s.set_number(fixed_number)
    return True


def filter_by_currency(v, query):
    return v.tariff.currency == settings.AVIA_NATIONAL_CURRENCIES[query.national_version] or (
        v.partner.foreign_currency
        and v.tariff.currency
        in settings.AVIA_NATIONAL_VERSION_ALLOWED_FOREIGN_CURRENCIES.get(query.national_version, set())
    )


def _handle_variants_tags(variants, msg):
    for v in variants:
        try:
            if v.tag:
                yield v
        except Exception:
            log.exception('Making vtag exception %s', msg)


def visa_required_for_inner_flights(v):
    forward_station_from = v.forward.segments[0].station_from
    forward_station_to = v.forward.segments[-1].station_to

    if forward_station_from.country_id != Country.RUSSIA_ID or forward_station_to.country_id != Country.RUSSIA_ID:
        return False

    if v.backward.segments:
        backward_station_from = v.backward.segments[0].station_from
        backward_station_to = v.backward.segments[-1].station_to

        if backward_station_from.country_id != Country.RUSSIA_ID or backward_station_to.country_id != Country.RUSSIA_ID:
            return False

    for segment in v.iter_all_segments():
        if segment.station_to.country_id not in NO_VISA_REQUIRED:
            return True

        if segment.station_from.country_id not in NO_VISA_REQUIRED:
            return True

    return False


def _log_unknown_station_codes(query, partner, variants):
    records = set()
    for v in variants:
        for segment in (s for s in v.iter_all_segments() if not s.is_complete):
            if not segment.station_from:
                records.add((query.id, partner.code, segment.company_iata, segment.number, segment.station_from_iata))
            if not segment.station_to:
                records.add((query.id, partner.code, segment.company_iata, segment.number, segment.station_to_iata))
    for r in records:
        unknown_station_codes_logger.log(*r)
