# coding: utf-8
from __future__ import unicode_literals
from __future__ import print_function

import base64
import datetime
import os
import random
import requests
import subprocess
import tempfile
import time
import tvmauth

import pandas as pd
import yt.wrapper as yt


class Report(object):
    def run(self):
        pass


def table_name(table_template):
    """Get env-specific table name"""
    return table_template.format(crypta_env=os.environ.get("ENV_TYPE", "production").lower())


class YtTableReport(object):

    """Base report class"""

    STYLES = [
        dict(selector="th", props=[("font-size", "120%"), ("text-align", "left")]),
        dict(selector="td", props=[("text-align", "right")]),
        dict(selector="table, th, td", props=[("border", "1px solid black"), ("border-collapse", "collapse")]),
        dict(selector="table", props=[("width", "100%")]),
    ]

    def __init__(self, table_path, fields_mapping):
        self.table_name_template = table_path
        self.fields_mapping = fields_mapping

    def run(self):
        """Make report image"""
        data = self.select_data()
        df = self.make_dataframe(data)
        html = self.to_html(df)
        return self.html2image(html, self.width)

    def html2image(self, html_content, width):
        """convert html into png image using wkhtmltoimage bin"""
        _, tmp_html = tempfile.mkstemp(".html")
        _, tmp_png = tempfile.mkstemp(".png")

        # tmp_html = "x.html"
        # tmp_png = "x.png"

        with open(tmp_html, "w") as html_file:
            html_content = html_content.replace("<table", '<table width="100%" cellspacing="0"')

            for field in self.fields_mapping + [{"field": "field", "verbose": "метрика"}]:
                html_content = html_content.replace(
                    field["field"], "<center>{vb}</center>".format(vb=field["verbose"])
                )

            html_file.write('<meta charset="UTF-8">')
            html_file.write(html_content.encode("utf-8"))

        # for some run is required
        subprocess.check_call(["chmod", "u+x", "crypta/utils/wkhtml/wkhtmltox/bin/wkhtmltoimage"])
        assert (
            subprocess.check_call(
                [
                    "crypta/utils/wkhtml/wkhtmltox/bin/wkhtmltoimage",
                    "--quality",
                    "100",
                    "--format",
                    "png",
                    "--width",
                    str(width),
                    tmp_html,
                    tmp_png,
                ]
            )
            == 0  # W503 # noqa
        )

        return tmp_png

    def select_data(self):
        # print(self.query)
        return list(yt.select_rows(self.query, format="json"))

    def make_dataframe(self, data):
        raise NotImplementedError()

    def to_html(self, df):
        raise NotImplementedError()

    @property
    def table_name(self):
        return table_name(self.table_name_template)

    @property
    def query(self):
        raise NotImplementedError()


