# -*- coding: utf-8 -*-
import os
import re

import ujson
from django.utils.encoding import force_text
from library.python import resource
from lxml import etree

from travel.avia.ticket_daemon_api.jsonrpc.carry_on_size_bucket import get_carry_on_bucket_size
from travel.avia.ticket_daemon_api.jsonrpc.lib.date import TicketDaemonFlightDateTimeDeserializer, aware_to_timestamp
from travel.avia.ticket_daemon_api.jsonrpc.lib.fare_families import (
    XML_TEMPLATE, FARE_FAMILIES_PATH, FARE_FAMILIES_EXPRESSIONS_PATH, get_all_resources
)
from travel.avia.library.python.ticket_daemon.memo import memoize
from travel.avia.ticket_daemon_api.jsonrpc.models_utils.geo import get_station_by_id, get_country_by_id


def get_data_by_company_id(base_path, company_id, default):
    for filename in get_all_resources(base_path):
        if filename.startswith(str(company_id)):
            expressions = {}
            content = force_text(resource.find(os.path.join(base_path, filename)))
            parser = etree.XMLParser(remove_comments=True)
            tree = etree.fromstring(content, parser=parser)
            for expression in tree:
                expressions[expression.get('id')] = expression.text

            return expressions
    return default


@memoize(lambda: None)
def get_data_for_all_companies():
    fare_families_by_company_id = {}
    for filename in get_all_resources(FARE_FAMILIES_PATH):
        company_id = int(filename.split('_', 1)[0])
        content = force_text(resource.find(os.path.join(FARE_FAMILIES_PATH, filename)))
        fare_families_by_company_id[company_id] = ujson.loads(content)
        _fill_rules(fare_families_by_company_id[company_id], company_id)
    return fare_families_by_company_id


def _fill_rules(fare_families, company_id):
    """Replace external_xpath_ref with related xpath and remove ignored rules"""
    ext_xpaths = get_xpath_expressions(company_id)
    for ff in fare_families:
        for term in ff['terms']:
            rules = []

            for rule in term['rules']:
                if rule.get('ignore', False) is True:
                    continue
                if 'external_xpath_ref' in rule:
                    # Считаем, что всегда есть, иначе тесты бы не прошли
                    rule['xpath'] = ext_xpaths[rule['external_xpath_ref']]
                    rule.pop('external_xpath_ref')
                if term.get('code') == 'carry_on':
                    _update_carry_on_mark(rule)
                rules.append(rule)
            term['rules'] = rules


def _update_carry_on_mark(rule):
    """For carry-on rules smaller than 50cm in all dimensions, mark carry-on as small"""
    size = rule.get('size')
    max_part_size = 0
    if size:
        size_parts = size.split('x')
        for part in size_parts:
            value = 0
            try:
                value = int(part)
            except:
                pass
            if value > max_part_size:
                max_part_size = value
    rule['carry_on_size_bucket'] = get_carry_on_bucket_size(max_part_size)


@memoize(lambda company_id, klass: (company_id, klass))
def get_tariffs(company_id, klass):
    data = get_data_for_all_companies().get(company_id, [])
    if klass == 'business':
        data = filter(lambda tariff: tariff['base_class'] == 'BUSINESS', data)
    return data


@memoize(lambda company_id: company_id)
def get_xpath_expressions(company_id):
    return get_data_by_company_id(
        FARE_FAMILIES_EXPRESSIONS_PATH, company_id, default={}
    )


class FareFamilies(object):
    def __init__(self):
        self.get_tariff = memoize(
            keyfun=(lambda code_tariff, flight, query: (code_tariff, flight['key'], query.klass))
        )(self._get_tariff)
        self._dt_deserializer = TicketDaemonFlightDateTimeDeserializer()

    def _get_tariff(self, code_tariff, flight, query):
        """

        :type query: jsonrpc.query.Query
        """
        if not code_tariff:
            return None

        company_id = flight['company']
        tariffs = get_tariffs(company_id, query.klass)

        if not tariffs:
            return None

        for tariff in tariffs:
            if re.match(tariff['tariff_code_pattern'], code_tariff, re.I + re.U):
                return self._match_tariff_rules(tariff, code_tariff, flight)

        return None

    def _match_tariff_rules(self, fare_tariff, code_tariff, flight):
        terms = []

        for term in fare_tariff['terms']:
            result_term = {
                'code': term['code']
            }

            for rule_idx, rule in enumerate(term['rules']):
                """
                "code": "open_return_date",
                "rules": [
                  {
                    "availability": "FREE"
                  }
                ]
                },
                """
                if self.check_rule(rule, code_tariff, flight):
                    result_term.update({
                        'id': rule_idx,
                        'rule': rule,
                    })
                    terms.append(result_term)
                    break

        tariff = fare_tariff.copy()
        tariff['terms'] = terms
        tariff['key'] = self._format_tariff_key(tariff, flight)

        return tariff

    def check_rule(self, rule, code_tariff, flight):
        xpath = rule.get('xpath')
        if not xpath:
            return True
        tree = self._format_fare_xml(code_tariff, flight)

        return bool(tree.xpath(xpath))

    @staticmethod
    def _format_tariff_key(tariff, flight):
        return '%d;%s;%s;%s' % (
            flight['company'],
            tariff['base_class'],
            tariff['brand'],
            ';'.join(['%s=%d' % (term['code'], term['id']) for term in tariff['terms']])
        )

    def _format_fare_xml(self, code_tariff, flight):
        station_from = get_station_by_id(flight['from'])
        station_to = get_station_by_id(flight['to'])
        country_from = get_country_by_id(station_from.country_id)
        country_to = get_country_by_id(station_to.country_id)
        arrival_ts = aware_to_timestamp(
            self._dt_deserializer.deserialize(flight['arrival'])
        ) if flight['arrival'] else ''
        xml = XML_TEMPLATE.format(
            FareCodePrefix=code_tariff[:3],
            Code=code_tariff,
            FromCountry=country_from.code if country_from else '',
            ToCountry=country_to.code if country_to else '',
            FromAirport=station_from.iata,
            ToAirport=station_to.iata,
            ArrivalTs=arrival_ts,
        )

        return etree.fromstring(xml)
