import math
import google.protobuf
import re
import logging

from collections import defaultdict
from multiprocessing import Pool
from functools import partial
from multiprocessing.pool import ThreadPool

from sandbox.projects.common.base_search_quality.tree import meta_pb2


delimiter = re.compile("\-{20,100}\d+")


class Result(object):
    def __init__(self, all_factors, fresh, mango, fstrs):
        self.Fstrs = fstrs

        self.AllFactorsDiff = self.get_diff_top(all_factors)
        self.AllFactorsZeroed = self.get_zeroed_top(all_factors)

        self.MangoDiff = self.get_simple_diff_top(mango)
        self.MangoZeroed = self.get_simple_zeroed(mango)

        self.FreshDiff = self.get_simple_diff_top(fresh)
        self.FreshZeroed = self.get_simple_zeroed(fresh)

    def get_simple_diff_top(self, diff_list):
        top = sorted(diff_list, key=lambda diff: diff.DiffSortValue, reverse=True)
        top = [(diff.FactorName,
                diff.FactorNumber,
                diff.Diff,
                diff.DiffPercent,
                diff.DiffWeight,
                diff.DiffSortValue)
               for diff in top if diff.DiffPercent]

        return top

    def get_simple_zeroed(self, diff_list):
        top = sorted(diff_list, key=lambda diff: diff.ZeroesPercent, reverse=True)
        top = [
            (
                diff.FactorName,
                diff.FactorNumber,
                diff.Zeroes,
                diff.ZeroesPercent,
            ) for diff in top if diff.ZeroesPercent]

        return top

    def get_zeroed_top(self, diff_list):
        top = sorted(diff_list, key=lambda diff: diff.ZeroesPercent, reverse=True)
        top = [
            (
                diff.FactorName,
                diff.FactorNumber,
                self.Fstrs.get("Ru.info", {}).get(unicode(diff.FactorNumber), "UNDEFINED"),
                self.Fstrs.get("RuFresh.info", {}).get(unicode(diff.FactorNumber), "UNDEFINED"),
                self.Fstrs.get("Tr.info", {}).get(unicode(diff.FactorNumber), "UNDEFINED"),
                self.Fstrs.get("TrFresh.info", {}).get(unicode(diff.FactorNumber), "UNDEFINED"),
                diff.Zeroes, diff.ZeroesPercent,
            ) for diff in top if diff.ZeroesPercent
        ]

        return top

    def get_diff_top(self, diff_list):
        top = sorted(diff_list, key=lambda diff: diff.DiffSortValue, reverse=True)
        top = [(diff.FactorName,
                diff.FactorNumber,
                self.Fstrs.get("Ru.info", {}).get(unicode(diff.FactorNumber), "UNDEFINED"),
                self.Fstrs.get("RuFresh.info", {}).get(unicode(diff.FactorNumber), "UNDEFINED"),
                self.Fstrs.get("Tr.info", {}).get(unicode(diff.FactorNumber), "UNDEFINED"),
                self.Fstrs.get("TrFresh.info", {}).get(unicode(diff.FactorNumber), "UNDEFINED"),
                diff.Diff,
                diff.DiffPercent,
                diff.DiffWeight,
                diff.DiffSortValue)
               for diff in top if diff.DiffPercent]

        return top


class FactorDiff(object):
    def __init__(self):
        self.Diff = 0
        self.Zeroes = 0
        self.Accumulated = 0.0
        self.DiffWeight = 0.0
        self.DiffPercent = 0.0
        self.ZeroesPercent = 0.0
        self.FactorName = None
        self.FactorNumber = None
        self.DiffSortValue = None
        self.Fstr1 = None
        self.Fstr2 = None

    def Finalize(self, factor_name, factor_number, common_urls_number):
        self.FactorName = factor_name
        self.FactorNumber = factor_number
        if self.Diff:
            self.DiffWeight = float(self.Accumulated) / self.Diff
        if common_urls_number:
            self.DiffPercent = float(self.Diff) * 100/common_urls_number
            self.ZeroesPercent = float(self.Zeroes) * 100/common_urls_number
        self.DiffSortValue = self.DiffPercent/100 * self.DiffWeight
        return self

    def __repr__(self):
        return "<FactorDiff(%s) DiffsPercent: %s, Zeroes:%s, TotalDiffs: %s>" % (
            self.FactorName, self.DiffPercent, self.ZeroesPercent, self.Diff)


