from __future__ import print_function, unicode_literals

import codecs
import datetime
import logging
import math
import time
import traceback
import urllib

import yt.wrapper as yt

import pandas as pd
import numpy as np
from fbprophet import Prophet

from crypta.lib.python.solomon.reporter import create_throttled_solomon_reporter
from crypta.lib.python.yql_runner.base_parser import BaseParser
from crypta.graph.metrics.helpers import wait
from ads.bsyeti.exp_stats.py_lib import exp_stats


def get_logger():
    # create logger
    logger = logging.getLogger("stats")
    logger.setLevel(logging.DEBUG)
    console = logging.StreamHandler()
    console.setLevel(logging.DEBUG)
    formatter = logging.Formatter("%(asctime)s [%(levelname)s] %(message)s")
    console.setFormatter(formatter)
    logger.addHandler(console)
    return logger


logger = get_logger()


def get_date(str_date):
    return datetime.datetime.strptime(str_date, "%Y-%m-%d").date()


def check_offset_days(old, new, offset):
    return new - datetime.timedelta(days=offset) == old


def chunks(data, size):
    for idx in xrange(0, len(data), size):
        yield data[idx : idx + size]  # E203 # noqa


def upload_to_solomon(conf, data):
    """
    Make solomon report
    Args:
        conf: solomon yaml config
        data:
    """
    solomon_client = create_throttled_solomon_reporter(push_interval=0.1, **conf)
    counter = 0
    for row in data:
        row_date = datetime.datetime.strptime(row["fielddate"], "%Y-%m-%d")

        if row_date < datetime.datetime.now() - datetime.timedelta(days=7):
            # upload only fresh data
            continue

        payload = dict(
            ts_datetime=row_date,
            labels={"version": row["version"], "fieldname": row["fieldname"], "kind": row["kind"]},
        )

        for sensor, value in row.iteritems():
            if value is not None and sensor not in {"fielddate", "fieldname", "version", "kind"}:
                solomon_client.set_value(sensor=sensor, value=value, **payload)
                counter += 1
                if counter % 1000 == 0:
                    time.sleep(0.1)

    time.sleep(3)


class StatsBaseNoData(ValueError):
    pass


class MetricaRunnerExpStas(object):

    """Run metrics from exp-stats"""

    KIND = "exp_stats"

    def __init__(self, version, fields, date, exp_id, key, etalon_exp_id, select_types=None, name=None, **kwargs):
        self.date = datetime.datetime.strptime(date + " 12:00", "%Y-%m-%d %H:%M")
        self.min_date = self.date - datetime.timedelta(days=365 * 2)
        self.exp_id = str(exp_id)
        self.etalon_exp_id = str(etalon_exp_id)
        if "clicks" not in fields:
            raise ValueError("Please, add clicks to metrics - it's required to compute eps.")

        self.fields = fields
        self.version = version
        self.key = key
        self.select_types = select_types
        self.name = name

    def run(self):
        arguments = self._get_arguments()

        ts = int(time.mktime(self.date.timetuple()))
        dt = datetime.datetime.strftime(self.date, "%Y-%m-%d")
        output = []

        print("https://search.crypta.yandex-team.ru/?query={}".format(urllib.quote_plus(" ".join(arguments))))
        for line in exp_stats(arguments):
            for field in self.fields:

                if line["exp"] == self.etalon_exp_id:
                    # take metrics from exp line
                    continue

                if not line[field].attributes:
                    # hook for strong value
                    attributes = {"Value": float(line[field]), "Deviation": 0}
                else:
                    attributes = {
                        "Value": float(line[field].attributes["Value"]),
                        "Deviation": float(line[field].attributes["Deviation"]),
                    }

                rsya = str(line["rsya"]).strip('"')

                suffix = (
                    "_k_{key}_{suf}".format(key=self.key, suf=str(line[self.key]).strip('"'))
                    if self.key is not None
                    else ""
                )

                additional_suffix = ""
                if self.name not in {"v2_main", "v2exp_main"}:
                    additional_suffix = "_{}".format(self.name.split("_", 1)[-1])
                    assert additional_suffix != "_"

                output_field = "{kind}_x_{field}{suffix}{subsuf}_rsya_{rsya}".format(
                    kind=self.KIND, field=field, suffix=suffix, subsuf=additional_suffix, rsya=rsya
                )

                if any(map(lambda val: math.isnan(val) or math.isinf(val), attributes.values())):
                    print("Skip nan or inf row {} {}".format(field, attributes))
                else:
                    output.append(
                        dict(
                            version=self.version,
                            ts=ts,
                            dt=dt,
                            kind=self.KIND,
                            field=output_field,
                            absolute=float(attributes["Value"]),
                            deviation=float(attributes["Deviation"]),
                        )
                    )

        return output

    def _get_arguments(self):
        arguments = [
            "./exp_stats",  # just for correct arg parse
            "--exp",
            ",".join([self.exp_id, self.etalon_exp_id]),
            "--etalon-exp",
            self.etalon_exp_id,
            "--fields",
            "{field}".format(field=",".join(self.fields)),
            "--keep-etalons",
            "--range",
            "{date:%Y%m%d}..{date:%Y%m%d}".format(date=self.date),
            "--from",
            "force-mq",
        ]

        if self.key is not None:
            arguments.extend(["--key", "exp,{key}".format(key=self.key)])
        if self.select_types is not None:
            arguments.extend(["--select_type", ",".join(map(str, self.select_types))])

        return arguments


