#!/usr/bin/env python
# -*- coding: utf-8 -*-

import collections
import sys

import antlr4
from antlr4.error import Errors

from crypta.lib.python.word_rule.WordRuleLexer import WordRuleLexer
from crypta.lib.python.word_rule.WordRuleParser import WordRuleParser
from crypta.profile.utils.segment_utils.visitor import Visitor

# TODO(unretrofied): CRYPTA-16120 remove debug logging after fix
import logging


WordFilterQuery = collections.namedtuple('WordFilterQuery', ['query_text', 'word_conditions'])

logger = logging.getLogger(__name__)
logger.setLevel(logging.DEBUG)


MATCH_OBJECT = u'$mm'

yql_request_template = u"""
PRAGMA yt.DefaultMemoryLimit = "2048M";
PRAGMA yt.DataSizePerJob = "{data_size_per_job}";
PRAGMA File("word_conditions.txt", "yt://{word_rules_file}");

$rules = ParseFile("String", "word_conditions.txt");
$match = Hyperscan::MultiMatch(String::JoinFromList($rules, '\n'));

$m = ($lemmas) -> {{
    RETURN $match($lemmas);
}};

$rule_ids = ($lemmas) -> {{
    $mm = $m($lemmas);
    RETURN AsList(
        {rule_conditions}
    )
}};

$condition_id_to_rule_id = AsDict(
{condition_id_to_rule_id}
);

$convert_condition_results_to_rule_ids = ($condition_results) -> {{
    RETURN ListFlatMap(
        ListEnumerate($condition_results),
        ($condition_id_and_result) -> {{
            RETURN CASE $condition_id_and_result.1
                    WHEN True THEN  $condition_id_to_rule_id[$condition_id_and_result.0]
                    ELSE NULL
                    END;
        }}
    )
}};

$process_rules = ($lemmas) -> {{
    $condition_results = $rule_ids($lemmas);
    RETURN $convert_condition_results_to_rule_ids($condition_results);
}};


INSERT INTO `{output_table}` WITH TRUNCATE
SELECT DISTINCT
    yandexuid,
    rule_ids AS rule_id,
FROM (
    SELECT
        yandexuid,
        $process_rules(lemmas) as rule_ids
    FROM `{input_table}`
)
FLATTEN LIST BY rule_ids
"""


class BinaryFilter(object):
    regex_delimiter = None
    operator_delimiter = None

    def __init__(self, operands):
        self.expr_operands = []
        self.lemma_operands = []

        for operand in operands:
            self.lemma_operands.append(operand.operand) if isinstance(operand, LemmaFilter) else self.expr_operands.append(operand)

    def serialize(self, yql_word_filter):
        operands = [x.serialize(yql_word_filter) for x in self.expr_operands]

        if self.lemma_operands:
            yql_word_filter.all_words.update(self.lemma_operands)
            regexp = u'.* {} .*'.format(self.regex_delimiter.join(sorted(self.lemma_operands)))
            word_filter = yql_word_filter._get_word_filter(regexp)
            operands.append(word_filter)

        result = self.operator_delimiter.join(operands)
        if len(operands) > 1:
            result = '({})'.format(result)
        return result


class AndFilter(BinaryFilter):
    regex_delimiter = u' .* '
    operator_delimiter = u' AND '


class OrFilter(BinaryFilter):
    regex_delimiter = u' | '
    operator_delimiter = u' OR '


class NotFilter(object):
    def __init__(self, operands):
        assert 1 == len(operands)

        self.operand = operands[0]

    def serialize(self, yql_word_filter):
        return u'NOT {}'.format(self.operand.serialize(yql_word_filter))


class LemmaFilter(object):
    def __init__(self, operand):
        self.operand = operand

    def serialize(self, yql_word_filter):
        yql_word_filter.all_words.add(self.operand)
        regexp = u'.* {} .*'.format(self.operand)
        return yql_word_filter._get_word_filter(regexp)


class RaiseRecursionLimitContextManager(object):
    def __init__(self, custom_limit):
        self.custom_limit = custom_limit
        self.old_recursion_limit = sys.getrecursionlimit()

    def __enter__(self):
        sys.setrecursionlimit(self.custom_limit)

    def __exit__(self, type, value, traceback):
        sys.setrecursionlimit(self.old_recursion_limit)


FILTERS_BY_OPS = {
    'NOT': NotFilter,
    'AND': AndFilter,
    'OR': OrFilter,
    Visitor.LEAF: LemmaFilter,
}


class YqlWordFilter(object):
    def __init__(self, logger=None):
        self.logger = logger
        self.all_words = set()
        self.word_filter_id_dict = collections.OrderedDict()
        self.existent_word_filters_dict = collections.OrderedDict()
        self.conditions = collections.OrderedDict()

    def parse_condition_string(self, condition_string):
        # TODO(unretrofied): CRYPTA-16120 remove debug logging after fix
        logger.debug("Preparing to parse condition_string: %s", condition_string)
        text = antlr4.InputStream(condition_string)
        lexer = WordRuleLexer(text)
        stream = antlr4.CommonTokenStream(lexer)
        parser = WordRuleParser(stream)
        parser._errHandler = antlr4.BailErrorStrategy()
        with RaiseRecursionLimitContextManager(10000):  # roughly 3x recursion levels are required for each top-level binary operator
            return Visitor(self.make_filter).visit(parser.root()).serialize(self)

    def make_filter(self, operator, operands):
        return FILTERS_BY_OPS[operator](operands)

    def add_condition_string(self, rule_revision_id, condition_string):
        self.conditions[rule_revision_id] = condition_string

    def get_yql_query(
        self,
        input_table,
        output_table,
        word_rules_file,
        rule_revision_ids=None,
        data_size_per_job="1G",
    ):
        rule_conditions = []
        condition_id_to_rule_revision_id = []

        for rule_revision_id, condition_string in self.conditions.iteritems():
            if not rule_revision_ids or rule_revision_id in rule_revision_ids:
                try:
                    condition = self.parse_condition_string(condition_string)
                    rule_conditions.append('-- rule {}\n{}'.format(rule_revision_id, condition))
                    condition_id_to_rule_revision_id.append(
                        '({}, {}ul)'.format(len(condition_id_to_rule_revision_id), rule_revision_id)
                    )
                except Errors.ParseCancellationException as exception:
                    if self.logger:
                        self.logger.exception(u'%s %s %s', exception, rule_revision_id, condition_string)
                    else:
                        print(exception, rule_revision_id, condition_string)

        if not rule_conditions:
            return None

        return WordFilterQuery(
            query_text=yql_request_template.format(
                input_table=input_table,
                output_table=output_table,
                word_rules_file=word_rules_file,
                rule_conditions=',\n'.join(rule_conditions),
                condition_id_to_rule_id=',\n'.join(condition_id_to_rule_revision_id),
                data_size_per_job=data_size_per_job,
            ),
            word_conditions=u'\n'.join(self.word_filter_id_dict.keys()),
        )

    def _get_word_filter(self, regexp):
        if regexp in self.existent_word_filters_dict:
            return self.existent_word_filters_dict[regexp]
        else:
            word_filter_id = len(self.word_filter_id_dict)
            self.word_filter_id_dict[regexp] = word_filter_id
            self.existent_word_filters_dict[regexp] = u'{}.{}'.format(MATCH_OBJECT, word_filter_id)
            return self.existent_word_filters_dict[regexp]
