import sys
import math
import logging
import operator as op
import datetime as dt
import calendar
import itertools
import collections

import six
from concurrent import futures

from sandbox import common
import sandbox.common.types.misc as ctm
import sandbox.common.types.task as ctt
import sandbox.common.types.client as ctc
import sandbox.common.types.database as ctd
import sandbox.common.types.resource as ctr
import sandbox.common.types.statistics as ctss

from sandbox.yasandbox.database import mapping
from sandbox.yasandbox.database.clickhouse import exceptions as ch_exceptions

from sandbox.services.modules.statistics_processor import schemas
from sandbox.services.modules.statistics_processor.processors import base
from sandbox.services.modules.statistics_processor.processors import errors

logger = logging.getLogger(__name__)


class ClickHouseProcessor(base.Processor):
    BATCH_SZ = 10000

    # tables with schema version newer than that of their schemas in local code
    # they are ignored to avoid corruption and/or runtime errors during insertion
    blacklist = set()

    @common.utils.singleton_property
    def db(self):
        return self.make_clickhouse_connection()

    @classmethod
    def initialize(cls, ctx):
        """
        Create necessary ClickHouse databases and tables once.

        :param ctx: a dictionary which represents this processor's context
        :return: updated context
        :rtype: dict
        """

        settings = common.config.Registry().server.services.statistics_processor.clickhouse
        db = cls.make_clickhouse_connection()
        context = dict(ctx)

        existing_databases = context.setdefault("dbs", {})
        if settings.database not in existing_databases:
            db.create_database()
            existing_databases[settings.database] = True

        existing_models = context.setdefault("models", {})
        for model_name, model_class in schemas.clickhouse.SIGNAL_MODELS.items():
            model_version = existing_models.get(model_name)

            if model_version is None:
                logger.info("%s: new, creating tables", model_name)
                db.create_table(model_class)

            elif model_version > model_class.schema_version:
                logger.info(
                    "%s: existing (NEW), blacklisting it (v%d > v%d)",
                    model_name, model_version, model_class.schema_version
                )
                cls.blacklist.add(model_name)
                continue

            else:
                model_class.init_auto_enums(db)
                if model_version < model_class.schema_version:
                    logger.info(
                        "%s: existing (OLD), renaming + creating (v%d < v%d)",
                        model_name, model_version, model_class.schema_version
                    )
                    db.rename_table(model_class)
                    db.create_table(model_class)

            existing_models[model_name] = model_class.schema_version

        return context

    @staticmethod
    def make_clickhouse_connection():
        from sandbox.yasandbox.database.clickhouse import database

        settings = common.config.Registry().server.services.statistics_processor.clickhouse

        return database.DistributedDatabase(
            settings.database,
            settings.cluster,
            settings.connection_url,
            username=settings.username,
            password=common.utils.read_settings_value_from_file(settings.password)
        )

    def process(self, signals, signal_type, _, timer):
        signal_type = filter(lambda ch: ch not in "_-", signal_type)
        if signal_type not in schemas.clickhouse.SIGNAL_MODELS:
            raise errors.UnknownSignal("Unknown data model {}".format(signal_type))

        for signal in signals:
            signal[self.DATE_FIELD] = signal[self.TIMESTAMP_FIELD]

        try:
            with timer["{} model instantiating".format(signal_type)]:
                model_class = schemas.clickhouse.SIGNAL_MODELS[signal_type]
            with timer["{} prepare to insertion".format(signal_type)]:
                items = [model_class(**signal) for signal in signals]
            with timer["{} insertion".format(signal_type)]:
                self.db.insert(items, batch_size=self.BATCH_SZ)
        except (ch_exceptions.DatabaseException, ValueError, TypeError):
            logger.error(
                "Failed to process %s signals of type %s. Example: %s",
                len(signals), signal_type, signals[:10],
            )
            _, value, tb = sys.exc_info()
            six.reraise(errors.ProcessorException, value, tb)


