#!/usr/bin/python
# -*- coding: utf-8 -*-

import argparse
import operator
import os
import json
import subprocess
import yt.wrapper as yt

from collections import Counter

class AnswerKeysFilter:
    def __init__(self, keys, full):
        self.keys = keys
        self.full = full

    @classmethod
    def get_value(cls, x, key):
        path = key.split('.')
        for sub_key in path:
            x = x.get(sub_key)
            if x is None:
                return None
        return repr(x)

    def get_values(self, begemot_answer):
        if self.full:
            keys = []
            try:
                for k in dict(begemot_answer['rules']):
                    keys.append('rules.' + k)
            except:
                pass
        else:
            keys = self.keys

        values = [
            (key, self.get_value(begemot_answer, key))
            for key in keys
        ]
        return dict([value for value in values if value[1] is not None])

    def __call__(self, row):
        # prepared answers for calculating diff
        yield yt.create_table_switch(0)
        answer_row = dict()
        answer_row['reqid'] = row['reqid']
        begemot_answer = json.loads(row['begemot_answer'])
        for ans_type in begemot_answer:
            if ans_type["type"] == "wizard":
                begemot_answer = ans_type
                break
        if 'rules' in begemot_answer and '.version' in begemot_answer['rules']:
            del begemot_answer['rules']['.version']
        answer_row['begemot_answer'] = json.dumps(self.get_values(begemot_answer))
        yield answer_row
        # table with hashes for calculating share of unique answers
        yield yt.create_table_switch(1)
        hash_row = dict()
        hash_row['hash'] = hash(answer_row['begemot_answer'])
        yield hash_row


def pre_calc(input_table, prepared_answers_table, hashes_table, keys_to_compare=[], full=False):
    return yt.run_map(
        AnswerKeysFilter(keys_to_compare, full),
        source_table=input_table,
        destination_table=[prepared_answers_table, hashes_table],
        spec=common_spec, client=client,
        sync=False,
    )


def calc_repeated(input_table, repeated_counts_table):

    @yt.reduce_aggregator
    def repeated_counter(row_groups):
        count, repeated = 0, 0
        for _, rows in row_groups:
            rows_count = sum(1 for _ in rows)
            count += rows_count
            repeated += rows_count - 1
        yield dict(count=count, repeated=repeated)

    return yt.run_reduce(
        repeated_counter,
        input_table,
        repeated_counts_table,
        reduce_by='hash',
        spec=common_spec,
        client=client,
        sync=False,
        job_count=1,
    )


def cache_guess(repeated_counts_table):
    count, repeated = 0, 0
    for row in yt.read_table(repeated_counts_table, client=client, format='json'):
        count += row['count']
        repeated += row['repeated']
    return float(repeated) / count


def calc_matches(answers_old, answers_new, matches_counts_table, diff_table):

    def split_relev_like_field(relev):
        if relev is None:
            return dict()
        relev = (str(relev)).lstrip('[').rstrip(']')
        if relev.startswith("u'"):
            relev = relev[2:-1]
        return dict([(tuple(value.split('=', 1)) if '=' in value else (value, '')) for value in relev.split(';') if value])

    def calc_changes(dict1, dict2):
        for d in [dict1, dict2]:
            for k in d:
                try:
                    d[k] = str(d[k])
                except:
                    d[k] = str(json.dumps(d[k]).encode('utf-8'))
        for k, _ in set(dict1.items()).symmetric_difference(set(dict2.items())):
            if k not in ['relev', 'rearr', 'snip', 'rules']:
                yield k, '%s --> %s' % (dict1.get(k), dict2.get(k))

    @yt.reduce_aggregator
    def diff(row_groups):
        match_count, count = 0, 0
        yield yt.create_table_switch(1)
        for k, rows in row_groups:
            count += 1
            rows = list(rows)
            if len(rows) != 2:
                raise Exception('Only one begemot answer was found for reqid %s' % k['reqid'])
            ans1, ans2 = (json.loads(row['begemot_answer']) for row in rows)
            if ans1 == ans2:
                match_count += 1
            else:
                changes = dict(calc_changes(ans1, ans2))
                for field in ['relev', 'rearr', 'snip']:
                    diff = dict(calc_changes(split_relev_like_field(ans1.get(field)), split_relev_like_field(ans2.get(field))))
                    if diff:
                        changes[field] = json.dumps(diff)
                rules1 = ans1.get('rules') or dict()
                rules2 = ans2.get('rules') or dict()
                try:
                    diff = dict(calc_changes(rules1, rules2))
                except:
                    changes['rules'] = 'Something went wrong. Task failed to find'
                if diff:
                    changes['rules'] = diff
                changes['reqid'] = k['reqid']
                yield changes
        yield yt.create_table_switch(0)
        yield dict(count=count, match_count=match_count)

    yt.run_reduce(
        diff,
        [answers_old, answers_new],
        [matches_counts_table, diff_table],
        reduce_by=['reqid'],
        spec=common_spec,
        client=client,
        sync=True,
    )


