from copy import deepcopy
from collections import defaultdict
from base64 import b64decode


import six


def check_failed(unique_logs, changed_logs):
    return any(unique_logs.itervalues()) or any(check_changed(item) for item in changed_logs.itervalues())


def check_changed(changed):
    return any(changed.itervalues())


def merge_unique(unique_base, unique_to_merge):
    for key in ('pre', 'test'):
        unique_base[key] |= unique_to_merge[key]


def merge_changed(changed_base, changed_to_merge):
    default_changed = {'+': set(), '-': set(), '!': set(), 'len': 0}
    for key in changed_to_merge:
        changed_base.setdefault(key, deepcopy(default_changed))
        for subkey in '+-!':
            changed_base[key][subkey] |= changed_to_merge[key][subkey]
        changed_base[key]['len'] = max(changed_base[key]['len'], changed_to_merge[key]['len'], key=abs)


def wrap_pre_test(func):
    def wrapped_pre(*pre_args, **pre_kwargs):
        def wrapped_test(*test_args, **test_kwargs):
            return func(*pre_args, **pre_kwargs), func(*test_args, **test_kwargs)
        return wrapped_test
    return wrapped_pre


def check_complex(value):
    return isinstance(value, dict) or isinstance(value, list)


def diff_log_entry(pre_entry, test_entry, log_name, pre_index, test_index, changed_dict):
    diff = u'@@ -{pre_index} +{test_index} @@ log name: {log_name}\n'.format(log_name=log_name, pre_index=pre_index, test_index=test_index)
    pre_keys, test_keys = wrap_pre_test(lambda x: set(x.iterkeys()))(pre_entry)(test_entry)
    pre_unique_keys, test_unique_keys = wrap_pre_test(lambda x, y: x - y)(pre_keys, test_keys)(test_keys, pre_keys)
    common_keys = pre_keys & test_keys

    def generate_diff_unique_keys(unique_keys, diff_symbol, changed_dict):
        for key in unique_keys:
            changed_dict[diff_symbol].add(key)

    (
        wrap_pre_test(generate_diff_unique_keys)
        (pre_unique_keys, '-', changed_dict)
        (test_unique_keys, '+', changed_dict)
    )

    for key in common_keys:
        if key == '___MD5___':
            pass
        else:
            if pre_entry[key] != test_entry[key]:
                changed_dict['!'].add(key)
    diff += u'-[{log_name}][{pre_index}]: {joined_pre_entry}\n'.format(
        log_name=log_name,
        pre_index=pre_index,
        joined_pre_entry=join_log_entry(pre_entry)
    )
    diff += u'+[{log_name}][{test_index}]: {joined_test_entry}\n\n'.format(
        log_name=log_name,
        test_index=test_index,
        joined_test_entry=join_log_entry(test_entry)
    )
    return diff


def get_md5_general(log):
    return log[-1]['___MD5___']


def get_md5_dict(log):
    return_dict = defaultdict(set)
    for index, item in enumerate(log[:-1]):
        return_dict[item['___MD5___']].add(index)
    return return_dict


def iterate_unique_entries(md5_dict, common_md5_set):
    for key, index_set in md5_dict.iteritems():
        if key not in common_md5_set:
            for index in index_set:
                yield key, index


def diff_log(pre_log, test_log, log_name):
    changed_dict = {'+': set(), '-': set(), '!': set(), 'len': 0}
    diff = u'\n'
    if get_md5_general(pre_log) == get_md5_general(test_log):
        return '', changed_dict
    pre_md5_dict, test_md5_dict = wrap_pre_test(get_md5_dict)(pre_log)(test_log)
    common_md5_set = set(pre_md5_dict.iterkeys()) & set(test_md5_dict.iterkeys())
    unique_pre_sorted, unique_test_sorted = (
        wrap_pre_test(lambda x: sorted(iterate_unique_entries(x, common_md5_set), key=lambda x: x[1]))
        (pre_md5_dict)
        (test_md5_dict)
    )
    min_len = min(len(unique_pre_sorted), len(unique_test_sorted))
    max_len = max(len(unique_pre_sorted), len(unique_test_sorted))
    for ref_index in range(min_len):
        diff += diff_log_entry(
            pre_log[unique_pre_sorted[ref_index][1]],
            test_log[unique_test_sorted[ref_index][1]],
            log_name,
            unique_pre_sorted[ref_index][1],
            unique_test_sorted[ref_index][1],
            changed_dict
        )
    if min_len != max_len:
        diff += u'@@ -0,0 +0,{cnt} @@ log name: {log_name}, unique entry diff\n'.format(log_name=log_name, cnt=max_len - min_len)
    for ref_index in range(min_len, max_len):
        if ref_index >= len(unique_pre_sorted):
            true_index = unique_test_sorted[ref_index][1]
            diff_symbol = '+'
            diff_log = test_log
        elif ref_index >= len(unique_test_sorted):
            true_index = unique_pre_sorted[ref_index][1]
            diff_symbol = '-'
            diff_log = pre_log
        diff += u'{diff_symbol}[{log_name}][{true_index}]: {joined_log_entry}\n'.format(
            diff_symbol=diff_symbol,
            log_name=log_name,
            true_index=true_index,
            joined_log_entry=join_log_entry(diff_log[true_index])
        )
    changed_dict['len'] = len(test_log) - len(pre_log)
    return diff, changed_dict


def join_log_entry(entry):
    log_entry = []
    for key, value in sorted(entry.iteritems(), key=lambda x: x[0]):
        if key != '___MD5___':
            log_entry.append(u'{}={}'.format(key, decode_value(value)))
    return u', '.join(log_entry)


def decode_value(value):
    if isinstance(value, dict) and set(value.iterkeys()) == {'$data', '$encoding'} and value['$encoding'] == 'base64':
        value = b64decode(value['$data'])
    return value if isinstance(value, six.text_type) else six.text_type(str(value), 'utf-8', 'replace')


def unique_logs_diff(unique_log_set, diff_symbol):
    return u''.join('{diff_symbol}[{log}]\n'.format(diff_symbol=diff_symbol, log=log) for log in unique_log_set)


def diff_logs(pre_logs, test_logs):
    diff = u''
    pre_logs_set = set(pre_logs.iterkeys())
    test_logs_set = set(test_logs.iterkeys())
    common_logs_set = pre_logs_set & test_logs_set
    if pre_logs_set ^ test_logs_set:
        unique_logs = {'pre': pre_logs_set - test_logs_set, 'test': test_logs_set - pre_logs_set}
        diff += u' Unique log diff:\n'
        diff += u''.join(
            wrap_pre_test(unique_logs_diff)
            (unique_logs['pre'], '-')
            (unique_logs['test'], '+')
        ) + '\n'
    else:
        unique_logs = {'pre': set(), 'test': set()}
    changed_logs = {}
    for log_name in common_logs_set:
        diff_log_str, changed_log = diff_log(pre_logs[log_name], test_logs[log_name], log_name)
        if diff_log_str:
            diff += '\n' + diff_log_str
            changed_logs[log_name] = changed_log
    return diff, unique_logs, changed_logs
