import itertools
import logging
import math
import re
from textwrap import dedent

from .constants import CONTROL_EXP_ID, COST_EPS, STDDEV_K


def parse_apc_check_output(out, use_sigma=False):
    if use_sigma:
        pattern = re.compile(r"p [0-9\.]+, avrg ([0-9\.]+), \[[0-9/.]+, [0-9/.]+\], sigma avrg ([0-9\.]+).*$")
    else:
        pattern = re.compile(r"p [0-9\.]+, avrg ([0-9\.]+), \[([0-9/.]+), ([0-9/.]+)\].*$")

    match = pattern.search(out)

    if match:
        if use_sigma:
            return float(match.group(1)), float(match.group(2)) / 100.0
        else:
            return float(match.group(1)), (float(match.group(3)) - float(match.group(2))) / 2.0
    else:
        logging.error("Failed to parse apc_check output: {}".format(out))
        return 0.0, 0.0


def print_zc(zc_res, exp_stats=None, etalon_exp=None, html=False):
    rows = []
    columns = ["experiment", "action", "zC"]

    if etalon_exp is not None:
        columns.append("normalized zC")

    if exp_stats:
        exp_stats_dict = {elem["ExpID"]: {k: v for (k, v) in elem.items() if k != "ExpID"} for elem in exp_stats}
        columns.extend(("shows_d", "cost_d", "(zC * cost)_d"))

        if etalon_exp:
            columns.append("(zC_norm * cost)_d")
    else:
        exp_stats_dict = dict()

    for ((exp, action), (zc_mean, zc_std, action_view)) in sorted(zc_res.items()):
        row = {
            "experiment": (exp, None, None, None),
            "action": (str(action_view), None, None, None),
            "zC": (zc_mean, 1.0, zc_std, True)
        }

        if etalon_exp is not None:
            if exp != etalon_exp:
                etalon_zc_mean, etalon_zc_std, _ = zc_res[(etalon_exp, action)]

                if zc_mean and zc_std and etalon_zc_mean and etalon_zc_std:
                    ratio_mean, ratio_std = _calculate_ratio_with_std(zc_mean, zc_std, etalon_zc_mean, etalon_zc_std)
                    row["normalized zC"] = (ratio_mean, 1.0, ratio_std, True)
                else:
                    row["normalized zC"] = (None, 1.0, None, True)
            else:
                row["normalized zC"] = (1.0, 1.0, 0.0, True)

        if exp_stats_dict and etalon_exp is not None:
            if exp == etalon_exp:
                row["shows_d"] = (0.0, 0.0, 0.0, True)
                row["cost_d"] = (0.0, 0.0, 0.0, True)
                row["(zC * cost)_d"] = (1.0, 1.0, 0.0, True)
                row["(zC_norm * cost)_d"] = (1.0, 1.0, 0.0, True)
            else:
                exp_stat = exp_stats_dict[exp]
                row["shows_d"] = (exp_stat["shows_d"], 0.0, exp_stat["shows_d_sigma"], True)
                row["cost_d"] = (exp_stat["cost_d"], 0.0, exp_stat["cost_d_sigma"], True)
                zc_cost_mean, zc_cost_std = _calculate_product_with_std(
                    1.0 + exp_stat["cost_d"] / 100.0, exp_stat["cost_d_sigma"] / 100.0,
                    zc_mean, zc_std
                )
                zc_norm_cost_mean, zc_norm_cost_std = _calculate_product_with_std(
                    1.0 + exp_stat["cost_d"] / 100.0, exp_stat["cost_d_sigma"] / 100.0,
                    row["normalized zC"][0], row["normalized zC"][2]
                )
                row["(zC * cost)_d"] = ((zc_cost_mean - 1.0) * 100.0, 0.0, zc_cost_std * 100.0, True)
                row["(zC_norm * cost)_d"] = ((zc_norm_cost_mean - 1.0) * 100.0, 0.0, zc_norm_cost_std * 100.0, True)

        rows.append(row)

    return make_table(sorted(rows, key=lambda r: (r["action"][0], r["experiment"][0])), columns, html=html)