def get_answers_diff(matches_counts_table):
    count, matched = 0, 0
    for row in yt.read_table(matches_counts_table, client=client, format='json'):
        count += row['count']
        matched += row['match_count']
    return count - matched, count


def update_fields_by_rule(fields_by_rule, name, factor):
    if name in fields_by_rule:
        fields_by_rule[name].add(factor)
    else:
        fields_by_rule[name] = set([factor])
    return fields_by_rule


def bgfactors_diff(bgfactors, changed_fields, fields_by_rule, rp_process):
    items = bgfactors.split(' --> ')
    if len(items) != 2:
        raise Exception("Bgfactors not parsed")

    parsed = []
    for item in items:
        rp_process.stdin.write(item + '\n')
        rp_process.stdin.flush()
        parsed.append(json.loads(rp_process.stdout.readline()))

    for factor in parsed[0]:
        if factor not in parsed[1] or parsed[0][factor] != parsed[1][factor]:
            field = "relev.bgfactors." + factor
            changed_fields[field] += 1
            fields_by_rule = update_fields_by_rule(fields_by_rule, 'QueryFactors', field)
    for factor in parsed[1]:
        if factor not in parsed[0]:
            field = "relev.bgfactors." + factor
            changed_fields[field] += 1
            fields_by_rule = update_fields_by_rule(fields_by_rule, 'QueryFactors', field)

    return changed_fields, fields_by_rule


def analyze_diff(diff_table, rp_path, relev_map, limit):
    changed_fields = Counter()
    rules_impact = Counter()
    fields_by_rule = {}
    rows_left = limit

    rp_args = [rp_path, 'QueryFactors.Factors', '--filter']
    os.chmod(rp_args[0], 0o755)
    rp_process = subprocess.Popen(rp_args, stdin=subprocess.PIPE, stdout=subprocess.PIPE, universal_newlines=True)
    for row in yt.read_table(diff_table, client=client, format='json'):
        if limit > 0 and rows_left == 0:
            break
        rows_left -= 1
        rules = set()
        for k in row:
            if k == 'reqid':
                continue
            elif k == 'relev':
                relev = json.loads(row[k])
                for field in relev:
                    if field in relev_map:
                        name = relev_map[field]
                    else:
                        name = '<unknown rule>'
                    rules.add(str(name))
                    if field == 'bgfactors':
                        changed_fields, fields_by_rule = bgfactors_diff(relev[field], changed_fields, fields_by_rule, rp_process)
                    else:
                        factor = k + '.' + field
                        changed_fields[factor] += 1
                        fields_by_rule = update_fields_by_rule(fields_by_rule, name, factor)
            else:
                changed_fields[k] += 1
                field = k.split('.')
                if field[0] == 'rules':
                    name = field[1]
                else:
                    name = field[0]
                rules.add(str(name))
                fields_by_rule = update_fields_by_rule(fields_by_rule, name, k)
        for rule in rules:
            rules_impact[rule] += 1

    rp_process.stdin.close()
    rp_process.stdout.close()
    rp_process.wait()

    ans = []
    for rule, cnt in sorted(rules_impact.items(), key=operator.itemgetter(1), reverse=True):
        ans.append({'rule' : rule, 'count' : cnt})
        factors = {}
        for field in fields_by_rule[rule]:
            factors[field] = changed_fields[field]
        ans[-1]['fields'] = sorted(factors.items(), key=operator.itemgetter(1), reverse=True)
    return ans


