# -*- coding: utf-8 -*-
from __future__ import unicode_literals

import json
import os
import re

import ujson
from django.utils.encoding import force_text
from library.python import resource
from lxml import etree

from travel.avia.ticket_daemon_api.jsonrpc.lib.fare_families import (
    XML_TEMPLATE, FARE_FAMILIES_PATH, get_all_resources
)
from travel.avia.ticket_daemon_api.jsonrpc.lib.fare_families.fare_families import get_xpath_expressions


class ErrorsCollector:
    def __init__(self):
        self._errors = []
        self._filename = None
        self._tariff = None
        self._term = None
        self._rule_id = None

    def empty(self):
        return len(self._errors) == 0

    @property
    def errors(self):
        return self._errors

    def set_filename(self, filename):
        self._filename = filename

    def get_filename(self):
        return self._filename

    def set_tariff(self, tariff):
        self._tariff = tariff

    def set_term(self, term):
        self._term = term

    def add_term_error(self, reason):
        self._errors.append(self._term_error(reason))

    def add_rule_error(self, rule_id, reason):
        self._rule_id = rule_id
        self._errors.append(self._rule_error(reason))

    def add_file_error(self, reason):
        self._errors.append(self._file_error(reason))

    def add_tariff_error(self, reason):
        self._errors.append(self._tariff_error(reason))

    def _check_context_initialized(self):
        assert self._filename is not None
        assert self._tariff is not None
        assert self._term is not None

    def _file_error(self, reason):
        assert self._filename is not None

        return 'File Error: "%s" for "%s" file' % (
            reason, self._filename
        )

    def _tariff_error(self, reason):
        return 'Tariff Error: "%s" in "%s" file, "%s" tariff.' % (
            reason, self._filename,
            self._tariff['tariff_group_name'].get('en'),
        )

    def _term_error(self, reason):
        self._check_context_initialized()

        return 'Term Error: "%s" in "%s" file, "%s" tariff, "%s" term' % (
            reason, self._filename,
            self._tariff['tariff_group_name'].get('en'),
            self._term['code'],
        )

    def _rule_error(self, reason):
        self._check_context_initialized()
        assert self._rule_id is not None

        return 'Rule Error: "%s" in "%s" file, "%s" tariff, "%s" term, rule id = %i' % (
            reason, self._filename,
            self._tariff['tariff_group_name']['en'],
            self._term['code'], self._rule_id,
        )


