import logging
import os
from multiprocessing import Pool, cpu_count

from response_differ import ResponseDiffer

WORKING_DIR_TEMPLATE = 'comparison_tmp_{}'
CHUNK_SIZE = 128

logger = logging.getLogger(__name__)


def _compare_result(
    request_id,
    response_pre,
    response_test,
    headers_to_replace,
    body_substitutes,
    json_keys_to_delete,
    xml_keys_to_delete,
    css_substitutes,
    html_substitutes,
    compare_bodies,
    window_size,
    base64_prefixes,
    html_tags,
    html_tags_remove_to_convert_xml,
    filter_by_pre_headers,
    filter_by_test_headers,
    json_padding_name,
    html_differ_path,
    html_differ_output_path,
):
    return ResponseDiffer(
        headers_to_replace,
        body_substitutes,
        json_keys_to_delete,
        base64_prefixes,
        xml_keys_to_delete,
        css_substitutes,
        html_substitutes,
        WORKING_DIR_TEMPLATE.format(os.getpid()),
        compare_bodies,
        window_size,
        html_differ_path,
        os.path.join(html_differ_output_path, str(request_id)) if html_differ_output_path is not None else None,
        html_tags,
        html_tags_remove_to_convert_xml,
        filter_by_pre_headers,
        filter_by_test_headers,
        json_padding_name
    ).diff_response_data(
        response_pre,
        response_test,
    )


def _compare_result_wrap(args):
    try:
        return _compare_result(*args)
    except:
        logger.exception('Failed to compare result')
        raise


def compare_results(pre_data, test_data, n_jobs, *args):
    pre_request_ids = set(pre_data.keys())
    test_request_ids = set(test_data.keys())
    request_ids = pre_request_ids & test_request_ids

    if len(request_ids) < len(pre_request_ids):
        logger.info('ATTENTION!!! Missing test ids in test: {}'.format(list(pre_request_ids - request_ids)))
    if len(request_ids) < len(test_request_ids):
        logger.info('ATTENTION!!! Missing test ids in pre: {}'.format(list(test_request_ids - request_ids)))

    compare_function_args = (
        (request_id, pre_data[request_id]['Data'], test_data[request_id]['Data']) + tuple(args) for request_id in list(request_ids)
    )

    if n_jobs == 1:
        results = map(_compare_result_wrap, compare_function_args)
    else:
        n_jobs = n_jobs if n_jobs is not None else int(cpu_count() * 1.5)
        pool = Pool(n_jobs)
        results = pool.map(_compare_result_wrap, compare_function_args, CHUNK_SIZE)
        pool.close()
        pool.join()

    return dict(zip(request_ids, list(results)))
