from baobab import common as baobab_common
import tamus

# currently, `from baobab.common import EventContext, EventsParser, ShowAndClicksJoiner`
# fails with "ImportError: No module named common" hence we get it this way:
EventContext = baobab_common.EventContext
EventsParser = baobab_common.EventsParser
ShowAndClicksJoiner = baobab_common.ShowAndClicksJoiner

DEFAULT_RULES_SLICE_SIZE = 64
DEFAULT_FACTOR = ''
SKIP_RULES = [
    '/$page[@service = "web" and @subservice = "granny"]'
]


# factored_rules structure:
# {
#     <factor>: { // e.g. some tld
#         'slices': [
#             {
#                 <rule>: '$tamus'
#             }
#         ],
#         'key_map': {
#             <rule>: {
#                 'index': 0, // key index in slices
#                 'count': 0  // number of times the rule matched a json
#             }
#         }
#     }
# }
class BaobabMatcher:
    def __init__(self, counters_list, max_keys_per_factor=10, slice_size=DEFAULT_RULES_SLICE_SIZE):
        self.rules = convert_list_to_rules(counters_list)
        self.skip_rules = convert_list_to_rules(SKIP_RULES)
        self.factored_rules = {}
        self.max_keys_per_factor = max_keys_per_factor
        self.slice_size = slice_size

    def get_keys(self, event_json_str, factor=DEFAULT_FACTOR):
        try:
            context = EventContext(
                None,  # uid
                0,  # ts
                None,  # referer
                None,  # ip
                None,  # url
                event_json_str,  # json-str
            )
        except:
            return set()

        parser = EventsParser()
        event = parser.parse_event(context)

        if event is None:
            return set()

        joiner = ShowAndClicksJoiner()
        joiner.set_show(event)
        joiner.join()

        skip_blocks = tamus.check_rules(self.skip_rules, joiner).markers2blocks
        if skip_blocks:
            raise SkipEventException()

        if factor not in self.factored_rules:
            self.factored_rules[factor] = get_rule_slices(self.rules, self.slice_size)

        factor_rules = self.factored_rules[factor]
        matched_keys = set()

        for rules in factor_rules['slices']:
            if len(rules) == 0:
                continue
            marks = tamus.check_rules(rules, joiner)
            slice_keys = set(marks.markers2blocks.keys())
            matched_keys = matched_keys | slice_keys

        result = set()

        for key in matched_keys:
            key_meta = factor_rules['key_map'][key]
            key_meta['count'] += 1
            if key_meta['count'] >= self.max_keys_per_factor:
                del factor_rules['slices'][key_meta['index']][key]
                del factor_rules['key_map'][key]
            result.add(self.rules[key])

        return result


class SkipEventException(Exception):
    pass


def convert_list_to_rules(counters_list):
    rules = {}
    for index, counter in enumerate(counters_list):
        rules['r%d' % index] = counter
    return rules


def get_rule_slices(rules, slice_size):
    slices = []
    key_map = {}
    curr_slice = None

    for key in rules:
        if curr_slice is None:
            curr_slice = {}
            slices.append(curr_slice)
        curr_slice[key] = rules[key]
        key_map[key] = {'index': len(slices) - 1, 'count': 0}
        if (len(curr_slice) == slice_size):
            curr_slice = None

    return {'slices': slices, 'key_map': key_map}