class FareFamiliesDataValidator(object):
    _tree = etree.fromstring(XML_TEMPLATE)

    @classmethod
    def validate(cls):
        collector = ErrorsCollector()

        for filename, tariffs in cls._get_all_fare_families():
            collector.set_filename(filename)
            bad_format = cls._check_json_format(filename, collector)
            cls._check_filename(filename, collector)
            if bad_format:
                continue

            for tariff in tariffs:
                collector.set_tariff(tariff)
                cls._check_tariff(tariff, collector)

                for term in tariff['terms']:
                    collector.set_term(term)
                    cls._check_term(term, collector)
                    cls._check_term_refundable_have_availability(term, collector)

            cls._check_fare_family_uniqueness(tariffs, collector)
            cls._check_fare_family_order(tariffs, collector)
            cls._check_tariff_code_pattern(tariffs, collector)

        return collector.errors

    @classmethod
    def _check_filename(cls, filename, errors_collector):
        filename_parts = filename.split('_', 1)

        if len(filename_parts) < 2:
            errors_collector.add_file_error(
                'Company id and other part of filename must be join by "_"'
            )
        elif not filename_parts[0].isdigit():
            errors_collector.add_file_error(
                'Company id "%s" is not digit in file name' % filename_parts[0]
            )

    @classmethod
    def _check_tariff(cls, tariff, errors_collector):
        for req_field in ('base_class', 'brand', 'tariff_code_pattern', 'tariff_group_name', 'terms'):
            if not tariff.get(req_field):
                errors_collector.add_tariff_error(
                    'Required field "%s" is not set' % req_field
                )

        for lang in ('ru', 'en'):
            if not tariff['tariff_group_name'].get(lang):
                errors_collector.add_tariff_error(
                    'Required "%s" tariff_group_name translation is not set' % lang
                )

    @classmethod
    def _check_term(cls, term, errors_collector):
        default_rules_idx = list(cls._check_rules(
            term['rules'], errors_collector
        ))

        if len(default_rules_idx) > 1:
            errors_collector.add_term_error('Many default rules')
        elif (len(default_rules_idx) == 1
                and default_rules_idx[0] != len(term['rules']) - 1):
            errors_collector.add_term_error('Default rule is not last')

    @classmethod
    def _check_term_refundable_have_availability(cls, term, errors_collector):
        if term['code'] == 'refundable':
            for rule in term['rules']:
                if 'availability' not in rule:
                    errors_collector.add_term_error('Refundable term is missing availability')

    @classmethod
    def _check_rules(cls, rules, errors_collector):
        company_id = errors_collector.get_filename().split('_', 1)[0]
        expressions = get_xpath_expressions(company_id)

        for rule_idx, rule in enumerate(rules):
            if rule.get('ignore', False):
                continue

            if 'xpath' in rule:
                try:
                    cls._tree.xpath(rule['xpath'])
                except etree.XPathError:
                    errors_collector.add_rule_error(
                        rule_idx, 'XPath Syntax Error'
                    )
            elif 'external_xpath_ref' in rule:
                ref = rule['external_xpath_ref']

                if ref not in expressions:
                    errors_collector.add_rule_error(
                        rule_idx,
                        'External xpath referrence(%s) doesn\'t exist' % ref
                    )
                else:
                    try:
                        cls._tree.xpath(expressions[ref])
                    except etree.XPathError:
                        errors_collector.add_rule_error(
                            rule_idx, 'External XPath(ref="%s") Syntax Error' % ref
                        )
            else:
                yield rule_idx

    @staticmethod
    def _get_content(filename):
        return force_text(resource.find(os.path.join(FARE_FAMILIES_PATH, filename)))

    @classmethod
    def _check_json_format(cls, filename, errors_collector):
        content = cls._get_content(filename)
        try:
            json.loads(content)
        except Exception as e:
            errors_collector.add_file_error(e)
            return True

    @classmethod
    def _get_all_fare_families(cls):
        for filename in get_all_resources(FARE_FAMILIES_PATH):
            content = cls._get_content(filename)
            yield filename, ujson.loads(content)

    @staticmethod
    def _check_fare_family_uniqueness(tariffs, errors_collector):
        fare_families = set()
        for tariff in tariffs:
            fare_family = '{} {}'.format(tariff.get('base_class'), tariff.get('brand'))
            if fare_family in fare_families:
                errors_collector.add_file_error('Fare family "{}" is not unique'.format(fare_family))
            fare_families.add(fare_family)

    @staticmethod
    def _check_fare_family_order(tariffs, errors_collector):
        base_classes = {'ECONOMY': 1, 'PREMIUM_ECONOMY': 2, 'COMFORT': 2, 'BUSINESS': 3}
        for i in range(1, len(tariffs)):
            prev_base_class = tariffs[i - 1].get('base_class')
            curr_base_class = tariffs[i].get('base_class')
            if base_classes.get(prev_base_class, 0) > base_classes.get(curr_base_class, 0):
                errors_collector.add_file_error(
                    'Fare families are not sorted. "{}" comes before "{}"'.format(prev_base_class, curr_base_class),
                )

    @staticmethod
    def _check_tariff_code_pattern(tariffs, errors_collector):
        for tariff in tariffs:
            try:
                re.compile(tariff['tariff_code_pattern'])
            except Exception as e:
                fare_family = '{} {}'.format(tariff.get('base_class'), tariff.get('brand'))
                errors_collector.add_file_error('Error "{}" in tariff_code_pattern for {}'.format(e, fare_family))
