from __future__ import division, print_function, absolute_import

from collections import defaultdict
import json
import logging
from multiprocessing import Pool, cpu_count
import numpy as np
import os


logger = logging.getLogger(__name__)
logging.basicConfig(level=logging.DEBUG)


class Aggregator(object):
    def __init__(self, key_fields):
        self.key_fields = key_fields

    def make_key(self, dictionary):
        return ';'.join((
            '{}'.format(dictionary[field])
            for field in self.key_fields
        ))

    def make_caption(self):
        return ';'.join(
            key_field
            for key_field in self.key_fields
        )


class TrafaretRankDiffer(object):
    def __init__(self, key_fields_lst):
        self.aggregators = [
            Aggregator(key_fields)
            for key_fields in key_fields_lst
        ]
        self.diff = {
            aggregator.make_caption(): {}
            for aggregator in self.aggregators
        }

    def __iadd__(self, other):
        if not isinstance(other, TrafaretRankDiffer):
            raise RuntimeError('Can\'t add object of type {} to TrafaretRankDiffer'.format(type(other)))
        for aggregation_id in self.diff.keys():
            for trafaret_rank_id in other.diff[aggregation_id].keys():
                if trafaret_rank_id not in self.diff[aggregation_id]:
                    self.diff[aggregation_id][trafaret_rank_id] = other.diff[aggregation_id][trafaret_rank_id]
                else:
                    for field in other.diff[aggregation_id][trafaret_rank_id].keys():
                        self.diff[aggregation_id][trafaret_rank_id][field] = np.append(
                            self.diff[aggregation_id][trafaret_rank_id][field], other.diff[aggregation_id][trafaret_rank_id][field]
                        )
        return self

    def diff_target(self, pre_target, test_target):
        for aggregator in self.aggregators:
            diff_buffer = {}
            data_fields = ('pre_bid', 'pre_cpc', 'test_bid', 'test_cpc', 'diff_bid', 'diff_cpc')

            def add_missing_id(diff_buffer, trafaret_rank_id):
                if trafaret_rank_id not in diff_buffer:
                    diff_buffer[trafaret_rank_id] = {
                        field: 0
                        for field in data_fields
                    }

            for trafaret_rank in pre_target['Clickometer']:
                trafaret_rank_id = aggregator.make_key(trafaret_rank)
                add_missing_id(diff_buffer, trafaret_rank_id)
                diff_buffer[trafaret_rank_id]['pre_bid'] += trafaret_rank['Bid']
                diff_buffer[trafaret_rank_id]['diff_bid'] -= trafaret_rank['Bid']
                diff_buffer[trafaret_rank_id]['pre_cpc'] += trafaret_rank['Cpc']
                diff_buffer[trafaret_rank_id]['diff_cpc'] -= trafaret_rank['Cpc']

            for trafaret_rank in test_target['Clickometer']:
                trafaret_rank_id = aggregator.make_key(trafaret_rank)
                add_missing_id(diff_buffer, trafaret_rank_id)
                diff_buffer[trafaret_rank_id]['test_bid'] += trafaret_rank['Bid']
                diff_buffer[trafaret_rank_id]['diff_bid'] += trafaret_rank['Bid']
                diff_buffer[trafaret_rank_id]['test_cpc'] += trafaret_rank['Cpc']
                diff_buffer[trafaret_rank_id]['diff_cpc'] += trafaret_rank['Cpc']

            aggregation_id = aggregator.make_caption()
            for trafaret_rank_id in diff_buffer.keys():
                if trafaret_rank_id not in self.diff[aggregation_id].keys():
                    self.diff[aggregation_id][trafaret_rank_id] = {
                        field: np.array([])
                        for field in data_fields
                    }
                for field in data_fields:
                    self.diff[aggregation_id][trafaret_rank_id][field] = np.append(
                        self.diff[aggregation_id][trafaret_rank_id][field], diff_buffer[trafaret_rank_id][field]
                    )

    def diff_response(self, pre_response, test_response):
        for pre_target, test_target in zip(pre_response['Targets'], test_response['Targets']):
            self.diff_target(pre_target, test_target)

    def diff_response_dir(self, pre_response_dir, test_response_dir, chunk_request_ids):
        for request_id in chunk_request_ids:
            pre_raw_response, test_raw_response = (
                open(os.path.join(directory, request_id)).read()
                for directory in (pre_response_dir, test_response_dir)
            )
            pre_json_response, test_json_response = (
                json.loads(raw_response.decode('utf-8', 'ignore'))
                for raw_response in (pre_raw_response, test_raw_response)
            )

            bad_response = True
            try:
                for json_response in (pre_json_response, test_json_response):
                    bad_response = json_response['code']['diff_data']['$data'] != '200'
                    if bad_response:
                        break
                if not bad_response:
                    pre_encoded_data, test_encoded_data = (
                        json_response['response_data']['diff_data']['$data']
                        for json_response in (pre_json_response, test_json_response)
                    )
                    pre_encoding, test_encoding = (
                        json_response['response_data']['diff_data']['$encoding']
                        for json_response in (pre_json_response, test_json_response)
                    )
                    pre_decoded_data, test_decoded_data = (
                        bytes(encoded_data).decode(encoding)
                        for encoded_data, encoding in zip(
                            (pre_encoded_data, test_encoded_data),
                            (pre_encoding, test_encoding)
                        )
                    )
                    pre_decoded_data_targets, test_decoded_data_targets = (
                        decoded_data.split('\n', 1)[0]
                        for decoded_data in (pre_decoded_data, test_decoded_data)
                    )
                    try:
                        pre_decoded_json, test_decoded_json = (
                            json.loads(decoded_data_targets)
                            for decoded_data_targets in (pre_decoded_data_targets, test_decoded_data_targets)
                        )
                        pre_targets, test_targets = (
                            decoded_json['Targets']
                            for decoded_json in (pre_decoded_json, test_decoded_json)
                        )
                        pre_response, test_response = (
                            {
                                'RequestID': request_id,
                                'Targets': targets
                            }
                            for targets in (pre_targets, test_targets)
                        )
                        self.diff_response(pre_response, test_response)
                    except ValueError:
                        pass
                    except TypeError:
                        pass
            except KeyError:
                logger.error('Can\'t parse request {}'.format(request_id))

    def _get_data(self, aggregator_id, trafaret_rank_id):
        dictionary = self.diff[aggregator_id][trafaret_rank_id]
        median_pre_bid = np.median(dictionary['pre_bid'])
        median_pre_cpc = np.median(dictionary['pre_cpc'])
        return [
            trafaret_rank_id,
            np.sum(dictionary['diff_bid']) / np.sum(dictionary['pre_bid']),
            np.sum(dictionary['diff_cpc']) / np.sum(dictionary['pre_cpc']),
            (np.median(dictionary['test_bid']) - median_pre_bid) / median_pre_bid,
            (np.median(dictionary['test_cpc']) - median_pre_cpc) / median_pre_cpc,
            int(np.sum(dictionary['diff_bid'])),
            int(np.sum(dictionary['diff_cpc'])),
            int(np.sum(dictionary['pre_bid'])),
            int(np.sum(dictionary['pre_cpc']))
        ]

    def __str__(self):
        HTML_BEGINNING = '''
<!DOCTYPE html>
<html>
<head>
    <title>Trafaret Rank Report</title>
    <style>
        body {
            font-family: Helvetica, Arial, sans-serif;
            font-size: 16px;
            color: #34495e;
        }
        table {
            width: 80%;
            margin: 0 auto;
        }
        .table-header {
            font-size: 1.1em;
            padding: 10px;
            height: 20px;
            width: 100%;
            position: sticky;
            top: 0;
            background-color: #FEEC8D;
        }
        .table-row {
            padding:10px;
            background-color: #F9FCFA
        }
        .table-row:hover {
            background-color: #E3E3E3
        }
    </style>
</head>
<body>
'''

        HTML_ENDING = '''
</body>
<script>
function sortTable(n, table_n) {
  var table, rows, switching, i, x, y, shouldSwitch, dir, switchcount = 0;
  table = document.getElementById("table_" + table_n);
  switching = true;
  dir = "asc";
  while (switching) {
    switching = false;
    rows = table.rows;
    for (i = 1; i < (rows.length - 1); i++) {
      shouldSwitch = false;
      x = rows[i].getElementsByTagName("TD")[n];
      y = rows[i + 1].getElementsByTagName("TD")[n];
      if (dir == "asc") {
        if (parseFloat(x.innerHTML.toLowerCase()) > parseFloat(y.innerHTML.toLowerCase())) {
          shouldSwitch= true;
          break;
        }
      } else if (dir == "desc") {
        if (parseFloat(x.innerHTML.toLowerCase()) < parseFloat(y.innerHTML.toLowerCase())) {
          shouldSwitch = true;
          break;
        }
      }
    }
    if (shouldSwitch) {
      rows[i].parentNode.insertBefore(rows[i + 1], rows[i]);
      switching = true;
      switchcount ++;
    } else {
      if (switchcount == 0 && dir == "asc") {
        dir = "desc";
        switching = true;
      }
    }
  }
}
</script>
</hmtl>
'''
        result = HTML_BEGINNING
        table_fields = ["AVG BID", "AVG CPC", "MED BID", "MED CPC", "TOTAL BID DIFF", "TOTAL CPC DIFF", "TOTAL PRE BID", "TOTAL PRE CPC"]
        for table_id, aggregator in enumerate(self.aggregators):
            aggregator_id = aggregator.make_caption()

            result += '<table id="table_{:d}">'.format(table_id)
            result += '<tr class="table-header">'
            for header_index, header_name in enumerate([aggregator_id] + table_fields):
                result += '<th onclick="sortTable({:d},{:d})">{}</th>'.format(header_index, table_id, header_name)
            result += '</tr>\n'

            for trafaret_rank_id in self.diff[aggregator_id].keys():
                result += '<tr class="table-row"><td>'
                result += '</td><td>'.join(map(
                    lambda x: '{:.6f}'.format(x) if isinstance(x, float) else str(x),
                    self._get_data(aggregator_id, trafaret_rank_id)
                ))
                result += '</td></tr>\n'

            result += '</table><br>\n'

        result += HTML_ENDING
        return result