class AuditProcessor(ClickHouseProcessor):
    @classmethod
    def initialize(cls, ctx):
        pass

    # For this statuses duration will be 0
    UNINVESTIGATED_STATUSES = (
        (set(ctt.Status.Group.FINISH) | set(ctt.Status.Group.BREAK) | {ctt.Status.DRAFT}) - {ctt.Status.RELEASING}
    )

    def process(self, signals, signal_type, logger_, timer):
        """
        Select audit records for corresponding tasks from MongoDB and calculate stages durations
        (in-place modification)
        """

        with timer["loading audit"]:
            tids = set(map(op.itemgetter("task_id"), signals))
            with mapping.switch_db(mapping.Audit, ctd.ReadPreference.SECONDARY) as Audit:
                objects = list(
                    Audit.objects(
                        task_id__in=tids, status__exists=True
                    ).only("task_id", self.DATE_FIELD, "status")
                )
            objects.sort(key=op.attrgetter(self.DATE_FIELD), reverse=True)
            records = collections.defaultdict(list)
            for o in objects:
                records[o.task_id].append(o)

        with timer["processing"]:
            result_signals = []
            for signal in signals:
                current_timestamp, current_status = map(signal.get, (self.TIMESTAMP_FIELD, "status"))
                previous_status = None
                previous_timestamp = None
                for obj in records[signal["task_id"]]:
                    if obj.date < current_timestamp:
                        previous_timestamp = obj.date
                        previous_status = obj.status
                        break

                if previous_status is not None and previous_status not in self.UNINVESTIGATED_STATUSES:
                    new_signal = signal.copy()
                    new_signal["duration"] = (signal[self.TIMESTAMP_FIELD] - previous_timestamp).total_seconds()
                    new_signal[self.DATE_FIELD] = new_signal[self.TIMESTAMP_FIELD] = previous_timestamp
                    new_signal["status"] = previous_status
                    result_signals.append(new_signal)
                if current_status in self.UNINVESTIGATED_STATUSES:
                    signal["duration"] = 0
                    result_signals.append(signal)

        return super(AuditProcessor, self).process(result_signals, signal_type, logger_, timer)


