# -*- coding: utf-8 -*-

from six.moves import zip

from sandbox.projects.common.differ.ut import common
from sandbox.projects.common.differ import differ
from sandbox.projects.common.differ import printers


class PrinterMock(printers.PrinterBase):
    def __init__(self, supports_parallel, join_parallel_subprinters_count=0):
        self.supports_parallel = supports_parallel
        self.join_parallel_subprinters_count = join_parallel_subprinters_count
        self.on_new_pair_calls = []
        self.finalize_calls = 0
        self.create_subprinter_calls = 0
        self.join_parallel_subprinters_calls = 0
        self.is_subprinter = False

    def _do_indent(self, i, diff_class):
        pass

    def _print_info(self, line, diff_class):
        pass

    def on_new_pair(self, title):
        self.on_new_pair_calls.append(title)

    def finalize(self):
        self.finalize_calls += 1

    @property
    def supports_parallel_processing(self):
        return self.supports_parallel

    def get_parallel_subprinter(self):
        self.create_subprinter_calls += 1
        assert self.supports_parallel_processing
        subprinter = PrinterMock(False)
        subprinter.is_subprinter = True
        return subprinter

    def join_parallel_subprinters(self, subprinters):
        self.join_parallel_subprinters_calls += 1
        assert self.supports_parallel_processing
        count = 0
        for subprinter in subprinters:
            count += 1
            assert subprinter.is_subprinter
        assert self.create_subprinter_calls == count


class DifferMock(differ.DifferBase):
    compare_single_calls = []
    has_diff_ret = False

    def _do_compare_single(self, obj1, obj2, title):
        self.compare_single_calls.append((obj1, obj2, title))
        return self.has_diff_ret


class TestDifferBase(object):
    def test_compare_single_pair(self):
        printer = PrinterMock(False)
        differ = DifferMock(printer)
        differ.compare_single_pair(1, 2, "title")

        common.assert_equal(differ.compare_single_calls, [(1, 2, "title")])
        common.assert_equal(printer.on_new_pair_calls, ["title"])
        common.assert_equal(printer.finalize_calls, 0)
        common.assert_equal(differ.get_diff_count(), 0)

        differ.has_diff_ret = True
        differ.compare_single_pair(3, 4, "name")

        common.assert_equal(differ.compare_single_calls, [(1, 2, "title"), (3, 4, "name")])
        common.assert_equal(printer.on_new_pair_calls, ["title", "name"])
        common.assert_equal(printer.finalize_calls, 0)
        common.assert_equal(differ.get_diff_count(), 1)

    def test_does_not_using_parallel_processing(self):
        printer = PrinterMock(True)
        differ = DifferMock(printer)
        cmp_data = zip(
            [1, 2],
            [1, 2],
            ["1", "2"]
        )
        differ.compare_pairs(cmp_data, False)

        common.assert_equal(printer.create_subprinter_calls, 0)
        common.assert_equal(printer.finalize_calls, 1)

    def test_does_not_call_finalize(self, do_in_parallel):
        printer = PrinterMock(do_in_parallel)
        differ = DifferMock(printer)
        cmp_data = zip(
            [1, 2],
            [1, 2],
            ["1", "2"]
        )
        differ.compare_pairs(cmp_data, do_in_parallel, False)

        common.assert_equal(printer.finalize_calls, 0)

    def test_compare_pairs(self, cmp_parallel):
        printer = PrinterMock(cmp_parallel)
        differ = DifferMock(printer)
        cmp_data = zip(
            [1, 2],
            [1, 2],
            ["1", "2"]
        )

        differ.compare_pairs(cmp_data)

        common.assert_equal(len(printer.on_new_pair_calls), 0 if cmp_parallel else 2)
        common.assert_equal(printer.finalize_calls, 1)
        if cmp_parallel:
            common.assert_equal(printer.create_subprinter_calls, 2)
            common.assert_equal(printer.join_parallel_subprinters_calls, 1)

    def test_all(self):
        self.test_compare_single_pair()
        self.test_does_not_using_parallel_processing()
        self.test_does_not_call_finalize(False)
        self.test_does_not_call_finalize(True)
        self.test_compare_pairs(False)
        self.test_compare_pairs(True)


def test_all():
    test = TestDifferBase()
    test.test_all()
