# -*- coding: utf-8 -*-

import itertools
import logging
import multiprocessing
import traceback
from six.moves import zip, map


class DifferBase(object):
    """
        Base class for comparing sequence of object pairs.

        To implement your own differ you need to process the following steps.
        1. Derive from DifferBase
        2. Implement _do_compare_single() method
        3. Implement _finalize_printer() method if you want to perform additional
           work during printer finalization step
    """

    def __init__(self, output_printer=None):
        """
            Differ constructor
            * output_printer - printer with PrinterBase interface. Can be set in set_output_printer() method
        """
        self._output_printer = output_printer
        self._diff_count = 0

    def set_printer(self, output_printer):
        """
            Used for setting new printer into differ in parallel mode.
            New printer needs new differ, so
        """
        self._output_printer = output_printer

    def compare_single_pair(self, obj1, obj2, title=""):
        """
            Compares one pair of objects
            If one uses this method, it is his responsibility to call printer.finalize()
            * obj1, obj2 - objects to compare
            * title - title of the pair
        """
        self._output_printer.on_new_pair(title)
        has_diff = self._do_compare_single(obj1, obj2, title)
        if has_diff:
            self._diff_count += 1

    def compare_pairs(self, cmp_data_generator, process_parallel=True, finalize_printer=True):
        """
            Compare all pairs from given sequence
            * cmp_data_generator - generator that generates tuples (obj1, obj2, title)
            * process_parallel - should process parallel if printer supports parallel processing
        """
        logging.info("Start comparison. Parallel = %s", process_parallel)
        cmp_parallel = process_parallel and self._output_printer.supports_parallel_processing
        if cmp_parallel:
            self._do_compare_parallel(cmp_data_generator)
        else:
            self._do_compare_sequentially(cmp_data_generator)

        if finalize_printer:
            self._finalize_printer()

    def get_diff_count(self):
        """
            Returns count of the given pairs of objects that are not equal
        """
        return self._diff_count

    def _do_compare_single(self, obj1, obj2, title):
        """
            Compare function implementation
            Must return has_diff flag
        """
        raise NotImplementedError

    def _finalize_printer(self):
        """
            Printer finalization
            You can do additional work in this method if needed
        """
        self._output_printer.finalize()

    #
    # base class implementation
    #

    def _do_compare_single_packed(self, cmp_data):
        obj1, obj2, title = cmp_data

        self._output_printer.on_new_pair(title)
        has_diff = self._do_compare_single(obj1, obj2, title)
        self._finalize_printer()
        return has_diff

    def _do_compare_parallel(self, cmp_data_generator):
        cpu_count = multiprocessing.cpu_count()
        logging.debug("Comparing parallel with {} threads".format(cpu_count))

        pool = multiprocessing.Pool(processes=cpu_count)

        cmp_data_bundle = zip(
            cmp_data_generator,
            map(_get_parallel_subprinter, itertools.repeat(self._output_printer)),
            itertools.repeat(self),
        )
        try:
            subprinters = []
            for printer, has_diff in pool.imap(_parallel_processor_function, cmp_data_bundle):
                if has_diff:
                    self._diff_count += 1
                subprinters.append(printer)

            self._output_printer.join_parallel_subprinters(subprinters)
        finally:
            pool.close()
            pool.join()

    def _do_compare_sequentially(self, cmp_data_generator):
        for cmp_data in cmp_data_generator:
            obj1, obj2, title = cmp_data
            self.compare_single_pair(obj1, obj2, title)


def _get_parallel_subprinter(printer):
    subprinter = printer.get_parallel_subprinter()
    return subprinter

#
# This function must be pickable
#


def _parallel_processor_function(args):
    try:
        cmp_data, printer, differ = args
        differ.set_printer(printer)
        has_diff = differ._do_compare_single_packed(cmp_data)
        return printer, has_diff
    except Exception as ex:
        logging.error("Exception in parallel processor: {}.\n{}".format(str(ex), traceback.format_exc()))
        raise