if __name__ == '__main__':
    parser = argparse.ArgumentParser(
        description='Calculates begemot answers diff and cache guess, prints json with result to stdout'
    )
    parser.add_argument('--old', dest='answers_old', help='Begemot answers old', required=True)
    parser.add_argument('--new', dest='answers_new', help='Begemot answers new', required=True)
    parser.add_argument('--output_path', dest='output_path', help='Intermediate tables path', required=True)
    parser.add_argument('--keys', dest='keys', help='Keys to compare in begemot answers - json array', required=True)
    parser.add_argument('--yt_proxy', dest='yt_proxy', required=True)
    parser.add_argument('--yt_pool', dest='yt_pool', required=True)
    parser.add_argument('--rp_path', dest='rp_path', help='path to response-parser tool', required=True)
    parser.add_argument('--relev_map', dest='relev_map', help='relev_factor->rule_name mapping', required=True)
    parser.add_argument('--limit', dest='limit', help='Analyzed answers limit', required=False, default=-1)
    parser.add_argument('--full', dest='full_check', action='store_true')
    parser.add_argument('--detailed', dest='detailed', action='store_true')
    parser.add_argument('--fail_on_error', dest='fail_on_error', action='store_true')
    parser.add_argument('--debug', dest='debug', action='store_true')
    args = parser.parse_args()

    token = os.environ['YT_TOKEN']
    client = yt.YtClient(
        args.yt_proxy,
        token,
        config=dict(
            pickling={
                'module_filter': lambda lib: hasattr(lib, '__file__') and not lib.__file__.endswith('.so')
                                             and 'hashlib' not in getattr(lib, '__name__', ''),
                'force_using_py_instead_of_pyc': True,
            },
            pool=args.yt_pool,
        ),
    )
    common_spec = {'max_failed_job_count': 0}
    keys = json.loads(args.keys)
    relev_map = json.loads(args.relev_map)

    prepared_old = yt.ypath_join(args.output_path, 'prepared_old')
    prepared_new = yt.ypath_join(args.output_path, 'prepared_new')
    old_hashes = yt.ypath_join(args.output_path, 'old_hashes')
    new_hashes = yt.ypath_join(args.output_path, 'new_hashes')
    matches_table = yt.ypath_join(args.output_path, 'matches_counts')
    diff_table = yt.ypath_join(args.output_path, 'diff')

    if args.debug:
        calc_matches(prepared_old, prepared_new, matches_table, diff_table)
        if args.full_check:
            analyze_diff(diff_table, args.rp_path, relev_map, int(args.limit))
    else:
        pre_calc_operation1 = pre_calc(args.answers_old, prepared_old, old_hashes, full=args.full_check, keys_to_compare=keys)
        pre_calc_operation2 = pre_calc(args.answers_new, prepared_new, new_hashes, full=args.full_check, keys_to_compare=keys)
        pre_calc_operation1.wait(), pre_calc_operation2.wait()
        sort_hashes_old = yt.run_sort(old_hashes, old_hashes, sort_by='hash', spec=common_spec, client=client, sync=False)
        sort_hashes_new = yt.run_sort(new_hashes, new_hashes, sort_by='hash', spec=common_spec, client=client, sync=False)
        sort_prepared_old = yt.run_sort(prepared_old, prepared_old, sort_by='reqid', spec=common_spec, client=client, sync=False)
        sort_prepared_new = yt.run_sort(prepared_new, prepared_new, sort_by='reqid', spec=common_spec, client=client, sync=False)
        sort_hashes_old.wait(), sort_hashes_new.wait()
        repeated_old = yt.ypath_join(args.output_path, 'repeated_counts_old')
        repeated_new = yt.ypath_join(args.output_path, 'repeated_counts_new')
        operation1, operation2 = calc_repeated(old_hashes, repeated_old), calc_repeated(new_hashes, repeated_new)
        operation1.wait(), operation2.wait()
        cache_guess_old, cache_guess_new = cache_guess(repeated_old), cache_guess(repeated_new)
        sort_prepared_old.wait(), sort_prepared_new.wait()
        calc_matches(prepared_old, prepared_new, matches_table, diff_table)
        if args.detailed:
            if args.fail_on_error:
                diff_rules = analyze_diff(diff_table, args.rp_path, relev_map, int(args.limit))
            else:
                try:
                    diff_rules = analyze_diff(diff_table, args.rp_path, relev_map, int(args.limit))
                except:
                    diff_rules = {}
        else:
            diff_rules = {}
        diff_parsed = True if diff_rules else False
        answers_diff_count, answers_count = get_answers_diff(matches_table)
        answers_diff = float(answers_diff_count) / answers_count
        result = dict(
            cache_guess_old=cache_guess_old,
            cache_guess_new=cache_guess_new,
            answers_diff=answers_diff,
            diff_parsed=diff_parsed,
            diff_rules=diff_rules,
            answers_count=answers_count,
            answers_diff_count=answers_diff_count,
        )
        print(json.dumps(result))
