#!/usr/bin/env python
# -*- coding: utf-8 -*-
import sys
import os
import codecs
import argparse
import datetime
from collections import defaultdict
import scipy.stats
import requests
from nile.api.v1 import clusters
from yql.api.v1.client import YqlClient
import traceback
from videolog_common import (
    get_cluster,
    get_driver,
    YqlRunner,
    apply_replacements,
)

TITLE = "Video Player Fast Metrics | CHYT"
CHYT_STUB = """
use chyt.@[cluster]@[clique];

create table "@[output_table]" Engine = YtTable()
as
select
    `Version`,
    bucket,
    EventName,
    count(*) as `count`
from @[tables]
where Service == '@[service]' and has(@[versions], `Version`)
group by `Version`, EventName, cityHash64(VSID) % 100 as bucket
order by EventName, `Version`, bucket
"""
html_stub = (
    u'<html><head><meta charset="UTF-8">'
    '<script src="https://pcode-ci.s3.mds.yandex.net/'
    'nirvana-report/ims-player-fast-metrics/index.js"></script>'
    '<link rel="stylesheet" href="https://pcode-ci.s3.mds.yandex.net/'
    'nirvana-report/ims-player-fast-metrics/index.css">'
    '<script> var PROJECT="{project}"; var START_TS = {start_ts}; var END_TS = {end_ts};</script>'
    '</head><body>'
)
api = "https://st-api.yandex-team.ru/v2"
headers = {
    "Content-Type": "application/json",
    "Authorization": "OAuth {}".format(os.environ["STARTREK_TOKEN"]),
}


class Moscow(datetime.tzinfo):
    def utcoffset(self, dt):
        return datetime.timedelta(hours=3)

    def tzname(self, dt):
        return "Moscow"

    def dst(self, dt):
        return datetime.timedelta(0)


moscow = Moscow()
dtformat = "%Y-%m-%dT%H:%M:%S"


def parse_date(s):
    if "T" in s:
        return datetime.datetime.strptime(s, dtformat).replace(tzinfo=moscow)
    return datetime.datetime.strptime(s, "%Y-%m-%d").replace(tzinfo=moscow)


def make_chyt_query(args, output_table, tables):
    if len(tables) == 1:
        from_ = '"{}"'.format(tables[0])
    else:
        from_ = "concatYtTables({})".format(
            ",".join('"{}"'.format(table) for table in tables)
        )
    versions = [args["control"]] + args["experiment"]
    query = apply_replacements(
        CHYT_STUB,
        {
            "@[root]": args["root"],
            "@[cluster]": args["proxy"].lower(),
            "@[clique]": "/{}".format(args["clique"]) if args["clique"] else "",
            "@[output_table]": output_table,
            "@[tables]": from_,
            "@[service]": args["service"],
            "@[versions]": str(versions).replace('"', "'"),
        },
    )
    return query


def daily_path_filter(path, from_, to_):
    from_ = from_.split("T")[0]
    to_ = to_.split("T")[0]
    path = path.split("/")[-1]
    return from_ <= path <= to_


def fast_path_filter(path, from_, to_, dailies):
    daily_thresh = dailies[-1].split("/")[-1] if dailies else ""
    path = path.split("/")[-1]
    if daily_thresh:
        return (
            from_ <= path <= to_
            and path > daily_thresh
            and daily_thresh not in path
        )
    return from_ <= path <= to_


def wrap(rec):
    if isinstance(rec, dict):
        return {wrap(k): wrap(v) for k, v in rec.items()}
    if isinstance(rec, bytes):
        return rec.decode("utf8", errors="replace")
    if isinstance(rec, list):
        return [wrap(x) for x in rec]
    return rec


def mw_wrapper(*args, **kwargs):
    try:
        return scipy.stats.mannwhitneyu(*args, **kwargs)
    except ValueError:
        return (0, 1)


def calc_metric(args, metric, metric_name, positive_direction=None):
    control_buckets = metric[args["control"]]
    control_value = sum(metric[args["control"]])
    row = {"value": control_value, "exps": []}
    row["metric"] = metric_name
    try:
        print("calculating {}".format(row["metric"]))
    except:
        import pdb; pdb.set_trace()
    if positive_direction:
        row["positive_direction"] = positive_direction
    for exp in sorted(args["experiment"]):
        exp_buckets = metric[exp]
        exp_value = sum(metric[exp])
        exp_row = {"value": exp_value or 0}
        exp_row["diff"] = (exp_row["value"] or 0) - (row["value"] or 0)
        try:
            exp_row["perc_diff"] = exp_row["diff"] / float(row["value"])
        except ZeroDivisionError:
            exp_row["perc_diff"] = 1
        exp_row["pvalue_less"] = mw_wrapper(
            control_buckets, exp_buckets, alternative="less"
        )[1]
        exp_row["pvalue_greater"] = mw_wrapper(
            control_buckets, exp_buckets, alternative="greater"
        )[1]
        exp_row["pvalue"] = min(
            exp_row["pvalue_less"], exp_row["pvalue_greater"]
        )
        if exp_row["pvalue_less"] < exp_row["pvalue_greater"]:
            exp_row["direction"] = "up"
        elif exp_row["pvalue_greater"] < exp_row["pvalue_less"]:
            exp_row["direction"] = "down"
        else:
            exp_row["direction"] = "none"
        row["exps"].append(exp_row)
    return row