class MetricaRunner(BaseParser):

    """Run crypta metrics query"""

    IS_SINGLE = False
    SINGE_TEMPLATE = "/templates/yql/single.sql.j2"
    SEPARATE_TEMPLATE = "/templates/yql/separate.sql.j2"
    PARTIAL_TEMPLATE = "/templates/yql/partial.sql.j2"

    def __init__(self, version, metrics, source, export_dir="", debug_dir="", diffs=None, **kwargs):
        self.metrics = metrics
        self.version = version
        self.source = source
        self.export_dir = export_dir
        self.debug_dir = debug_dir
        self.diffs = diffs
        super(MetricaRunner, self).__init__(**kwargs)

    def get_context_data(self, **kwargs):
        context = super(MetricaRunner, self).get_context_data(**kwargs)
        context.update(
            udf_url_enable=False,
            version=self.version,
            metrics=self.metrics,
            source=self.source,
            export_dir=self.export_dir,
            debug_dir=self.debug_dir,
            diffs=self.diffs,
        )
        return context

    @property
    def QUERY_TEMPLATE(self):  # N802 # noqa
        return (self.SEPARATE_TEMPLATE, self.SINGE_TEMPLATE)[self.IS_SINGLE]

    def get_libs(self):
        if self.IS_SINGLE:
            return []
        kwargs = self.get_context_data()

        def _libs():
            for index, metrica in enumerate(self.metrics, start=1):
                partial = super(BaseParser, self).render(
                    self.PARTIAL_TEMPLATE, metrica=metrica, loop_index=index, **kwargs
                )
                name = "lib_{metrica}_{index}.sql".format(metrica=metrica["prefix"], index=index)
                file_path = "/tmp/{}".format(name)
                with codecs.open(file_path, "w", encoding="utf-8") as ofile:
                    ofile.write(partial)

                yield {
                    "name": file_path if self.is_embedded else name,
                    "content": file_path,
                    "disposition": "filesystem",
                    "type": "library",
                }
                self._query = None

        return list(_libs())