def compare_trafaret_chunk(args_tuple):
    try:
        pre_path, test_path, chunk, key_fields_lst = args_tuple
        differ = TrafaretRankDiffer(key_fields_lst)
        differ.diff_response_dir(pre_path, test_path, chunk)
        return differ
    except Exception:
        logger.exception("Failed to compare bsrank result chunks")
        raise


def compare_trafaret_ranks(pre_path, test_path, test_ids, bad_requests_ids, key_fields_lst, n_jobs, report_chunk_size):
    chunks = defaultdict(list)

    filter_test_ids = filter(lambda test_id: int(test_id) not in bad_requests_ids, test_ids)
    for test_id in filter_test_ids:
        chunks[int(test_id) / report_chunk_size].append(test_id)

    compare_args = (
        (pre_path, test_path, chunk, key_fields_lst)
        for chunk in chunks.itervalues()
    )
    if n_jobs == 1:
        result_differs = map(compare_trafaret_chunk, compare_args)
    else:
        n_jobs = n_jobs if n_jobs is not None else int(cpu_count() * 1.5)
        process_pool = Pool(n_jobs)
        result_differs = process_pool.map(compare_trafaret_chunk, compare_args)
        process_pool.close()
        process_pool.join()
    joined_differ = result_differs[0]
    for result_differ in result_differs[1:]:
        joined_differ += result_differ

    return joined_differ
