#!/usr/bin/env python
# -*- coding: utf-8 -*-

import antlr4
from antlr4.error import Errors

from crypta.lib.python.word_rule.WordRuleLexer import WordRuleLexer
from crypta.lib.python.word_rule.WordRuleParser import WordRuleParser
from crypta.profile.utils.segment_utils.visitor import Visitor


VALUE = 1


class Segments(object):
    def __init__(self, included=None, excluded=None):
        self.included = included or set()
        self.excluded = excluded or set()

    def __invert__(self):
        return Segments(included=set(self.excluded), excluded=set(self.included))

    def __and__(self, other):
        return Segments(included=self.included | other.included, excluded=self.excluded | other.excluded)

    def __eq__(self, other):
        return self.included == other.included and self.excluded == other.excluded


class NotNode(object):
    def __init__(self, children):
        assert 1 == len(children)
        self.child = children[0]

    def evaluate(self, evaluator, root_export_id):
        return not self.child.evaluate(evaluator, root_export_id)[0], VALUE

    @property
    def dependencies(self):
        return self.child.dependencies

    def get_full_dependencies(self, evaluator, root_export_id):
        return ~self.child.get_full_dependencies(evaluator, root_export_id)

    def has_terms_with_socdem_only(self, evaluator, root_export_id):
        return True  # Easier to cover all such cases


class AndNode(object):
    def __init__(self, children):
        self.children = children

    def evaluate(self, evaluator, root_export_id):
        return all(child.evaluate(evaluator, root_export_id)[0] for child in self.children), VALUE

    @property
    def dependencies(self):
        return set().union(*(x.dependencies for x in self.children))

    def get_full_dependencies(self, evaluator, root_export_id):
        result = Segments()
        for child in self.children:
            result &= child.get_full_dependencies(evaluator, root_export_id)
        return result

    def has_terms_with_socdem_only(self, evaluator, root_export_id):
        return all(child.has_terms_with_socdem_only(evaluator, root_export_id) for child in self.children)


class OrNode(object):
    def __init__(self, children):
        self.children = children

    def evaluate(self, evaluator, root_export_id):
        return any(child.evaluate(evaluator, root_export_id)[0] for child in self.children), VALUE

    @property
    def dependencies(self):
        return set().union(*(x.dependencies for x in self.children))

    def get_full_dependencies(self, evaluator, root_export_id):
        result = Segments()
        for child in self.children:
            result &= child.get_full_dependencies(evaluator, root_export_id)
        return result

    def has_terms_with_socdem_only(self, evaluator, root_export_id):
        return any(child.has_terms_with_socdem_only(evaluator, root_export_id) for child in self.children)


class LeafNode(object):
    def __init__(self, export_id_to_evaluate):
        self.export_id_to_evaluate = export_id_to_evaluate

    def evaluate(self, evaluator, root_export_id):
        return evaluator.evaluate_independent(self.export_id_to_evaluate, root_export_id)

    @property
    def dependencies(self):
        return {self.export_id_to_evaluate}

    def get_full_dependencies(self, evaluator, root_export_id):
        return evaluator(self.export_id_to_evaluate, root_export_id)

    def has_terms_with_socdem_only(self, evaluator, root_export_id):
        return evaluator(self.export_id_to_evaluate, root_export_id)


NODES_BY_OPS = {
    'NOT': NotNode,
    'AND': AndNode,
    'OR': OrNode,
    Visitor.LEAF: LeafNode,
}


class ExpressionParser(object):
    def __init__(self, exports_expressions, logger):
        self.exports_expressions = exports_expressions
        self.logger = logger

    @staticmethod
    def make_node(operator, operands):
        return NODES_BY_OPS[operator](operands)

    def build_trees(self):
        trees = {}

        for export_id, export in self.exports_expressions.iteritems():
            try:
                trees[export_id] = self.parse(export.expressions)
            except Errors.ParseCancellationException:
                self.logger.exception("Failed to parse expression for export: %s", export_id)

        self.remove_errors(trees)

        return trees

    def parse(self, expression):
        text = antlr4.InputStream(expression)
        lexer = WordRuleLexer(text)
        stream = antlr4.CommonTokenStream(lexer)
        parser = WordRuleParser(stream)
        parser._errHandler = antlr4.BailErrorStrategy()
        return Visitor(self.make_node).visit(parser.root())

    def remove_errors(self, trees):
        working_set = set()
        bad_set = set()
        finished_set = set()
        for export_id in trees:
            self.has_errors(trees, export_id, working_set, finished_set, bad_set)

        for export_id in bad_set:
            if export_id in trees:
                del trees[export_id]

    def has_errors(self, trees, export_id, working_set, finished_set, bad_set):
        if export_id in finished_set:
            return False
        if export_id in working_set or export_id in bad_set:
            return True

        root = trees.get(export_id)
        if root is None:
            bad_set.add(export_id)
            return True

        working_set.add(export_id)

        for child in root.dependencies:
            if child != export_id and self.has_errors(trees, child, working_set, finished_set, bad_set):
                working_set.remove(export_id)
                bad_set.add(export_id)
                return True

        working_set.remove(export_id)
        finished_set.add(export_id)

        return False