class YtProcessor(object):
    def __init__(self, yt_path, schema, **kwargs):
        self.table_name = yt_path
        self.schema = schema

    def load_table(self):
        self._verify_schema()
        logger.info("Mount dynamic table [%s]", self.table_name)
        if self._get_table_state() == "unmounted":
            yt.mount_table(self.table_name, sync=True)

    def unload_table(self):
        yt.unmount_table(self.table_name, sync=True)

    def select(self, version, min_date):
        schema_fields = ",\n".join(
            map(
                lambda item: "MAX({item}) AS {item}".format(item=item),
                filter(
                    lambda item: item not in ("dt", "version", "field", "kind"),
                    iter(item["name"] for item in self.schema),
                ),
            )
        )
        return yt.select_rows(
            (
                """
                dt AS fielddate,
                version,
                kind,
                field AS fieldname,
                {fields}
            FROM [{table}]
            WHERE version = {version!r}
                AND dt > {min_date:"%Y-%m-%d"}
            GROUP BY dt, version, kind, field
        """
            ).format(
                table=self.table_name,
                fields=schema_fields,
                version=version,
                min_date=min_date,
            ),
            format="json",
        )

    def insert_data(self, data):
        yt_data = []
        for sub_data in data:
            if isinstance(sub_data, list):
                yt_data.extend(sub_data)
            else:
                dataframe = sub_data
                if not isinstance(sub_data, pd.DataFrame):
                    dataframe = sub_data.full_dataframe
                yt_data.extend(dataframe.where(pd.notnull(dataframe), None).to_dict(orient="records"))
        logger.debug("metrics data %s", yt_data)
        yt.insert_rows(self.table_name, yt_data, update=True, format="json")

    @wait()
    def _check_table_state(self, state="mounted"):
        return self._get_table_state() == state

    def _get_table_state(self):
        return yt.get("{0}/@tablets/0/state".format(self.table_name))

    def _verify_schema(self):
        """Check table schema (and try migrate if need)"""
        logger.info("Create yt table if not exists [%s]", self.table_name)
        schema = self.schema
        if not yt.exists(self.table_name):
            yt.create(
                "table",
                path=self.table_name,
                recursive=True,
                ignore_existing=True,
                attributes={
                    "schema": schema,
                    "dynamic": True,
                    "optimize_for": "scan",
                    "enable_dynamic_store_read": True,
                },
            )
        else:
            current_schema = list(yt.get("{0}/@schema".format(self.table_name)))
            current_schema_fields = set(item["name"] for item in current_schema)
            schema_fields = set(item["name"] for item in schema)
            if current_schema_fields != schema_fields:
                self.unload_table()
                yt.alter_table(self.table_name, schema, dynamic=True)


class MetricsProcessor(object):

    """Run metrics and make yt upload"""

    GENERATE_DATE_FIELD = "generate_date"

    def __init__(self, date, version, metrics, source, storage_conf, history_dir="", export_dir="", debug_dir=""):
        self.date = date
        self.min_date = datetime.datetime.strptime(self.date, "%Y-%m-%d") - datetime.timedelta(days=365 * 2)
        self.version = version
        self.metrics_yql = list(filter(lambda item: item["prefix"] != "exp_stats", metrics))
        self.metrics_exp = list(filter(lambda item: item["prefix"] == "exp_stats", metrics))
        self.source = source
        self.last_exception = None
        self.exception_counter = 0
        self.storage_conf = storage_conf
        self.history_dir = history_dir.rstrip("/")
        self.export_dir = export_dir
        self.debug_dir = debug_dir
        self.diffs = None
        self.init_diffs()

    def init_diffs(self):
        if not self.history_dir:
            return

        tables = filter(
            lambda table: "difflog" not in str(table),
            sorted(
                yt.search(self.history_dir, node_type="table", attributes=[self.GENERATE_DATE_FIELD]),
                key=lambda table: table.attributes.get(self.GENERATE_DATE_FIELD),
                reverse=True,
            ),
        )

        self.diffs = []
        new_generate_date = get_date(yt.get_attribute(self.source, self.GENERATE_DATE_FIELD))
        for table in tables:
            old_generate_date = get_date(table.attributes.get(self.GENERATE_DATE_FIELD))
            for period, offset in zip(("daily", "weekly", "monthly"), (1, 7, 30)):
                if check_offset_days(old_generate_date, new_generate_date, offset):
                    self.diffs.append({"old": str(table), "period": period})
                    break

    def run(self, **kwargs):
        if self.metrics_yql:
            for line in self.run_yql(**kwargs):
                yield line
        if self.metrics_exp:
            for line in self.run_exp(**kwargs):
                yield line

    def run_exp(self, **kwargs):
        logger.info("Start ExpStats metrics")
        for metric in self.metrics_exp:
            # exp-stats version in params
            task = MetricaRunnerExpStas(date=self.date, **metric)
            try:
                logger.info("Metric is running...")
                yield task.run()
            except Exception as err:
                self.exception_counter += 1
                self.last_exception = err
                logger.exception("ExpStats metric failed to run")
                print(traceback.format_exc())
        logger.info("Finish run ExpStats metrics")

    def validate_metrics(self, metrics=None, **kwargs):
        logger.info("Metrics are validating...")

        validated_metrics = []
        not_validated_metrics = []

        for metric in metrics:
            task = MetricaRunner(
                metrics=[metric],
                version=self.version,
                date=self.date,
                source=self.source,
                export_dir=self.export_dir,
                debug_dir=self.debug_dir,
                diffs=self.diffs,
                **kwargs
            )
            try:
                task.validate()
                validated_metrics.append(metric)
                logger.info("Metric '%s' is valid", metric["prefix"])
            except Exception:
                not_validated_metrics.append(metric)
                logger.warning("Metric '%s' is not valid (skipped)", metric["prefix"])
                print(traceback.format_exc())

        logger.info("Metric validation is complete")

        return validated_metrics, not_validated_metrics

    def run_yql(self, **kwargs):
        """
        1. Validate metrics.
        2. Try to run all valid metrics at one query.
        3. If there are not valid and/or not completed metrics then we run them one at a time.
        """
        validated_metrics, not_ready_metrics = self.validate_metrics(metrics=self.metrics_yql, **kwargs)

        if not validated_metrics:
            # raise Exception("No metric is valid")
            self.exception_counter += 1
            self.last_exception = Exception("No metric is valid")
            logger.exception("No metric is valid")
            return

        try:
            logger.info("Start all validated metrics combined query")
            yield next(self.run_all(metrics=validated_metrics, **kwargs))
            logger.info("Finish all validated metrics combined query")
        except StopIteration:
            not_ready_metrics += validated_metrics

        if not_ready_metrics:
            logger.info("Start Each metrics single query")
            for sub_result in self.run_each(metrics=not_ready_metrics, **kwargs):
                yield sub_result
            logger.info("Finish Each metrics single query")

    def run_all(self, metrics, **kwargs):
        """
        All metrics in one query
        Args:
            metrics: list of metrics from yaml config
            **kwargs:

        Returns:
        """
        log_metrics = ", ".join(m["prefix"] for m in metrics)
        logger.info("Metrics are running: %s", log_metrics)

        try:
            task = MetricaRunner(
                metrics=metrics,
                version=self.version,
                date=self.date,
                source=self.source,
                export_dir=self.export_dir,
                debug_dir=self.debug_dir,
                diffs=self.diffs,
                **kwargs
            )

            result = task.run()
            logger.info("Running metrics are finished: %s", log_metrics)

            yield result

        except Exception as err:
            self.exception_counter += 1
            self.last_exception = err
            logger.exception("Running metrics are failed: %s", log_metrics)
            print(traceback.format_exc())

    def run_each(self, metrics, **kwargs):
        """Each metric in separate query"""
        for metric in metrics:
            for result in self.run_all([metric], **kwargs):
                yield result

    def print_result(self, data):
        for line in data:
            try:
                df = line.full_dataframe
                df.head(df.shape[0])
            except AttributeError:
                logger.debug(line)

    def upload_on_yt(self, data):
        """Upload data on YT"""
        yp = YtProcessor(**self.storage_conf)
        yp.load_table()
        yp.insert_data(data)

    def select_all(self):
        """Read all data from YT"""
        yp = YtProcessor(**self.storage_conf)
        yp.load_table()
        return yp.select(self.version, self.min_date)