class MetricFormatter(object):
    def __init__(
        self, control, experiments, metrics, task=None, additional_data=None
    ):
        self.cnt = control
        self.exps = sorted(experiments)
        self.metrics = metrics
        self.task = task
        self.additional_data = additional_data or {}

    @staticmethod
    def format_diff(diff):
        if not isinstance(diff, (int, float, long)):
            return u"–"
        return u"{sign}{diff:.04f}".format(
            sign=u"+" if diff > 0 else u"", diff=diff
        ).replace(u"-", u"–")

    @staticmethod
    def _format_pvalue(perc_diff, pvalue, positive_direction="unknown"):
        real_direction = (
            "up" if (perc_diff is None or perc_diff > 0) else "down"
        )
        if pvalue >= 0.01 or (perc_diff is not None and perc_diff == 0):
            color = "gray"
        elif positive_direction == "unknown":
            color = "yellow"
        elif positive_direction == real_direction:
            color = "green"
        else:
            color = "red"
        return "!!({color}){pvalue:.04f}!!".format(color=color, pvalue=pvalue)

    def format_pvalue(self, metric, positive_direction="unknown"):
        if "pvalue" not in metric:
            return u"–"
        return self._format_pvalue(
            metric["perc_diff"],
            metric["pvalue"],
            positive_direction=positive_direction,
        )

    @staticmethod
    def get_abstract_pvalue(metric, positive_direction="unknown"):
        if "pvalue" not in metric:
            return u"-"
        perc_diff = metric["perc_diff"]
        pvalue = metric["pvalue"]
        real_direction = metric["direction"]
        result = {
            "type": "pvalue",
            "color": "gray",
            "value": pvalue,
            "direction": real_direction,
        }
        if pvalue >= 0.01 or (perc_diff is not None and perc_diff == 0):
            return result
        if positive_direction == "unknown":
            result["color"] = "yellow"
        elif positive_direction == real_direction:
            result["color"] = "green"
        elif real_direction == "none":
            result["color"] = "gray"
        else:
            result["color"] = "red"
        return result

    @staticmethod
    def st_format_elem(elem):
        if isinstance(elem, dict):
            return u"!!({color}){pvalue:.04f}!!".format(
                color=elem["color"], pvalue=elem["value"]
            )
        return format(elem)

    def st_format_row(self, row):
        return u"|| " + u" | ".join(map(self.st_format_elem, row)) + u" ||"

    @staticmethod
    def html_format_elem(elem, tag="td"):
        if isinstance(elem, dict):
            return u'<{tag} class="{color}">{pvalue:.04f} ({direction})</{tag}>'.format(
                color=elem["color"],
                pvalue=elem["value"],
                tag=tag,
                direction=elem["direction"],
            )
        if isinstance(elem, str):
            elem = elem.decode("utf8", errors="replace")
        return u"<{tag}>{elem}</{tag}>".format(elem=elem, tag=tag)

    def html_format_row(self, row, header=False):
        if not header:
            formatted_row = [self.html_format_elem(row[0], tag="th")] + list(
                map(self.html_format_elem, row[1:])
            )
        else:
            formatted_row = list(
                map(lambda x: self.html_format_elem(x, tag="th"), row)
            )
        return u"<tr>" + u"".join(formatted_row) + u"</tr>"

    def st_format_table(self, rows):
        rows = [u"#|"] + [self.st_format_row(row) for row in rows] + [u"|#"]
        return u"\n".join(rows)

    def html_format_table(self, rows):
        header, rows = rows[0], rows[1:]
        return (
            "<table>\n"
            + "<thead>\n"
            + self.html_format_row(header, header=True)
            + "</thead>\n<tbody>\n"
            + "\n".join(map(self.html_format_row, rows))
            + "\n</tbody>\n</table>"
        )

    def generate_metric_table(self, metrics, exp_id=None):
        result = []
        row = [u"Metric", u"{} value".format(self.cnt)]
        for exp in self.exps if exp_id is None else [self.exps[exp_id]]:
            row += [
                u"{} value".format(exp),
                u"{} diff".format(exp),
                u"{} Percent diff".format(exp),
                u"{} pValue".format(exp),
            ]
        result.append(row)
        for metric in metrics:
            row = [metric["metric"], metric["value"]]
            if exp_id is not None and exp_id >= len(metric["exps"]):
                sys.stderr.write(
                    "unable to calculate {}\n".format(metric["metric"])
                )
                continue
            for exp in (
                metric["exps"] if exp_id is None else [metric["exps"][exp_id]]
            ):
                row += [
                    exp["value"],
                    self.format_diff(exp["diff"]),
                    self.format_diff(exp["perc_diff"]),
                    self.get_abstract_pvalue(
                        exp,
                        positive_direction=metric.get(
                            "positive_direction", "unknown"
                        ),
                    ),
                ]
            result.append(row)
        return result

    def generate_st_comment(self):
        ad = self.additional_data
        result = [
            u"**Запуск стартовал**: {}".format(ad.get("start_ts")),
            u"**От**: {}".format(ad.get("from_")),
            u"**До**: {}".format(ad.get("to_")),
            u"**Контроль**: {}".format(self.cnt),
            u"**Эксперименты**: {}".format(self.exps),
        ]
        if ad.get("share_urls"):
            result.append(
                u"**Ссылки на запросы**: {}".format(
                    u", ".join(ad["share_urls"])
                )
            )
        result.append(u"\nМетрики см. в html-файлах ниже.")
        return u"\n".join(result)

    def generate_html(self, exp=None):
        result = [html_stub.format(
            start_ts=datetime.datetime.strptime(
                self.additional_data["from_"], "%Y-%m-%dT%H:%M:%S"
            ).strftime("%s") + "000",
            end_ts=datetime.datetime.strptime(
                self.additional_data["to_"], "%Y-%m-%dT%H:%M:%S"
            ).strftime("%s") + "000",
            project=self.additional_data["args"]["service"]
        )]
        result.append(u"<h1>Main metrics</h1>")
        result.append(
            self.html_format_table(
                self.generate_metric_table([self.metrics[0]], exp_id=exp)
            )
        )
        result.append(u"<h1>Event metrics</h1>")
        result.append(
            self.html_format_table(
                self.generate_metric_table(self.metrics[1:], exp_id=exp)
            )
        )
        result.append("</body></html>")
        return u"\n".join(result)

    @staticmethod
    def _post_comment(text, task, attachments=None):
        json_ = {"text": text}
        if attachments:
            json_["attachmentIds"] = attachments
        kwargs = dict(headers=headers, json=json_, verify=False)
        req = requests.post(
            "{}/issues/{}/comments".format(api, task), **kwargs
        )
        return req

    @staticmethod
    def st_upload_file(text, task, filename=None):
        if not filename:
            filename = "{}-{}.html".format(
                task, datetime.datetime.now().strftime("%s")
            )
        headers_ = headers.copy()
        headers_.pop("Content-Type")
        with codecs.open("tmpfile", "w", "utf8") as f:
            f.write(text)
        files = {"file": open("tmpfile", "rb")}
        req = requests.post(
            "{api}/attachments?filename={filename}".format(
                api=api, filename=filename
            ),
            files=files,
            headers=headers_,
            verify=False,
        )
        os.remove("tmpfile")
        return req

    def generate_comment_and_attachments(self):
        comment = self.generate_st_comment()
        attachments = []
        htmls = []
        for i, exp in enumerate(self.exps):
            html = self.generate_html(exp=i)
            htmls.append(html)
            exp_id = exp.split("-")[-1]
            filename = "{}_{}.html".format(exp_id, self.task)
            req = self.st_upload_file(html, self.task, filename=filename)
            try:
                id_ = req.json()["id"]
            except:
                raise Exception(
                    u"unsuccessful query: {} {}".format(
                        req.status_code, req.text
                    )
                )
            attachments.append(id_)
        return (comment, attachments, htmls)

    def post_comment(self, data=None):
        if not data:
            comment, attachments, htmls = (
                self.generate_comment_and_attachments()
            )
        else:
            comment, attachments, htmls = data
        req = self._post_comment(comment, self.task, attachments=attachments)
        return req, comment, htmls