class SessionCompletionProcessor(ClickHouseProcessor):
    PRODUCTION_PARAMS = {
        "cluster": "sandbox",
        "service": "sandbox",
    }
    PRE_PRODUCTION_PARAMS = {
        "cluster": "sandbox_1",
        "service": "sandbox_1",
    }
    CPU_PARAMS = {
        "partition": "load",
        "info": "cpu_usage"
    }
    RAM_PARAMS = {
        "partition": "total_rss",
    }
    DISK_AND_NETWORK_PARAMS = {
        "partition": "*",
        "info": "disk_io|net_io",
        "sensor": "disk_io_bytes|net_io_bytes"
    }

    CPU_RAM_FRACTIONS = {
        "p50": 0.5,
        "p75": 0.75,
        "p90": 0.9
    }
    DISK_NET_FRACTIONS = {
        "p90": 0.9
    }

    @classmethod
    def initialize(cls, ctx):
        pass

    @staticmethod
    def query_solomon(signal, host, req_params):
        auth_session = common.auth.TVMSession(common.tvm.TVM.get_service_ticket(["solomon"])["solomon"])
        rest_client = common.rest.Client(base_url=common.config.Registry().common.solomon.api, auth=auth_session)

        params = {
            "host": host
        }
        params.update(req_params)
        # format according to https://solomon.yandex-team.ru/docs/concepts/querying
        params_list = []
        for key, value in params.items():
            params_list.append('{}="{}"'.format(key, value))
        params_str = "drop_nan({" + ", ".join(params_list) + "})"

        request_body = {
            "from": base.DATETIME_CONVERTER.encode(signal["start"]),
            "to": base.DATETIME_CONVERTER.encode(signal["finish"]),
            "program": params_str
        }

        return rest_client.projects.sandbox.sensors.data.create(**request_body)

    def query_solomon_and_bundle_data(self, signal):
        host = signal.pop("host")
        signal_bundle = {}

        cpu_response = self.query_solomon(signal, host, self.CPU_PARAMS)
        signal_bundle["cpu"] = self.per_core_cpu_usage(dict(signal), host, cpu_response)

        ram_response = self.query_solomon(signal, host, self.RAM_PARAMS)
        signal_bundle["ram"] = self.ram_usage(dict(signal), ram_response)

        disk_network_response = self.query_solomon(signal, host, self.DISK_AND_NETWORK_PARAMS)
        signal_bundle["disk_io"] = self.disk_network_usage(
            dict(signal),
            disk_network_response,
            "disk_io"
        )
        signal_bundle["network_io"] = self.disk_network_usage(
            dict(signal),
            disk_network_response,
            "net_io"
        )

        return signal_bundle

    @staticmethod
    def percentiles(signal, values, labeled_fractions):
        sorted_values = sorted(values)

        if sorted_values:
            percentile_values = [common.utils.percentile(sorted_values, fraction) for fraction in
                                 labeled_fractions.values()]
        else:
            percentile_values = [0] * len(labeled_fractions)

        signal.update(zip(labeled_fractions.keys(), percentile_values))

        return signal

    def per_core_cpu_usage(self, signal, host, response):
        time_series_values = itertools.chain.from_iterable(
            sensor["timeseries"]["values"] for sensor in response["vector"]
        ) if "vector" in response else []

        with mapping.switch_db(mapping.Client, ctd.ReadPreference.SECONDARY) as Client:
            objects = list(
                Client.objects(
                    hostname=host
                ).only("hardware")
            )

        total_nb_cores = objects[0].hardware.cpu.cores

        def nb_active_cores_from_percentage(value):
            return math.ceil(total_nb_cores * value / 100.0)

        values = [nb_active_cores_from_percentage(item) for item in time_series_values]

        return self.percentiles(signal, values, self.CPU_RAM_FRACTIONS)

    def ram_usage(self, signal, response):
        time_series_values = itertools.chain.from_iterable(
            sensor["timeseries"]["values"] for sensor in response["vector"]
        ) if "vector" in response else []

        return self.percentiles(signal, time_series_values, self.CPU_RAM_FRACTIONS)

    def disk_network_usage(self, signal, response, info_label):
        time_series_values = list(
            itertools.chain.from_iterable(
                sensor["timeseries"]["values"]
                for sensor in response["vector"]
                if sensor["timeseries"]["labels"]["info"] == info_label
            ) if "vector" in response else []
        )
        signal["total_bytes"] = int(sum(time_series_values))
        return self.percentiles(signal, time_series_values, self.DISK_NET_FRACTIONS)

    @staticmethod
    def filter_signals(signals):
        host_names = map(op.itemgetter("host"), signals)

        with mapping.switch_db(mapping.Client, ctd.ReadPreference.SECONDARY) as Client:
            clients = set(
                Client.objects(
                    q_obj=Client.tags_query(~(ctc.Tag.MULTISLOT | ctc.Tag.Group.OSX)),
                    hostname__in=host_names
                ).scalar("hostname")
            )

        return [s for s in signals if s["host"] in clients]

    def process_signals_concurrently(self, signals):
        with futures.ThreadPoolExecutor(max_workers=20) as pool:
            return [_ for _ in pool.map(self.query_solomon_and_bundle_data, signals)]

    def process(self, signals, signal_type, logger_, timer):
        """
        Select audit records for corresponding tasks from MongoDB and query Solomon about resources usage
        (in-place modification)
        """

        with timer["filtering"]:
            signals = self.filter_signals(signals)

        is_production = common.config.Registry().common.installation == ctm.Installation.PRODUCTION
        auth_params = self.PRODUCTION_PARAMS if is_production else self.PRE_PRODUCTION_PARAMS

        self.CPU_PARAMS.update(auth_params)
        self.RAM_PARAMS.update(auth_params)
        self.DISK_AND_NETWORK_PARAMS.update(auth_params)

        with timer["processing"]:
            signal_bundles = self.process_signals_concurrently(signals)

        for batch, signal_subtype in (
            ([_["cpu"] for _ in signal_bundles], ctss.SignalType.TASK_CPU_USAGE),
            ([_["ram"] for _ in signal_bundles], ctss.SignalType.TASK_RAM_USAGE),
            ([_["disk_io"] for _ in signal_bundles], ctss.SignalType.TASK_DISK_IO_USAGE),
            ([_["network_io"] for _ in signal_bundles], ctss.SignalType.TASK_NETWORK_IO_USAGE),
        ):
            super(SessionCompletionProcessor, self).process(batch, signal_subtype, logger_, timer)


