# -*- coding: utf-8 -*-
import json
import logging
from six.moves import range

# TODO: move json comparation code from custom_proto_comparators to this file, write tests
from sandbox.projects.common import string
from sandbox.projects.common.differ import differ
from sandbox.projects.common.differ import printers


class JsonDiffer(differ.DifferBase):
    """Compare sequence of jsons."""

    def _do_compare_single(self, obj1, obj2, title):
        diff_lines_before = self._output_printer.get_diff_lines_count()
        cmp_json(obj1, obj2, self._output_printer, True)
        return self._output_printer.get_diff_lines_count() > diff_lines_before


def cmp_json(string1, string2, output, cmp_lists=False):
    """Compare json strings"""
    try:
        item1 = json.loads(string1, encoding='utf-8')
        item2 = json.loads(string2, encoding='utf-8')
    except Exception as e:
        logging.debug("Can't compare strings\n%s\nand\n%s\nas jsons:\n%s", string1, string2, e)
        cmp_strings(string1, string2, output)
        return
    _cmp_value(item1, item2, output, cmp_lists)


def _cmp_value(item1, item2, output, cmp_lists):
    if isinstance(item1, dict) and isinstance(item2, dict):
        _cmp_dicts(item1, item2, output, cmp_lists)
    elif cmp_lists and isinstance(item1, list) and isinstance(item2, list):
        _cmp_lists(item1, item2, output)
    else:
        cmp_strings(string.all_to_str(item1), string.all_to_str(item2), output)


def _cmp_dicts(dict_1, dict_2, output, cmp_lists):
    keys_1, keys_2 = set(dict_1.keys()), set(dict_2.keys())
    added = sorted(keys_2 - keys_1)
    removed = sorted(keys_1 - keys_2)
    changed = sorted(keys_1 & keys_2)
    for k in added:
        output.sched(k)
        _output_value(dict_2[k], output, printers.DiffType.ADDED)
        output.desched()
    for k in removed:
        output.sched(k)
        _output_value(dict_1[k], output, printers.DiffType.REMOVED)
        output.desched()
    for k in changed:
        output.sched(k)
        item1 = dict_1[k]
        item2 = dict_2[k]
        _cmp_value(item1, item2, output, cmp_lists)
        output.desched()


def cmp_strings(string_1, string_2, output):
    if string_1 != string_2:
        output("{} -> {}".format(string_1, string_2), printers.DiffType.CHANGED)


def _output_value(item, output, diff_class):
    if isinstance(item, dict):
        for k in sorted(item.keys()):
            output.sched(k)
            _output_value(item[k], output, diff_class)
            output.desched()
    elif isinstance(item, list):
        for i in range(0, len(item)):
            output.sched(str(i))
            _output_value(item[i], output, diff_class)
            output.desched()
    else:
        output(string.all_to_str(item), diff_class)


def _cmp_lists(list_1, list_2, output):
    l_1 = len(list_1)
    l_2 = len(list_2)
    for i in range(0, min(l_1, l_2)):
        output.sched(str(i))
        _cmp_value(list_1[i], list_2[i], output, True)
        output.desched()

    if l_1 != l_2:
        diff_class = printers.DiffType.ADDED if l_2 > l_1 else printers.DiffType.REMOVED
        bigger_list = list_2 if l_2 > l_1 else list_1
        for i in range(min(l_1, l_2), max(l_1, l_2)):
            output.sched(str(i))
            _output_value(bigger_list[i], output, diff_class)
            output.desched()
