# -*- coding: utf-8 -*-
import itertools
import logging
from collections import defaultdict, namedtuple

import typing

from travel.avia.library.python.common.models.partner import RegionalizePartnerQueryRule

from travel.avia.library.python.ticket_daemon.memo import memoize, SimpleWarmGroup


logger = logging.getLogger(__name__)
regionalization_warm_group = SimpleWarmGroup('regionalization')
_rule_attrs = [
    'id', 'exclude',
    'settlement_from_id', 'country_from_id',
    'settlement_to_id', 'country_to_id',
    'start_date', 'end_date', 'week_days'
]
RegionalizationRuleT = namedtuple('RegionalizationRuleT', _rule_attrs)


class PartnerRulesContainer(object):
    def __init__(self, rules):
        self._rules = {}
        self.only_exclude_rules = all([rule.exclude for rule in rules])

        for rule in rules:
            rule_from_id = rule.settlement_from_id or rule.country_from_id or None
            rule_to_id = rule.settlement_to_id or rule.country_to_id or None
            if rule_from_id == rule_to_id:
                logger.warning('Skip incorrect rule %r.', rule)
                continue
            self._rules.setdefault((rule_from_id, rule_to_id), []).append(rule)

    def get_applicable_rules(self, query):
        applicable_rules = []
        for from_id, to_id in itertools.product(
                [query.point_from.settlement_id, query.point_from.country_id, None],
                [query.point_to.settlement_id, query.point_to.country_id, None]
        ):
            applicable_rules.extend(self._rules.get((from_id, to_id), []))
        return applicable_rules


@regionalization_warm_group
@memoize()
def _rules_by_partner_code():
    by_partner_code = defaultdict(list)
    rules = RegionalizePartnerQueryRule.objects.values(
        'partner__code', *_rule_attrs
    )
    for r in rules:
        by_partner_code[r.pop('partner__code')].append(RegionalizationRuleT(**r))
    return {p_code: PartnerRulesContainer(rules) for p_code, rules in by_partner_code.iteritems()}


def get_partner_rules(partner_code):
    # type: (basestring) -> typing.Optional[PartnerRulesContainer]
    return _rules_by_partner_code().get(partner_code)