def parse_response(response):
    missing_fields = [
        re.compile("ExecutionTime: \d+"),
        re.compile("SourceTimestamp: \d+"),
        re.compile("""FirstStageAttribute {
        Key: "[^}]+"
        Value: "[^}]+"
      }"""),
        re.compile("FailedPrimusList: .+\n"),
        re.compile("EventLog: .+\n"),
        re.compile("SearchInfo: .+\n"),
        re.compile("SearcherProp-scheme\.json_local_standard\.nodump \{.+\}\n", re.MULTILINE | re.DOTALL),
    ]
    for field in missing_fields:
        response = re.sub(field, "", response)

    if response.strip():
        message = meta_pb2.TReport()
        try:
            google.protobuf.text_format.Merge(response, message)
        except google.protobuf.text_format.ParseError as err:
            raise google.protobuf.text_format.ParseError("Failed to parse response (%s):\n %s\n" % (err, response))
        return message


def split_factors(string):
    return [float(f) for f in string.split(";") if f]


def extract_values(parsed_response, groupingName=None):
    result = {}
    if parsed_response and parsed_response.TotalDocCount and parsed_response.TotalDocCount[0]:
        for grouping in parsed_response.Grouping:
            if groupingName:
                if grouping.Attr != groupingName:
                    continue
            for group in grouping.Group:
                factors = []
                fresh_factors = []
                mango_factors = []
                shard = ''
                document = group.Document[0]
                for attr in document.ArchiveInfo.GtaRelatedAttribute:
                    if attr.Key == '_AllFactors':
                        factors = split_factors(attr.Value)
                    elif attr.Key == '_FreshFactors':
                        fresh_factors = split_factors(attr.Value)
                    elif attr.Key == '_MangoFactors':
                        mango_factors = split_factors(attr.Value)
                    elif attr.Key == '_Shard':
                        shard = attr.Value
                result[document.ArchiveInfo.Url] = (document.Relevance, factors, shard, fresh_factors, mango_factors)
    return result


def get_values_from_response(response, grouping=None):
    return extract_values(parse_response(response), groupingName=grouping)


def parse_file(path, grouping=None):
    with open(path) as f:
        responses = delimiter.split(f.read())
    logging.info("split file %s" % path)
    p = Pool(20)
    parsed_responses = p.map(partial(get_values_from_response, grouping=grouping), responses)
    return filter(None, parsed_responses)


def update_with_factor_diff(new_factors, prod_factors, diff):
    for number, new_value in enumerate(new_factors):
        if len(prod_factors) > number:
            prod_value = prod_factors[number]
            if new_value != prod_value:
                diff[number].Diff += 1
                diff[number].Accumulated += math.fabs(prod_value - new_value)
                if new_value == 0:
                    diff[number].Zeroes += 1
    return diff


def update_diff(factors_index, url, new, prod, diff):
    new_factors = new[url][factors_index]
    prod_factors = prod[url][factors_index]
    return update_with_factor_diff(new_factors, prod_factors, diff)


def get_full_diff(new_responses, prod_responses, factor_names):
    all_diff = defaultdict(lambda: FactorDiff())
    fresh_diff = defaultdict(lambda: FactorDiff())
    mango_diff = defaultdict(lambda: FactorDiff())
    common_count = 0
    all_factors_index = 1
    fresh_factors_index = 3
    mango_factors_index = 4

    for n, new in enumerate(new_responses):
        if len(prod_responses) > n:
            prod = prod_responses[n]

            for url in new:
                if url in prod:

                    common_count += 1
                    update_diff(all_factors_index, url, new, prod, all_diff)
                    update_diff(fresh_factors_index, url, new, prod, fresh_diff)
                    update_diff(mango_factors_index, url, new, prod, mango_diff)

    diffs_only = [value.Finalize(factor_names[key], key, common_count) for key, value in all_diff.items() if value.Diff]
    fresh_diffs_only = [value.Finalize("Factor %s" % key, key, common_count) for key, value in fresh_diff.items() if value.Diff]
    mango_diffs_only = [value.Finalize("Factor %s" % key, key, common_count) for key, value in mango_diff.items() if value.Diff]

    return diffs_only, fresh_diffs_only, mango_diffs_only


def compare_factors(new_file, prod_file, factor_names, grouping=None):
    p = ThreadPool(2)
    parse_file_func = partial(parse_file, grouping=grouping)
    new_responses, prod_responses = p.map(parse_file_func, [new_file, prod_file])
    return get_full_diff(new_responses, prod_responses, read_factor_names(factor_names))


def read_factor_names(path):
    result = {}
    with open(path) as f:
        for line in f:
            line = line.strip()
            if line:
                try:
                    number, name = line.split("\t")
                except ValueError, e:
                    logging.error("Failed to parse line '%s': %s" % (line, e))
                    number = int(line)
                    name = "UNKNOWN_FACTOR_%s" % number

                result[int(number)] = name
    return result


def get_all_diff_stats(new_responses, old_responses, factor_names, fstrs, grouping=None):
    all_factors, fresh, mango = compare_factors(new_responses, old_responses, factor_names, grouping=grouping)
    return Result(all_factors, fresh, mango, fstrs)