class CommonReport(YtTableReport):

    """Make report of common metrics"""

    def __init__(self, config):
        self.width = 1200
        super(CommonReport, self).__init__(config["tables"]["storage"], config["reports"]["common"])

    @property
    def query(self):
        return """
            version,
            dt,
            field,
            kind,
            MAX(absolute) AS absolute,
            MAX(percentage) AS percentage
        FROM [{table}]
        WHERE
            version IN ('v2', 'v2exp')
            AND dt > '{date:%Y-%m-%d}'
            AND ({fields})
        GROUP BY version, dt, field, kind
        """.format(
            table=self.table_name,
            date=(datetime.datetime.now() - datetime.timedelta(days=35)),
            fields=" OR ".join(
                "(kind='{}' AND field='{}')".format(f["kind"], f["field"]) for f in self.fields_mapping
            ),
        )

    def make_dataframe(self, data):
        df = pd.DataFrame(data)
        fields_mapping_key_value = {field["field"]: field["series"] for field in self.fields_mapping}

        def fields_mapping_proces(line):
            value = line[fields_mapping_key_value[line.field]]
            return pd.Series(
                [
                    line.version,
                    line["dt"],
                    line.field,
                    (value, value * 100.0)[line.field in {"radius_x_recall_mean_opt", "radius_x_precision_mean_opt"}],
                ],
                ["version", "date", "field", "value"],
            )

        df = df.apply(fields_mapping_proces, axis=1).dropna()

        def select_top(df, version):
            ver_df = df[df.version == version]
            top_dates = ver_df.date.sort_values(ascending=False).unique()[:2]
            ver_df = ver_df[ver_df.date.isin(top_dates)]
            ver_df = ver_df.pivot_table(index="field", columns=["version", "date"], values=["value"])["value"][version]

            ver_df = ver_df.assign(diff=ver_df.diff(axis=1).values[:, 1])
            ver_df.columns = pd.MultiIndex.from_product([[version], ver_df.columns])
            return ver_df

        def max_common_date_diff(df, v0="v2", v1="v2exp"):
            versions = df.groupby(by=["date"])["version"].apply(set)
            max_common_date = versions.loc[versions == {v0, v1}].index.max()

            def common_day(df, max_common_date, version):
                return df.loc[(df.date == max_common_date) & (df.version == version)][["field", "value"]].set_index(
                    "field"
                )

            diff = (
                common_day(df, max_common_date, v0)
                .join(common_day(df, max_common_date, v1), how="outer", lsuffix="_" + v0, rsuffix="_" + v1)
                .assign(common_day=lambda row: (row["value_{}".format(v1)] - row["value_{}".format(v0)]))
            )[["common_day"]]
            return diff.rename(columns={"common_day": ("{}/{}".format(v0, v1), max_common_date)})

        return (
            select_top(df, "v2")
            .join(select_top(df, "v2exp"), how="outer")
            .reindex([field["field"] for field in self.fields_mapping])
            .join(max_common_date_diff(df))
        )

    def to_html(self, df):  # noqa
        def float_format(val):
            float_format.counter += 1
            if not isinstance(val, float):
                return str(val)

            if float_format.counter == 7:
                float_format.counter = 0

            def number(val):
                """
                Left 2 characters after the decimal point,
                Add human readable suffix for big numbers
                """
                if abs(val) < 1000:
                    return "{:.2f}".format(val)
                if abs(val) < 1000000:
                    return "{:.2f}K".format(val / 1000.0)
                return "{:.2f}M".format(val / 1000000.0)

            def diff(val):
                return "{:+.2f}".format(float(val))

            if float_format.counter % 3 != 0:
                return number(val)
            else:
                return diff(val)

        float_format.counter = 0

        def color_negative_red(negative_bad=True):
            """
            Takes a scalar and returns a string with
            the css property `'color: red'` for negative
            strings, black otherwise.
            """
            color_map = ["red", "green"]
            if not negative_bad:
                color_map.reverse()
            colors = (
                (lambda val: abs(val) < 0.009, "black"),
                (lambda val: val < 0.0, color_map[0]),
                (lambda val: val > 0.0, color_map[1]),
            )

            def inner(val):
                inner.counter += 1
                if inner.counter == 7:
                    inner.counter = 0
                if inner.counter % 3 != 0:
                    return ""
                if not isinstance(val, (float, int)):
                    return ""
                return "color: {}".format([color[1] for color in colors if color[0](val)][0])

            inner.counter = 0
            return inner

        # set negative bad
        style = df.fillna("").style
        for metric in self.fields_mapping:
            style = style.applymap(
                color_negative_red(not metric["negative"]), subset=pd.IndexSlice[metric["field"], :]
            )

        return style.format(float_format).set_table_styles(self.STYLES).render()