class PredictTimeSeries(object):

    """Select values from timeseries and make prediction via prophet"""

    kDEFAULT_PERIODS = 14
    kPREDICT_PERIODS = ((30, 10), (60, 10), (90, 15), (120, 20))
    kRE_PREDICT_MAX_COUNT = 2

    def __init__(self, storage_conf, key):
        """
        Key should be dict of {
            'version': <version>,
            'field': <field>,
            'kind': <kind>,
            'series': <fieldname of series data>, }
        """
        assert key["version"]
        assert key["field"]
        assert key["kind"]
        assert key["series"]

        self.key = key
        self.min_date = datetime.datetime.now() - datetime.timedelta(days=365 * 2)
        self.yp = YtProcessor(**storage_conf)

    def run(self, periods=kDEFAULT_PERIODS):
        series = self.select()
        forecast = self.predict(series, periods)
        self.commit(*forecast)

    def predict(self, series, periods):
        """Make prophet predict for varative time series"""
        rolling = self._predict_roll(series, periods)
        prophet = self._predict_ph(series, periods)

        def update(df_first, df_second):
            return df_first.append(df_second).groupby("ts").apply(lambda group: group.tail(1)).reset_index(drop=True)

        for size, cross in self.kPREDICT_PERIODS:
            try:
                predicted = self._predict_ph(series[-size:], periods)

                start = max(series.shape[0] - 2 * size + cross, 0)
                flag = bool(start)
                predict_counter = 0
                while (start >= 0) and (predict_counter < self.kRE_PREDICT_MAX_COUNT):
                    # predict_counter += 1  # todo: uncomment ++
                    try:
                        predicted = update(predicted, self._predict_ph(series[start : start + size], 0))  # noqa
                    except StatsBaseNoData as err:
                        logger.info(
                            "Skipped predict for [version=%s] [kind=%s] [field=%s] [size=%s] [start=%s]\n%s",
                            self.key["version"],
                            self.key["kind"],
                            self.key["field"],
                            str(size),
                            str(start),
                            err.message,
                        )
                    start = start - size + cross
                    if start < 0 and flag:
                        start = 0
                        flag = False

            except StatsBaseNoData as err:
                logger.info(
                    "Skipped predict for [version=%s] [kind=%s] [field=%s] [size=%s]\n%s",
                    self.key["version"],
                    self.key["kind"],
                    self.key["field"],
                    str(size),
                    err.message,
                )
                continue

            prophet = (
                prophet.set_index("ts")
                .join(
                    predicted.drop(columns=["dt"])
                    .add_suffix("_x_{size}".format(size=size))
                    .rename(columns={"ts_x_{size}".format(size=size): "ts"})
                    .set_index("ts"),
                    how="outer",
                )
                .reset_index()
            )

        return prophet, rolling

    def _predict_roll(self, series, periods):
        """Make rolling mean"""
        return (
            pd.DataFrame(
                [
                    series.ds.dt.strftime("%s").astype(int),
                    series.ds.dt.strftime("%Y-%m-%d"),
                    series.y.rolling(max(periods, self.kDEFAULT_PERIODS), min_periods=1).mean(),
                    series.y.rolling(max(periods, self.kDEFAULT_PERIODS), min_periods=1).std(),
                ],
                ["ts", "dt", "rolling_mean", "rolling_std"],
            )
            .T.sort_values(by=["ts", "dt"])
            .reset_index(drop=True)
        )

    def _predict_ph(self, series, periods):
        """Make prophet predict"""
        if series[series["y"].notnull()].shape[0] < 2:
            raise StatsBaseNoData(
                "Prophet cannot be used when number of rows (which are not null)"
                " in the data you're passing is less than 2."
            )

        model = Prophet(
            yearly_seasonality=False,
            weekly_seasonality=True,
            daily_seasonality=False,
            changepoint_range=0.9,
            changepoint_prior_scale=0.3,
            interval_width=0.75,
        )
        model.fit(series, iter=1000)

        future = model.make_future_dataframe(periods=periods)
        forecast = model.predict(future)

        df = forecast[["ds", "trend", "trend_lower", "trend_upper", "yhat", "yhat_lower", "yhat_upper"]]

        signif_changepoints = np.array([])

        if len(model.changepoints) > 0:
            threshold = 0.01
            signif_changepoints = model.changepoints[np.abs(np.nanmean(model.params["delta"], axis=0)) >= threshold]

        predicted = df.assign(
            ts=df.ds.dt.strftime("%s").astype(int),
            dt=df.ds.dt.strftime("%Y-%m-%d"),
            change_point=df.ds.isin(signif_changepoints).astype(float),
        ).drop(["ds"], axis=1)

        return predicted

    def select(self):
        data = yt.select_rows(
            (
                """
                ts AS ds,
                {self.key[series]} AS y
            FROM [{table}]
            WHERE version = {self.key[version]!r}
                AND field = {self.key[field]!r}
                AND kind = {self.key[kind]!r}
                AND dt > {self.min_date:"%Y-%m-%d"}
        """
            ).format(table=self.yp.table_name, self=self),
            format="json",
        )
        series = pd.DataFrame(list(data)).dropna()
        missing_data = np.arange(series.ds.min(), series.ds.max() + 1, 60 * 60 * 24)
        series = series.merge(pd.DataFrame(missing_data, columns=["ds"]), how="outer", on="ds")
        series.ds = pd.to_datetime(series.ds, unit="s")
        series = series.sort_values(by=["ds"])
        return series

    def commit(self, *dataframes):
        """Insert data into YT"""
        dump = []
        for df in dataframes:
            dump.extend(
                df.where(pd.notnull(df), None)
                .assign(version=self.key["version"], field=self.key["field"], kind=self.key["kind"])
                .to_dict(orient="records")
            )
        yt.insert_rows(self.yp.table_name, dump, update=True, format="json")