def print_zc_results(zc_res, exp_stats, calculate_cost, html=False):
    # DEPRECATED, LEFT FOR COMPATIBILITY
    rows = []
    cost_etalon = exp_stats[CONTROL_EXP_ID].get("cost")

    for ((exp_id, action, action_view), (zc_mean, zc_std)) in sorted(zc_res.items()):
        row = {
            "experiment": (exp_id, None, None, None),
            "action": (str(action_view), None, None, None),
            "zC": (zc_mean, 1.0, zc_std, True),
        }

        if exp_id in exp_stats and cost_etalon:
            clicks_exp, cost_exp = exp_stats[exp_id].get("clicks"), exp_stats[exp_id].get("cost")
            cost_r = float(cost_exp) / float(cost_etalon) * 100.0
            eps = math.sqrt(2.0 / (1.0 + clicks_exp))
            cost_std = eps * COST_EPS * 100.0

            zc_cost = zc_mean * cost_r
            zc_cost_std = math.sqrt((zc_std * cost_std) ** 2 + (zc_std * cost_r) ** 2 + (cost_std * zc_mean) ** 2)
            row["cost_d"] = (cost_r - 100, 0.0, cost_std, True)
            row["zC * cost"] = (zc_cost - 100.0, 0.0, zc_cost_std, True)

        rows.append(row)

    if calculate_cost:
        return make_table(rows, ("experiment", "action", "zC", "cost_d", "zC * cost"), html=html)
    else:
        return make_table(rows, ("experiment", "action", "zC"), html=html)


def make_table(table_rows, columns, html=False):
    if html:
        return dedent(
            "\n".join(itertools.chain(
                ("<table>", ),
                ("<tr>", _make_row([(column, None, None, None) for column in columns], html=True, header=True), "</tr>"),
                itertools.chain.from_iterable(
                    ("<tr>", _make_row([row[column] for column in columns], html=True), "</tr>") for row in table_rows
                ),
                ("</table>", )
            ))
        )
    else:
        return dedent(
            "\n".join(itertools.chain(
                ("#|", _make_row([(column, None, None, None) for column in columns], header=True)),
                (_make_row([row[column] for column in columns]) for row in table_rows),
                ("|#", )
            ))
        )


def _calculate_product_with_std(mean1, std1, mean2, std2):
    if mean1 is None or mean2 is None:
        product_mean = None
        product_std = None
    else:
        product_mean = mean1 * mean2

        if std1 is None or std2 is None:
            product_std = None
        else:
            product_std = math.sqrt((std1 * std2) ** 2 + (mean1 * std2) ** 2 + (mean2 * std1) ** 2)

    return product_mean, product_std


def _calculate_ratio_with_std(num_mean, num_std, denom_mean, denom_std):
    # http://www.stat.cmu.edu/~hseltman/files/ratio.pdf
    if any(x is None for x in (num_mean, num_std, denom_mean, denom_std)):
        ratio_mean = None
        ratio_std = None
    else:
        ratio_mean = num_mean / denom_mean * (1.0 + (denom_std / denom_mean) ** 2)
        ratio_std = (num_mean / denom_mean) * math.sqrt((num_std / num_mean) ** 2 + (denom_std / denom_mean) ** 2)

    return ratio_mean, ratio_std


def _make_row(values, html=False, header=False):
    print_values = []

    for (value, expectation, deviation, more_is_better) in values:
        if value is None:
            print_values.append("?")
        if isinstance(value, str):
            print_values.append(("<b>{}</b>" if html else "**{}**").format(value) if header else value)
        elif isinstance(value, int):
            print_values.append(str(value))
        elif isinstance(value, float):
            if expectation is not None and deviation is not None:
                print_values.append(_paint_value(value, expectation, deviation, more_is_better or True, html))
            else:
                print_values.append("{:.4f}".format(value))

    if html:
        if header:
            return "".join("<th>{}</th>".format(v) for v in print_values)
        else:
            return "".join("<td>{}</td>".format(v) for v in print_values)
    else:
        return "|| {} ||".format(" | ".join(print_values))


def _paint_value(value, expectation, deviation, more_is_better=True, html=False):
    if value > expectation:
        color = "green" if more_is_better else "red"
    else:
        color = "red" if more_is_better else "green"

    if abs(value - expectation) > 2.0 * deviation * STDDEV_K:
        if html:
            res = "<b style=\"color:{};\">{:.4f} +/- {:.4f}</b>".format(color, value, deviation)
        else:
            res = "**!!({}){:.4f} +/- {:.4f}!!**".format(color, value, deviation)
    elif abs(value - expectation) > deviation * STDDEV_K:
        if html:
            res = "<p style=\"color:{};\">{:.4f} +/- {:.4f}</p>".format(color, value, deviation)
        else:
            res = "!!({}){:.4f} +/- {:.4f}!!".format(color, value, deviation)
    else:
        res = "{:.4f} +/- {:.4f}".format(value, deviation)

    return res