class MoneyReport(YtTableReport):

    """Make report of money"""

    splits = "shows", "clicks", "cost"
    splits_c = 0

    def __init__(self, config):
        self.width = 1200
        self.stickers = config["stickers"]["money"]
        super(MoneyReport, self).__init__(config["tables"]["storage"], config["reports"]["money"])

    @property
    def query(self):
        field_list = list((f["kind"], f["field"]) for f in self.fields_mapping)
        for split in self.splits:
            for f in self.fields_mapping:
                field_list.append((f["kind"], f["field"].replace("cost_d", split)))

        return """
            dt,
            field,
            kind,
            MAX(absolute) AS absolute,
            MAX(deviation) AS deviation
        FROM [{table}]
        WHERE
            version = 'v2'
            AND dt > '{date:%Y-%m-%d}'
            AND ({fields})
        GROUP BY dt, field, kind
        """.format(
            table=self.table_name,
            date=(datetime.datetime.now() - datetime.timedelta(days=35)),
            fields=" OR ".join("(kind='{}' AND field='{}')".format(*f) for f in field_list),
        )

    def run(self):
        return super(MoneyReport, self).run(), self._sticker

    def make_dataframe(self, data):
        df = (
            pd.DataFrame(filter(lambda row: "_cost_d" in row["field"], data))
            .dropna()
            .pivot(index="dt", columns="field", values=["absolute", "deviation"])
            .sort_values(by=["dt"], ascending=False)
            .head(7)  # last week
        )

        def get_ratio_subframe(key):
            ratio_frame = pd.DataFrame(filter(lambda row: key in row["field"] and "_cost_d" not in row["field"], data))
            ratio_frame.field = ratio_frame.field.str.replace(key, "_cost_d_")
            return (
                ratio_frame.dropna()
                .pivot(index="dt", columns="field", values=["absolute"])["absolute"]
                .sort_values(by=["dt"], ascending=False)
                .head(1)  # only top
                .apply(lambda row: 100.0 * row / row.max(), axis="columns")
            )

        df.loc["mean"] = df.mean()
        df.loc["min"] = df.min()
        df.loc["max"] = df.max()

        # denezhki
        fieldname = self.fields_mapping[0]["field"]
        is_max = df.iloc[0][("absolute", fieldname)] == df.loc["max"][("absolute", fieldname)]
        is_min = df.iloc[0][("absolute", fieldname)] == df.loc["min"][("absolute", fieldname)]
        color = (("grey", "green")[is_max], "red")[is_min]
        self._sticker = self.stickers[color]

        for ratio in self.splits:
            try:
                absolute = get_ratio_subframe("_{}_".format(ratio))
                df.loc["{key} ratio".format(key=ratio)] = df.iloc[0]
                df.loc["{key} ratio".format(key=ratio)].absolute = absolute
                self.splits_c += 1
            except AttributeError:
                # no given slice
                pass

        return df

    def to_html(self, df):  # noqa
        def float_format(val):
            if not isinstance(val, float):
                return str(val)
            return "{:.2f}".format(val)

        def highlight(line):

            green = "background-color: rgba(88, 214, 141, 0.5)"
            red = "background-color: rgba(236, 112, 99, 0.5)"

            output = [""] * len(line)
            if "ratio" in line.name:
                # no highlight for extra lines
                return output

            for index in range(len(line)):
                if line.get("is_max_{idx}".format(idx=index)):
                    output[index] = green
                elif line.get("is_min_{idx}".format(idx=index)):
                    output[index] = red
            return output

        def merge_cols(line):
            keys = (item["field"] for item in self.fields_mapping)
            for key in keys:
                if line.name.startswith("2"):
                    line[key] = "{0:.4f}±{1:.2f}".format(line[key], line["{key}_dev".format(key=key)])
                else:
                    line[key] = "{0:.4f}".format(line[key])
            return line

        df_dev = df.deviation
        df_abs = df.absolute[(item["field"] for item in self.fields_mapping)]
        df = df_abs.join(df_dev, lsuffix="", rsuffix="_dev")

        hide_columns = []
        for index, field in enumerate(self.fields_mapping):
            x_field = self.fields_mapping[index]["field"]
            df = df.assign(
                **{
                    "is_max_{idx}".format(idx=index): (
                        df.iloc[: -1 * self.splits_c][x_field] == df.iloc[: -1 * self.splits_c][x_field].max()
                    ),
                    "is_min_{idx}".format(idx=index): (
                        df.iloc[: -1 * self.splits_c][x_field] == df.iloc[: -1 * self.splits_c][x_field].min()
                    ),
                }
            )
            hide_columns.extend(
                [
                    "is_min_{idx}".format(idx=index),
                    "is_max_{idx}".format(idx=index),
                    "{field}_dev".format(field=x_field),
                ]
            )

        return (
            df.apply(merge_cols, axis=1)
            .rename(columns={self.fields_mapping[0]["series"]: self.fields_mapping[0]["verbose"]})
            .style.apply(highlight, axis=1)
            .hide_columns(hide_columns)
            # .format(float_format)
            .set_table_styles(self.STYLES)
            .render()
        )


