# coding: utf-8

import logging
import os

import marisa_trie

from vins_core.nlu.flow_nlu_factory.transition_model import MarkovTransitionModel
from vins_core.utils.data import load_data_from_file

from uhura.lib.vins.transition import intents
from uhura.lib.vins.transition.transition_rules import create_transition_rules, ChainRule

logger = logging.getLogger(__name__)
TRANSITION_RULES_FILENAME = 'transition_rules.yaml'


class UhuraTransitionModel(MarkovTransitionModel):
    def __init__(
            self, intents, base_path, internal_boost=10.0, active_slot_ellipsis_boost=1.0,
            custom_rules=None,
            **kwargs
    ):
        super(UhuraTransitionModel, self).__init__(**kwargs)

        self._internal_boost = internal_boost
        self._base_path = base_path
        self._active_slot_ellipsis_boost = active_slot_ellipsis_boost
        self._custom_rules = create_transition_rules(self._parse_transition_rules(custom_rules)) if custom_rules else []
        self._precompute_transition_probs(intents)

    @property
    def internal_boost(self):
        return self._internal_boost

    @property
    def active_slot_ellipsis_boost(self):
        return self._active_slot_ellipsis_boost

    def __call__(self, intent_name, session, req_info=None, base_score=1):
        # Some intents are not allowed in certain contexts
        intents_to_skip = session.get('skip_intents', [])
        if intent_name in intents_to_skip:
            return 0.0

        # Basic score
        score = super(UhuraTransitionModel, self).__call__(intent_name, session, req_info)

        # Dynamic boosts
        boost = 1.0

        if req_info:
            message = req_info.additional_options.get('message')
            # if user sent an image, we should ban intents that can't process it
            if message and (message.get('document') or message.get('photo')):
                if not intents.is_image_intent(intent_name):
                    return 0.0
                else:
                    base_score = 1

        # Use transition rules
        active_rules = filter(lambda x: x.check(intent_name, session, req_info), self._custom_rules)
        if active_rules:
            logger.debug('Scenarios %s->%s active rules: %s', session.intent_name, intent_name, active_rules)
            if session.intent_name:
                module_prev = intents.get_intent_module(session.intent_name)
                module_next = intents.get_intent_module(intent_name)
                if module_prev != module_next and module_next != 'utils':
                    return 0.0
            if len(active_rules) > 1:
                logger.debug('More than one rule is active, their boosts will be multipled')
            for active_rule in active_rules:
                if isinstance(active_rule, ChainRule):
                    base_score = 1  # even if base_score == 0 we must force this intent
                boost *= active_rule.boost
            logger.debug('Active rules boost: %s', boost)

        if boost != 1.0:
            logger.debug('Dynamic transition model boost %s', boost)

        logger.debug('Weights: %s, %f', intent_name, score * base_score * boost)
        return base_score * score * boost

    def _precompute_transition_probs(self, intents):
        intent_trie = marisa_trie.Trie(intent.name for intent in intents)

        for intent in intents:
            self.add_transition(
                None,
                intent.name,
                self._compute_transition_score(intent_trie, None, intent)
            )

        for prev_intent in intents:
            for intent in intents:
                self.add_transition(
                    prev_intent.name,
                    intent.name,
                    self._compute_transition_score(intent_trie, prev_intent, intent)
                )

    @staticmethod
    def _intent_info(intent_trie, intent_name):
        parent_name = intents.parent_intent_name(intent_name)
        is_internal = intents.is_internal(intent_name)

        if is_internal:
            has_internal = True
        else:
            keys = intent_trie.keys(parent_name + intents.INTERNAL_INTENT_SEPARATOR)
            has_internal = bool(keys)

        return parent_name, is_internal, has_internal

    def _parse_transition_rules(self, custom_rules):
        rules = []
        if custom_rules['enable']:
            path = os.path.join(self._base_path, custom_rules['path'])
            for (directory, _, filenames) in os.walk(path):
                if TRANSITION_RULES_FILENAME in filenames:
                    rules += load_data_from_file(os.path.join(directory, TRANSITION_RULES_FILENAME))['rules']

        return rules

    def _compute_transition_score(self, intent_trie, prev_intent, intent):
        # The main idea behind this model:
        #  * P(intent_X | intent_X) = 1
        #  * P(intent_X | intent_X__internal) = 1
        #  * P(intent_X__internal | intent_Y) = 0 because we don't want 'internal' and elliptic utterances to trigger the intent.  # noqa
        #  * P(intent_X__internal | intent_X) > P(intent_Y | intent_X) because when we started doing something,
        #      we're more likely to continue than to switch to another task
        #  * P(intent_X__internal | intent_X__internal) > P(intent_Y | intent_X__internal) for the same reason

        prev_intent_name = prev_intent and prev_intent.name
        intent_name = intent.name
        parent_name, is_internal, has_internal = self._intent_info(intent_trie, intent_name)

        # There is no previous intent
        if prev_intent_name is None:
            return 0.0 if is_internal else 1.0

        prev_parent_name, prev_is_internal, _ = self._intent_info(intent_trie, prev_intent_name)

        # submit_form -> submit_form
        if intents.is_submit_form_intent(prev_intent_name) and not (
                intents.is_submit_form_intent(intent_name) and intents.have_same_parent(intent_name, prev_intent_name)
        ):
            return 0.0

        # Transitioning into an internal intent
        if is_internal:
            # Can only transition from another internal intent with the same parent or the parent itself
            if parent_name == prev_parent_name:
                return self.internal_boost
            else:
                return 0.0

        # We can enter intents with reset_form=true from anywhere
        if intent.reset_form:
            return 1.0

        # Intents with internal intents can only be entered from outside
        return 1.0 if not has_internal or parent_name != prev_parent_name else 0.0


def create_transition_model(*args, **kwargs):
    return UhuraTransitionModel(*args, **kwargs)