def ensure_utf8(s):
    if isinstance(s, str):
        return s.decode("utf8", errors="replace")
    return s.encode("utf8", errors="surrogateescape").decode("utf8", errors="replace")


def main():
    parser = argparse.ArgumentParser()
    parser.add_argument("--root", default="//tmp")
    parser.add_argument("--proxy", default="Hahn")
    parser.add_argument("--from")
    parser.add_argument("--to")
    parser.add_argument("--task")
    parser.add_argument("--clique")
    parser.add_argument("--force_table")
    parser.add_argument("--control")
    parser.add_argument("--service", default="StreamPlayer")
    parser.add_argument("--experiment")
    parser.add_argument("--cleanup", default="yes")
    parser.add_argument("--output_file", default="output.json")
    parser.add_argument("--output_html", default="output.html")
    args = vars(parser.parse_args())
    args["experiment"] = sorted(args["experiment"].split(","))

    cluster = get_cluster(clusters, args)
    yt = get_driver(cluster).client

    client = YqlClient(token=os.environ["YQL_TOKEN"])
    client.config.db = None

    yr = YqlRunner(client, title=TITLE)

    start_ts = datetime.datetime.now()
    print("started at: {}".format(start_ts.strftime("%Y-%m-%dT%H:%M:%S")))

    from_ = args.get("from").split("+")[0]
    from_d = parse_date(from_)
    to_ = args.get("to").split("+")[0]
    if not to_ or to_ == "now":
        to_d = datetime.datetime.now(moscow)
        to_ = to_d.strftime(dtformat)
    print("from: {}".format(from_))
    print("to: {}".format(to_))

    dailies = sorted(
        yt.search(
            root="//logs/jstracer-log/1d",
            path_filter=lambda x: daily_path_filter(x, from_, to_),
        )
    )
    print("daily jstracer tables: {}".format(dailies))

    fasts = sorted(
        yt.search(
            root="//logs/jstracer-log/30min",
            path_filter=lambda x: fast_path_filter(x, from_, to_, dailies),
        )
    )
    print("fast jstracer tables: {}".format(fasts))

    all_tables = dailies + fasts
    if not all_tables:
        raise Exception("No tables available")

    if args.get("force_table"):
        table_id = args["force_table"]
    else:
        table_id = "videoplayerchytabt_{}".format(
            start_ts.strftime("%Y%m%dT%H%M%S")
        )
    output_table = "{}/{}".format(args["root"], table_id)
    if not args.get("force_table"):
        chyt_query = make_chyt_query(args, output_table, all_tables)
        req = yr.run(chyt_query, query_type="CLICKHOUSE")
        share_url = req.share_url
    else:
        share_url = "-"

    current_metric = defaultdict(list)
    current_metric_name = None
    previous_metric_name = None
    tvt_metric = defaultdict(lambda: [0] * 100)

    metrics = []

    for rec_ in yt.read_table(output_table):
        rec = wrap(rec_)
        if not rec.get("bucket"):
            continue
        current_metric_name = ensure_utf8(rec["EventName"])
        if (
            current_metric_name != previous_metric_name
            and previous_metric_name
        ):
            try:
                metrics.append(
                    calc_metric(args, current_metric, previous_metric_name)
                )
            except Exception as e:
                sys.stderr.write(traceback.format_exc(e) + "\n\n")
            current_metric = defaultdict(list)
        if current_metric_name in (
            "10SecWatched",
            "20SecWatched",
            "30SecHeartbeat",
        ):
            tvt_metric[rec["Version"]][rec["bucket"]] += rec["count"] * int(
                current_metric_name[0]
            )
        current_metric[rec["Version"]].append(rec["count"])
        previous_metric_name = current_metric_name
    try:
        metrics.append(calc_metric(args, current_metric, current_metric_name))
    except Exception as e:
        sys.stderr.write(traceback.format_exc(e) + "\n\n")
    metrics.insert(0, calc_metric(args, tvt_metric, "PseudoTVT"))

    mf = MetricFormatter(
        args["control"],
        args["experiment"],
        metrics,
        task=args.get("task"),
        additional_data=dict(
            share_urls=[share_url],
            start_ts=start_ts,
            from_=from_,
            to_=to_,
            dailies=dailies,
            fasts=fasts,
            # money_dailies=money_dailies,
            # money_fasts=money_fasts,
            # rum_dailies=rum_dailies,
            args=args,
        ),
    )
    comment, att_, htmls_ = mf.generate_comment_and_attachments()
    req = None
    if args.get("task"):
        req, comment, htmls = mf.post_comment(data=(comment, att_, htmls_))

    if args["cleanup"] == "yes":
        yt.remove(output_table)

    output = u"comment posted at: {}\n\n{}".format(
        "https://st.yandex-team.ru/{}#{}".format(
            args.get("task"),
            req.json().get("longId", "ERROR") if req else None,
        ),
        comment,
    )

    with codecs.open(args["output_file"], "w", "utf8") as f:
        f.write(output)
    with codecs.open(args["output_html"], "w", "utf8") as f:
        f.write(u"\n".join(htmls))


if __name__ == "__main__":
    main()