class SimpleMoneyReport(MoneyReport):

    splits = "cost_d"
    splits_c = 1

    @property
    def query(self):
        field_list = list((f["kind"], f["field"]) for f in self.fields_mapping)
        for split in self.splits:
            for f in self.fields_mapping:
                field_list.append((f["kind"], f["field"].replace("cost_d", split)))

        return """
            dt,
            field,
            kind,
            0.0 AS deviation,
            MAX(percentage) AS absolute
        FROM [{table}]
        WHERE
            version = 'v2'
            AND dt > '{date:%Y-%m-%d}'
            AND ({fields})
        GROUP BY dt, field, kind
        """.format(
            table=self.table_name,
            date=(datetime.datetime.now() - datetime.timedelta(days=35)),
            fields=" OR ".join("(kind='{}' AND field='{}')".format(*f) for f in field_list),
        )


class JugglerTaskStatusReport(Report):
    def __init__(self, env):
        super(JugglerTaskStatusReport, self).__init__()
        self.env = env
        self.preloaded_report = None
        if env == "Prod":
            self.juggler_filter = {"namespace": "crypta", "host": "crypta-task-status", "service": "all-tasks"}
        else:
            self.juggler_filter = {"namespace": "crypta", "host": "crypta-task-status", "service": "testing-tasks"}

    @staticmethod
    def get_check_status(juggler_filter):
        filter_params = {"filters": [juggler_filter]}
        response = requests.post(
            "https://juggler-api.search.yandex.net/v2/checks/get_checks_state",
            headers={"Content-Type": "application/json"},
            json=filter_params,
            verify=False,
        )

        response.raise_for_status()

        checks = response.json()["items"]
        for check in checks:
            yield check["status"], check["description"]

    @staticmethod
    def get_dashboard_link(juggler_filter):
        url_params = "&".join([k + "=" + v for k, v in juggler_filter.items()])
        return "https://juggler.yandex-team.ru/check_details/?last=1DAY&" + url_params

    def _get_cached_report(self):
        if self.preloaded_report is None:
            self.preloaded_report = list(self.get_check_status(self.juggler_filter))
        return self.preloaded_report

    def _get_formatted_report(self, fmt):
        report_text = ""
        dashboard_link = self.get_dashboard_link(self.juggler_filter)
        for status, description in self._get_cached_report():
            report_text += fmt.format(env=self.env, status=status, descr=description, url=dashboard_link)
        return report_text.replace("``````", "")

    def run(self):
        yield self._get_formatted_report("{env} status is {status}: ```{descr}``` [Dashboard]({url})")


class StreamJugglerReport(JugglerTaskStatusReport):
    def __init__(self, env, *args, **kwargs):
        super(StreamJugglerReport, self).__init__(env, *args, **kwargs)
        if env == "Prod":
            self.juggler_filter = {"namespace": "crypta", "host": "crypta-task-status", "service": "stream-logs"}
        else:
            self.juggler_filter = {
                "namespace": "crypta",
                "host": "crypta-task-status",
                "service": "stream-logs-testing",
            }


