import json
import logging
import os
import re
from collections import defaultdict
from multiprocessing import Pool, cpu_count
from urlparse import urlparse

from sandbox.projects.yabs.qa.tasks.YabsServerB2BFuncShootCmp.utils.compare import abbreviate_diff, REPORT_CHUNKS_COUNT

logger = logging.getLogger(__name__)


def _process_single_case(test_id, request, pre_data, test_data, diff_data):
    diff = 'Status diff:\n' + diff_data['StatusDiff'] + '\n' if len(diff_data['StatusDiff']) else ''
    diff += 'Headers diff:\n' + diff_data['HeadersDiff'] + '\n' if len(diff_data['HeadersDiff']) else ''
    diff += 'Entities diff:\n' + diff_data['EntitiesDiff'] + '\n' if len(diff_data['EntitiesDiff']) else ''
    diff += 'Exts diff:\n' + diff_data['ExtsDiff'] + '\n' if len(diff_data['ExtsDiff']) else ''
    diff += 'Logs diff:\n' + diff_data['LogsDiff'] + '\n' if len(diff_data['LogsDiff']) else ''

    tags = set(['StatusDiff']) if len(diff_data['StatusDiff']) else set()
    tags |= set(['HeadersDiff']) if len(diff_data['HeadersDiff']) else set()
    tags |= set(['EntitiesDiff']) if len(diff_data['EntitiesDiff']) else set()
    tags |= set(['ExtsDiff']) if len(diff_data['ExtsDiff']) else set()
    tags |= set(['LogsDiff']) if len(diff_data['LogsDiff']) else set()

    tags |= set([diff_data['ProcessingType']])

    url = pre_data['Url']
    url_path = urlparse(url).path

    versioned_handler_match = re.match(r'^/v\d+/\w+', url_path)
    if versioned_handler_match:
        url_path = versioned_handler_match.group(0)
    report_handler_match = re.match(r'^/(\w*/)?report', url_path)
    if report_handler_match:
        url_path = report_handler_match.group(0)

    try:
        diff_out = {'diff': abbreviate_diff(diff)}
        json.dumps(diff_out)
    except:
        diff_out = 'Failed to print diff preview in UTF-8. Take a look on full diff.'

    try:
        json.dumps(diff)
    except:
        diff = 'Failed to print full diff in UTF-8. Look for it diff table.'

    test_out = {
        'status': ('failed' if diff_data['HasDiff'] else 'passed'),
        'id': int(test_id),
        'diff': diff,
        'request': request,
        'name': str(int(test_id)),
        'ft_shoot': {
            'pre.code': pre_data['HttpCode'],
            'test.code': test_data['HttpCode'],
            'handler': url_path,
        },
    }

    result = {
        'test_id': str(int(test_id)),
        'has_diff': diff_data['HasDiff'],
        'unique_changed_logs': {'unique': {'pre': set(), 'test': set()}, 'changed': {}},
        'validation_results': {},
        'log_statistics': {},
        'pre_code': pre_data['HttpCode'],
        'test_code': test_data['HttpCode'],
        'handler': url_path,
        'tags': tags,
    }

    return result, test_out, diff_out


def _process_chunk(chunk, tests_dir, diffs_dir):
    chunk_ids = [int(case_data[0]) % REPORT_CHUNKS_COUNT for case_data in chunk]
    if len(set(chunk_ids)) != 1:
        raise RuntimeError("Malformed chunk: %s" % chunk_ids)
    chunk_filename = '{}.json'.format(chunk_ids[0])

    results = []
    tests = {}
    diffs = {}

    for case_data in chunk:
        result, test, diff = _process_single_case(*case_data)
        test_id = case_data[0]
        results.append(result)
        tests[str(int(test_id))] = test
        diffs[str(int(test_id))] = diff

    with open(os.path.join(tests_dir, chunk_filename), 'w') as f:
        json.dump(tests, f)
    with open(os.path.join(diffs_dir, chunk_filename), 'w') as f:
        json.dump(diffs, f)

    return results


def _process_chunk_wrap(args):
    try:
        return _process_chunk(*args)
    except:
        logger.exception("Failed to process result chunks")
        raise


def _split_to_chunks(requests, pre_data, test_data, diff_data, bad_requests_ids):
    test_ids = [id for id in requests.keys() if id not in bad_requests_ids]

    chunks = defaultdict(list)
    for test_id in test_ids:
        chunks[int(test_id) % REPORT_CHUNKS_COUNT].append(
            (test_id, requests[test_id], pre_data[test_id], test_data[test_id], diff_data[test_id]),
        )

    return chunks


def process_results(requests, pre_data, test_data, diff_data, bad_requests_ids, n_jobs, report_dir):
    chunks = _split_to_chunks(requests, pre_data, test_data, diff_data, bad_requests_ids)

    tests_dir = os.path.join(report_dir, 'tests')
    if not os.path.exists(tests_dir):
        os.mkdir(tests_dir)

    diffs_dir = os.path.join(report_dir, 'diff')
    if not os.path.exists(diffs_dir):
        os.mkdir(diffs_dir)

    process_function_args = (
        (chunk, tests_dir, diffs_dir) for chunk in chunks.itervalues()
    )

    if n_jobs == 1:
        result_chunks = map(_process_chunk_wrap, process_function_args)
    else:
        n_jobs = n_jobs if n_jobs is not None else int(cpu_count() * 1.5)
        pool = Pool(n_jobs)
        result_chunks = pool.map(_process_chunk_wrap, process_function_args)
        pool.close()
        pool.join()

    results = []
    for res_chunk in result_chunks:
        results.extend(res_chunk)
    return results
