# coding: utf-8

import html
import re
from base64 import b64decode
from contextlib import ContextDecorator
from time import sleep
from typing import Union
from xml.etree import ElementTree

from django.conf import settings
from django.utils.timezone import now
from faker import Factory
from zeep import Client

from procu.api import models


class spark_client(ContextDecorator):
    def __enter__(self):
        self.client = Client(settings.SPARK_URL)
        self.client.service.Authmethod(
            settings.SPARK_LOGIN, settings.SPARK_PASSWORD
        )
        return self.client.service

    def __exit__(self, *exc):
        self.client.service.End()
        return False


INDICATORS = {
    'high': re.compile(r'высокий', flags=re.I | re.U),
    'low': re.compile(r'низкий', flags=re.I | re.U),
    'moderate': re.compile(r'средний', flags=re.I | re.U),
}


class NoData(Exception):
    pass


def get_indicator(text):
    for key, regexp in INDICATORS.items():
        if regexp.search(text):
            return key

    return ''


def get_short_risk(text, n=2):
    return ' '.join(text.split()[:n]).lower()


def spark_cached(func):
    def wrapped(vat_id, update=False):
        try:
            if update:
                raise ValueError

            return models.SparkCache.objects.values_list(
                'risk_data', flat=True
            ).get(vat_id=vat_id)

        except (models.SparkCache.DoesNotExist, ValueError):
            data = func(vat_id)

            models.SparkCache.objects.update_or_create(
                vat_id=vat_id, defaults={'risk_data': data, 'updated_at': now()}
            )
            return data

    return wrapped


if settings.IS_PRODUCTION:

    @spark_cached
    def get_risks_summary(vat_id: Union[int, str]):

        vat_id = str(vat_id)

        if not (len(vat_id) in (9, 10) and vat_id.isdigit()):
            raise ValueError('VAT ID is not valid: %s' % vat_id)

        vat_id = int(vat_id)

        # ----------------------------------------------------------------------

        with spark_client() as client:
            risk_info = client.GetCompanyRiskFactors(inn=vat_id)

        tree = ElementTree.fromstring(risk_info['xmlData'].encode('utf-8'))
        report = tree.find('Data/Report')

        if report is None:
            raise NoData

        company = report.find('ShortName')
        company_name = (
            html.unescape(company.text) if (company is not None) else ''
        )

        summary = {
            'vat_id': vat_id,
            'name': company_name,
            'updated_at': now().isoformat(),
        }

        # ----------------------------------------------------------------------
        # Risks

        risks = []

        # Overall Risk
        node = report.find('ConsolidatedIndicator')
        if node is not None:
            risks.append(
                {
                    'short_name': 'Риск',
                    'full_name': 'Риск',
                    'description': get_short_risk(node.get('Description'), 1),
                    'indicator': get_indicator(node.get('Description')),
                }
            )

        # Concrete Risks
        node = report.find('IndexOfDueDiligence')
        if node is not None:
            risks.append(
                {
                    'short_name': 'ИДО',
                    'full_name': 'Индекс должной осмотрительности',
                    'indicator': get_indicator(node.get('IndexDesc')),
                    'description': get_short_risk(node.get('IndexDesc')),
                    'score': node.get('Index'),
                }
            )

        node = report.find('FailureScore')
        if node is not None:
            risks.append(
                {
                    'short_name': 'ИФР',
                    'full_name': 'Индекс финансового риска',
                    'indicator': get_indicator(node.get('FailureScoreDesc')),
                    'description': get_short_risk(node.get('FailureScoreDesc')),
                    'score': node.get('FailureScoreValue'),
                }
            )

        node = report.find('PaymentIndex')
        if node is not None:
            risks.append(
                {
                    'short_name': 'ИПД',
                    'full_name': 'Индекс платежной дисциплины',
                    'indicator': get_indicator(node.get('PaymentIndexDesc')),
                    'description': get_short_risk(node.get('PaymentIndexDesc')),
                    'score': node.get('PaymentIndexValue'),
                }
            )

        summary['risks'] = risks

        # ----------------------------------------------------------------------
        # Risk Factors

        factors = []

        # Risk Factors
        factors_node = report.find('RiskFactors')
        if factors_node is not None:
            for factor in factors_node:
                data = {'name': factor.get('Name'), 'info': {}}

                info = factor.find('AddInfo')
                if info:
                    for item in info:
                        data['info'][item.get('Name')] = item.text

                factors.append(data)

        summary['factors'] = factors

        # ----------------------------------------------------------------------

        return summary


else:

    @spark_cached
    def get_risks_summary(vat_id: Union[int, str]):

        vat_id = str(vat_id)

        if not (len(vat_id) in (9, 10) and vat_id.isdigit()):
            raise ValueError('VAT ID is not valid: %s' % vat_id)

        fake = Factory.create('ru_RU')
        fake.seed(vat_id)

        summary = {
            'vat_id': vat_id,
            'name': fake.company(),
            'risks': [
                {'short_name': 'Риск', 'full_name': 'Риск'},
                {
                    'short_name': 'ИДО',
                    'full_name': 'Индекс должной осмотрительности',
                },
                {'short_name': 'ИФР', 'full_name': 'Индекс финансового риска'},
                {
                    'short_name': 'ИПД',
                    'full_name': 'Индекс платежной дисциплины',
                },
            ],
            'factors': [],
            'updated_at': now().isoformat(),
        }

        for i, risk in enumerate(summary['risks']):
            description = fake.random.choice(
                ['Низкий риск', 'Средний риск', 'Высокий риск']
            )

            risk['indicator'] = get_indicator(description)
            risk['description'] = get_short_risk(description, 2 if i else 1)

            if i:
                risk['score'] = fake.random.randint(0, 100)

        for i in range(fake.random.randint(1, 5)):
            data = {
                'name': fake.sentence(
                    nb_words=5, variable_nb_words=True, ext_word_list=None
                ),
                'info': {'Number': fake.random.randint(0, 100)},
            }
            summary['factors'].append(data)

        sleep(round(fake.random.random() * 7, 1))

        return summary


def get_risks_report(vat_id):

    with spark_client() as client:
        response = client.GetCompanySparkRisksReport(inn=vat_id)

    tree = ElementTree.fromstring(response['xmlData'].encode('utf-8'))
    report = tree.find('Data/Report/SparkRisksReport')

    if report is None:
        raise NoData

    return b64decode(report.text)