class Greeting(YtTableReport):

    """Just for fun"""

    class Message:
        def __init__(self, image=None, text=None, sticker=None):
            self.image = image
            self.text = text
            self.sticker = sticker

    def __init__(self, config):
        super(Greeting, self).__init__(config["tables"]["staff"], config["greeting"])
        self._greetings = (self._b_day, self._n_year, self._8_march)

    def run(self):
        for greeting_fun in self._greetings:
            for message in greeting_fun():
                yield message

    def _b_day(self):
        """Check for b_day messages"""
        config = self.fields_mapping["b_day"]
        day = datetime.date.today().strftime("%m-%d")
        users_b_day = yt.select_rows(
            """
            staff_login,
            try_get_string(data, "/telegram/0") as tg
            FROM [{table}]
            WHERE is_substr("-{day}", dt)
                AND try_get_string(data, "/telegram/0") IN {users!r}
        """.format(
                table=self.table_name, day=day, users=tuple(config["observed_users"])
            ),
            format="json",
        )

        for row in users_b_day:
            b64_message = base64.b64encode(
                "BB8 beep beep happy birthday for {username}".format(username=row["staff_login"])
            )

            random.seed(time.time() * 1e6)
            message = "{b64}\n\n{t0}\n{t1}\n{t2}".format(
                b64=b64_message,
                t0=random.choice(config["templates"]["t0"]),
                t1=random.choice(config["templates"]["t1"]),
                t2=random.choice(config["templates"]["t2"]),
            ).format(username=row["tg"])

            yield self.Message(config["image"], message)

    def _n_year(self):
        """Check for n_year message"""
        if datetime.date.today().strftime("%m-%d") != "12-31":
            return
        yield self.Message(self.fields_mapping["n_year"]["image"], self.fields_mapping["n_year"]["text"])

    def _8_march(self):
        """Check for n_year message"""
        if datetime.date.today().strftime("%m-%d") != "03-08":
            return
        yield self.Message(
            text=self.fields_mapping["8_march"]["text"], sticker=self.fields_mapping["8_march"]["sticker"]
        )


class DutyReport(Report):
    def __init__(self, tvm_id, tvm_secret, staff_table):
        self._tvm_client = tvmauth.TvmClient(
            tvmauth.TvmApiClientSettings(
                self_tvm_id=tvm_id,
                self_secret=tvm_secret,
                dsts={"abc": 2012190},
            )
        )

        self._abc_host = "https://abc-back.yandex-team.ru/api/v4/"
        self._sklejka_abc_service_id = 5684

        self._staff_table = table_name(staff_table)

    def run(self):
        yield "[Планёрка](https://wiki.yandex-team.ru/crypta/matching/internal/planjorka/)"

        response = requests.get(
            self._abc_host + "duty/on_duty/",
            params={"service": self._sklejka_abc_service_id},
            headers={"X-Ya-Service-Ticket": self._tvm_client.get_service_ticket_for("abc")},
        )

        if response.ok:
            duty_list = response.json()
            if duty_list:
                for duty in duty_list:
                    staff_login = duty["person"]["login"]
                    telegram_login = self._get_tg_login(staff_login)
                    login = "@" + (telegram_login or staff_login)

                    yield "Сегодня дежурит {login}\\. " " [Прибыть в 314 кабинет](https://wiki.yandex-team.ru/crypta/matching/internal/duty/) \\(инструкции, ссылки\\)".format(
                        login=login
                    )
            else:
                yield "Тревога\\! Дежурного нет\\!"

    def _get_tg_login(self, staff_login):
        tg_logins = list(
            yt.select_rows(
                """
            try_get_string(data, "/telegram/0") as tg
            FROM [{table}] WHERE staff_login = '{staff_login}'
            """.format(
                    table=self._staff_table, staff_login=staff_login
                ),
                format="json",
            )
        )

        if tg_logins:
            return tg_logins[0]["tg"]
        else:
            return None