class Solomon(object):
    MAX_BATCH_SIZE = 10000

    def __init__(self, project, cluster, service):
        self.project = project
        self.cluster = cluster
        self.service = service
        self.url_query = "/push?project={}&cluster={}&service={}".format(project, cluster, service)
        self.token = common.fs.read_settings_value_from_file(common.config.Registry().common.solomon.token)
        self.rest_client = common.rest.Client(base_url=common.config.Registry().common.solomon.api, auth=self.token)
        self.metrics = []

    def add_metric(self, signal):
        self.metrics.append(signal)

    def __send_signals_batch(self, batch):
        self.rest_client[self.url_query].create(metrics=batch)

    def send_signals(self):
        for start in range(0, len(self.metrics), self.MAX_BATCH_SIZE):
            self.__send_signals_batch(self.metrics[start:start + self.MAX_BATCH_SIZE])
        self.metrics = []


class SolomonConsumption(ClickHouseProcessor):
    PRODUCTION_PARAMS = {
        "cluster": "sandbox_consumption",
        "service": "sandbox_consumption",
    }
    PRE_PRODUCTION_PARAMS = {
        "cluster": "sandbox_1_consumption",
        "service": "sandbox_1_consumption",
    }

    @common.patterns.singleton_property
    def solomon(self):
        is_production = common.config.Registry().common.installation == ctm.Installation.PRODUCTION
        auth_params = self.PRODUCTION_PARAMS if is_production else self.PRE_PRODUCTION_PARAMS
        return Solomon("sandbox", auth_params["cluster"], auth_params["service"])

    @classmethod
    def initialize(cls, ctx):
        pass

    def process_signal(self, signal):
        pass

    def signals_aggregation(self, signals):
        return signals

    def process(self, signals, signal_type, logger_, timer):
        for signal in self.signals_aggregation(signals):
            self.process_signal(signal)

        self.solomon.send_signals()
        super(SolomonConsumption, self).process(signals, signal_type, logger_, timer)


class MdsConsumption(SolomonConsumption):
    METRICS_DELTA = 1  # in minutes

    def signals_aggregation(self, signals):
        bucket_stats = collections.defaultdict(list)
        for signal in signals:
            bucket_stats[signal["mds_bucket_name"] or ctr.DEFAULT_S3_BUCKET].append(signal)

        result = []
        for bucket_signals in bucket_stats.itervalues():
            bucket_signals.sort(key=lambda _: _["timestamp"])
            date = bucket_signals[0]["timestamp"] - dt.timedelta(minutes=self.METRICS_DELTA + 1)
            for signal in bucket_signals:
                if signal["timestamp"] - date >= dt.timedelta(minutes=self.METRICS_DELTA):
                    result.append(signal)
                    date = signal["timestamp"]
        return result

    def process_signal(self, signal):
        data = {
            "labels": {
                "sensor": "mds_quota",
                "mds_bucket_name": signal["mds_bucket_name"] or ctr.DEFAULT_S3_BUCKET,
                "type": "consumption",
            },
            "value": signal["mds_bucket_used"],
            "ts": int(calendar.timegm(signal["timestamp"].timetuple()))
        }
        self.solomon.add_metric(data)
        data = {
            "labels": {
                "sensor": "mds_quota",
                "mds_bucket_name": signal["mds_bucket_name"] or ctr.DEFAULT_S3_BUCKET,
                "type": "quota",
            },
            "value": signal["mds_bucket_max_size"],
            "ts": int(calendar.timegm(signal["timestamp"].timetuple()))
        }
        self.solomon.add_metric(data)


class ApiQuotaConsumption(SolomonConsumption):
    def process_signal(self, signal):
        sensor = signal["source"] + "_quota"
        data = {
            "labels": {
                "sensor": sensor,
                "owner": signal["login"],
                "type": "consumption",
            },
            "value": signal["consumption"],
            "ts": int(calendar.timegm(signal["timestamp"].timetuple()))
        }
        self.solomon.add_metric(data)
        data = {
            "labels": {
                "sensor": sensor,
                "owner": signal["login"],
                "type": "quota",
            },
            "value": signal["quota"],
            "ts": int(calendar.timegm(signal["timestamp"].timetuple()))
        }
        self.solomon.add_metric(data)
