import datetime
import logging
import yaml

from abc import ABCMeta, abstractmethod
from retry import retry

import yt.wrapper as yt

from crypta.graph.metrics.stats_base.lib import YtProcessor, upload_to_solomon
from crypta.lib.python.yql_runner.base_parser import BaseParser
from library.python import resource

FORMAT = "%(asctime)s\t%(levelname)s\t%(message)s"
logging.basicConfig(format=FORMAT)
logger = logging.getLogger(__name__)
logger.setLevel(logging.INFO)


def get_last_table(history_dir):
    for table in sorted(yt.list(history_dir, absolute=True), key=str, reverse=True):
        return table


class ValidationError(Exception):
    def __init__(self, check_name, check_value, threshold, message=None):
        self.check_name = check_name
        self.check_value = check_value
        self.threshold = threshold
        self.message = message

    def get_default_message(self):
        return "{} is {:.2f} DID NOT PASS threshold {:.2f}".format(
            self.check_name,
            self.check_value,
            self.threshold,
        )

    def __str__(self):
        return self.message or self.get_default_message()


class Metric(object):
    def __init__(self, kind, field, series, threshold):
        """
        Args:
            kind: str
            field: str
            series: str, e. g. absolute, percentage
            threshold: float
        """
        self.kind = kind
        self.field = field
        self.series = series
        self.threshold = threshold

    def __str__(self):
        return "{}:{}".format(self.kind, self.field)

    def check(self, result):
        """
        Failed if it's a bad metric
        """
        if result < self.threshold:
            raise ValidationError(str(self), result, self.threshold)
        else:
            logger.info("{} is {:.2f} PASSED threshold {:.2f}".format(str(self), result, self.threshold))


class BaseYQLChecker(object):
    __metaclass__ = ABCMeta

    def __init__(self):
        # dict with all needed variables for appropriate yql script
        self.yql_context_data = {}

    @abstractmethod
    def validate(self, yql_result):
        """
        Failed if there are at least one bad metric
        Args:
            yql_result: pd.DataFrame
        """


class IdsStabilityChecker(BaseYQLChecker):
    KIND_ = "stable"

    def __init__(self, metrics, checking_table, generate_date, last_valid_table):
        """
        Args:
            metrics: list, the element of which may be instance of Metric
            checking_table: str
            last_valid_table: str
        """
        logger.info("STARTED IdsStabilityChecker")

        if not last_valid_table:
            logger.error("last_valid_table is needed, but it's empty")

        self.checking_table = checking_table
        self.generate_date = generate_date
        self.last_valid_table = last_valid_table

        logger.info("Checking table is   %s", self.checking_table)
        logger.info("Last valid table is %s", self.last_valid_table)

        self.metrics = metrics
        self.yql_context_data = {
            "source": self.checking_table,
            "stable": {
                "new": self.checking_table,
                "old": self.last_valid_table,
                "period": "last",
                "kind": self.KIND_,
                "fields": [m.field for m in self.metrics],
            }
        }
        self.skip = False

        if generate_date == yt.get_attribute(last_valid_table, "generate_date"):
            logger.warning("The same generated_date %s for checking table and last valid table", generate_date)
            self.skip = True

    def _get_results(self, yql_result):
        results = []

        for metric in self.metrics:
            try:
                results.append(yql_result.loc[metric.kind, metric.field][metric.series])
            except KeyError:
                logger.exception("SKIPPED metric %s", metric)

        return results

    def validate(self, yql_result):
        """
        Failed if there are at least one bad metric
        Args:
            yql_result: pd.DataFrame
        """
        if self.skip:
            logger.info("SKIPPED metrics: %s", ", ".join(map(str, self.metrics)))
        else:
            results = self._get_results(yql_result)
            for metric, result in zip(self.metrics, results):
                metric.check(result)

        logger.info("FINISHED IdsStabilityChecker")


class YQLChecker(BaseParser):
    QUERY_TEMPLATE = "/templates/yql/validation_metrics.sql.j2"

    def __init__(self, generate_date, metrics, yt_proxy="hahn", yt_pool="crypta_graph", upload_metrics=True):
        """
        Args:
            generate_date: str, ISO-format of date
            metrics: list, the element of which may be instance of BaseYQLChecker
            yt_proxy: str
            yt_pool: str
            upload_metrics: bool, whether upload data to YT, Statface and Solomon or not
        """
        super(YQLChecker, self).__init__(generate_date, yt_proxy, yt_pool, is_embedded=False)
        self.generate_date = generate_date
        self.metrics = metrics
        self.upload_metrics = upload_metrics

    def get_libs(self):
        return []

    def get_context_data(self, **kwargs):
        kwargs["version"] = "v2"
        kwargs["date"] = self.generate_date
        for metric in self.metrics:
            if {"version", "date"}.intersection(metric.yql_context_data):
                raise ValueError("don't overlap the names of YQL template variables")
            kwargs.update(metric.yql_context_data)

        return super(YQLChecker, self).get_context_data(**kwargs)

    def _get_yql_result(self):
        yql_result = self.run()
        yql_result = yql_result.full_dataframe

        if yql_result.empty:
            raise RuntimeError("there are no results from YQL query")

        return yql_result

    def _upload(self, yql_result):
        """
        Upload data to YT and Solomon
        Args:
            yql_result: DataFrame
        """
        storage_conf = yaml.full_load(resource.find("/configs/storage.yaml"))

        yp = YtProcessor(**storage_conf)
        yp.load_table()
        yp.insert_data([yql_result])
        logger.info("UPLOADED metrics to YT")

        data = yp.select(
            version="v2",
            min_date=datetime.datetime.strptime(self.generate_date, "%Y-%m-%d") - datetime.timedelta(days=30),
        )

        try:
            upload_to_solomon(storage_conf["solomon"], data)
            logger.info("UPLOADED metrics to Solomon")
        except:
            logger.exception("Solomon upload error")

    def validate(self):
        """
        Failed if there are at least one bad metric
        """
        if all(m.skip for m in self.metrics):
            for metric in self.metrics:
                logger.info("SKIPPED metrics: %s", ", ".join(map(str, metric.metrics)))
            return None

        yql_result = self._get_yql_result()

        if self.upload_metrics:
            self._upload(yql_result)

        yql_result.set_index(["kind", "field"], inplace=True)

        for metric in self.metrics:
            metric.validate(yql_result)


class Publisher(object):
    def __init__(self, input_, output, checks, skip_checks=False):
        """
        Args:
            input_: str, YT path to directory
            output: str, YT path to directory
            checks: list, the element of which may be instance of YQLChecker
        """
        self.input = input_
        self.output = output
        self.checks = checks
        self.skip_checks = skip_checks

    def validate(self):
        for check in self.checks:
            check.validate()

    @retry(tries=3, delay=1)
    def publish(self):
        """
        Copy matching tables from workdir/output to matching directory.
        """
        if not yt.exists(self.output):
            yt.mkdir(self.output, recursive=True)

        with yt.Transaction():
            tables = []
            for table_name in yt.list(self.input):
                input_table = yt.ypath_join(self.input, table_name)
                output_table = yt.ypath_join(self.output, table_name)
                yt.copy(input_table, output_table, force=True)
                tables.append(output_table)

            logger.info("Copied matching tables:\n\t%s", "\n\t".join(tables))

    def run(self):
        if self.skip_checks:
            logger.info("SKIPPED all checks")
        else:
            self.validate()

        self.publish()
