#!/usr/bin/env python

import time
import psutil
import logging
import datetime as dt
import subprocess as sp
import collections

import bson
import pymongo
import requests
import concurrent.futures

from sandbox.services.base import service as base_service

from sandbox import common
from sandbox.common import utils
import sandbox.common.types.notification as ctn

from sandbox.yasandbox import controller

logger = logging.getLogger(__name__)


class Sensor(object):
    def __init__(self, name, labels, deriv=False):
        self.name = name
        self.labels = labels
        self.deriv = deriv

        self.prev_value = None
        self.curr_value = None
        self.ts = None

    def set_value(self, value, ts=None):
        self.prev_value = self.curr_value
        self.curr_value = value
        self.ts = ts or int(time.time())

    def json(self):
        r = {
            "ts": self.ts,
            "value": self.curr_value,
            "labels": dict({"sensor": self.name}, **self.labels)
        }

        if self.deriv:
            if isinstance(r["value"], dict):
                raise ValueError("Value is not a number for sensor {!r}. Data: {!r}".format(self.name, r))
            r["value"] -= self.prev_value

        return r


class Solomon(object):
    def __init__(self, project, cluster, service):
        self.project = project
        self.cluster = cluster
        self.service = service
        self.push_api = common.config.Registry().common.solomon.api + "/push?project={}&cluster={}&service={}".format(
            project, cluster, service
        )
        self.token = common.fs.read_settings_value_from_file(common.config.Registry().common.solomon.token)

        self._sensors = {}

    def _get_sensor_key(self, name, labels):
        key = frozenset(labels.items()).union({name})
        return key

    def set_value(self, name, value, labels=None, deriv=False, ts=None):
        if labels is None:
            labels = {}
        key = self._get_sensor_key(name, labels)

        if key not in self._sensors:
            self._sensors[key] = Sensor(name, labels, deriv)

        self._sensors[key].set_value(value, ts)

    def get_value_setter(self, default_labels, default_ts=None):
        def set_value(name, value, prefix=None, labels=None, deriv=False, ts=default_ts):
            if prefix:
                if isinstance(prefix, (tuple, list)):
                    prefix = "_".join(prefix)
                name = prefix + "_" + name

            if labels is None:
                labels = {}
            labels.update(default_labels)

            self.set_value(name, value, labels=labels, deriv=deriv, ts=ts)

        return set_value

    def json(self, ts):
        r = {
            "sensors": []
        }
        for sensor in self._sensors.values():
            if sensor.ts < ts:
                continue
            if sensor.deriv and (sensor.prev_value is None or sensor.prev_value > sensor.curr_value):
                continue
            if sensor.name in ("network_serviceExecutorTaskStats", "network_compression"):
                # TODO: SANDBOX-7137: Value contains `dict`
                continue
            try:
                r["sensors"].append(sensor.json())
            except Exception:
                logger.exception("Error processing sensor %r", sensor.name)

        return r


class ServerStatusReport(object):
    def __init__(self, server_status, top):
        self.server_status = server_status
        self.top = top

    def to_solomon(self, solomon, ts):
        s = self.server_status

        if "setName" not in s["repl"]:
            logger.error("Skip report from %s, not a replicaset member", s["host"])
            return

        host = s["host"].split(":", 1)[0]
        shard = s["repl"]["setName"]

        set_value = solomon.get_value_setter({"host": host, "shard": shard}, ts)

        # role
        is_master = 1 if s["repl"]["ismaster"] else 0
        set_value("isMaster", is_master, prefix="repl")

        # opcounters
        for k, v in s["opcounters"].items():
            set_value(k, v, prefix="opcounters", deriv=True)

        # network
        for k, v in s["network"].items():
            set_value(k, v, prefix="network", deriv=True)

        # asserts
        for k, v in s["asserts"].items():
            set_value(k, v, prefix="asserts", deriv=True)

        # connections
        for f in ("current", "available"):
            set_value(f, s["connections"][f], prefix="connections")
        set_value("totalCreated", s["connections"]["totalCreated"], prefix="connections", deriv=True)

        # globalLock
        for f in ("activeClients", "currentQueue"):
            for k, v in s["globalLock"][f].items():
                set_value(k, v, prefix=("globalLock", f))
        set_value("totalTime", s["globalLock"]["totalTime"] / 10.0**6, prefix="globalLock")

        # mem
        for f in ("resident", "virtual"):
            set_value(f, s["mem"][f], prefix="mem")

        # metrics.commands
        for k, v in s["metrics"]["commands"].items():
            if k == "<UNKNOWN>":
                continue
            if "failed" in v and "total" in v:
                set_value("failed", v["failed"], prefix=("metrics", "commands", k.strip("_")), deriv=True)
                set_value("total", v["total"], prefix=("metrics", "commands", k.strip("_")), deriv=True)

        # metrics.document
        for k, v in s["metrics"]["document"].items():
            set_value(k, v, prefix=("metrics", "document"), deriv=True)

        # metrics.operation
        for k, v in s["metrics"]["operation"].items():
            set_value(k, v, prefix=("metrics", "operation"), deriv=True)

        # metrics.cursor
        for k, v in s["metrics"]["cursor"]["open"].items():
            set_value(k, v, prefix=("metrics", "cursor", "open"))
        set_value("timedOut", s["metrics"]["cursor"]["timedOut"], prefix=("metrics", "cursor"), deriv=True)

        # top.totals
        for dbcoll, commands in self.top["totals"].items():
            if "." not in dbcoll:
                continue

            db, collection = dbcoll.split(".", 1)
            if db not in ("events2", "sandbox", "signal"):
                continue

            for command, v in commands.items():
                set_value(
                    "count", v["count"],
                    prefix=("top", "totals", command),
                    labels={"db": db, "collection": collection},
                    deriv=True
                )
                set_value(
                    "time", v["time"] / 10.0**6,
                    prefix=("top", "totals", command),
                    labels={"db": db, "collection": collection},
                    deriv=True
                )


class ReplicaSetStatusReport(object):
    def __init__(self, rs_status):
        self.rs_status = rs_status

    def to_solomon(self, solomon, ts):
        r = self.rs_status

        shard = r["set"]
        primary = None
        secondaries = []

        for member in r["members"]:
            if member["stateStr"] == "PRIMARY":
                primary = member
            else:
                secondaries.append(member)

        if primary is None:
            logger.warning("No primary for shard '%s'", shard)
            return

        primary_optime = primary.get("optime", {}).get("ts", None)
        if primary_optime is None:
            logger.warning("Primary for shard '%s' doesn't have optime: %s", shard, primary)
            return

        def get_host_for(member):
            return member["name"].split(":", 1)[0]

        set_value = solomon.get_value_setter({"shard": shard}, ts)
        set_value("replicationLag", 0, labels={"host": get_host_for(primary)})

        for member in secondaries:
            secondary_optime = member.get("optime", {}).get("ts", None)
            if secondary_optime is None:
                continue

            lag = (primary_optime.as_datetime() - secondary_optime.as_datetime()).total_seconds()
            set_value("replicationLag", max(lag, 0), labels={"host": get_host_for(member)})


class CurrentOpReport(object):
    def __init__(self, current_op):
        self.current_op = current_op

    def get_user_operations(self):
        # Filter internal mongodb operations
        for op in self.current_op["inprog"]:
            if "." not in op["ns"]:
                continue
            if op["ns"] == "local.oplog.rs":
                continue
            if op["ns"].split(".", 1)[1] == "$cmd":
                continue
            yield op

    def to_solomon(self, solomon, ts):
        class Ops(object):
            def __init__(self):
                self.count = 0
                self.duration = 0

        ops = collections.defaultdict(lambda: collections.defaultdict(Ops))

        for op in self.get_user_operations():
            db, collection = op["ns"].split(".", 1)
            ops[db][collection].count += 1
            ops[db][collection].duration += op.get("microsecs_running", 0)

        for db, stats in ops.items():
            for collection, op in stats.items():
                solomon.set_value(
                    "ops_count", op.count,
                    labels={"host": "cluster", "db": db, "collection": collection}
                )
                solomon.set_value(
                    "ops_duration", op.duration / 10.0**6,
                    labels={"host": "cluster", "db": db, "collection": collection}
                )


class StorageSizeReport(object):
    def __init__(self, stats):
        self.stats = stats

    def to_solomon(self, solomon, ts):
        for db, colls in self.stats.items():
            for collection, stats in colls.items():
                for k, v in stats.items():
                    solomon.set_value(
                        k, v,
                        labels={"db": db, "collection": collection, "host": "cluster"},
                        ts=ts
                    )


class OpLatenciesReport(object):
    def __init__(self, stats):
        self.stats = stats

    def _sum_histograms(self, histograms):
        ret = {
            "reads": collections.defaultdict(int),
            "writes": collections.defaultdict(int),
            "commands": collections.defaultdict(int),
        }

        for shard in histograms:
            for k in ret.keys():
                for v in shard["latencyStats"][k]["histogram"]:
                    # Convert to milliseconds
                    bucket = v["micros"] // 1000 if v["micros"] > 1000 else 1
                    ret[k][bucket] += v["count"]

        return ret

    def to_solomon(self, solomon, ts):
        for db, colls in self.stats.items():
            for collection, stats in colls.items():
                set_value = solomon.get_value_setter({"db": db, "collection": collection, "host": "cluster"}, ts)
                for k, hist in self._sum_histograms(stats).items():
                    for bucket, value in hist.items():
                        set_value(
                            "opLatency_{}".format(k), value,
                            labels={"bucket": "{} ms".format(bucket)},
                            deriv=True
                        )


class BaseReporter(object):
    def __init__(self, uri):
        self.client = pymongo.MongoClient(
            uri,
            socketTimeoutMS=5000,
            connectTimeoutMS=5000,
            serverSelectionTimeoutMS=5000,
        )

    def get_report(self):
        raise NotImplementedError


class ServerStatusReporter(BaseReporter):
    def get_report(self):
        return ServerStatusReport(
            self.client.admin.command("serverStatus"),
            self.client.admin.command("top"),
        )


class ReplicaSetStatusReporter(BaseReporter):
    def get_report(self):
        return ReplicaSetStatusReport(
            self.client.admin.command("replSetGetStatus"),
        )


class CurrentOpReporter(BaseReporter):
    def get_report(self):
        return CurrentOpReport(
            self.client.admin.command("currentOp")
        )


class StorageSizeReporter(BaseReporter):
    def _get_coll_stats(self, collstats):
        fields = ("count", "size", "avgObjSize", "storageSize", "totalIndexSize", "nindexes")
        return {k: collstats.get(k, 0) for k in fields}

    def get_report(self):
        stats = {
            db: {
                col: self._get_coll_stats(self.client[db].command("collstats", col))
                for col in self.client[db].collection_names()
            }
            for db in self.client.list_database_names()
        }
        return StorageSizeReport(stats)


class OpLatenciesReporter(BaseReporter):
    def get_report(self):
        stats = {
            db: {
                col: list(self.client[db][col].aggregate([{"$collStats": {"latencyStats": {"histograms": True}}}]))
                for col in self.client[db].collection_names()
            }
            for db in self.client.list_database_names()
        }
        return OpLatenciesReport(stats)


class MongoMonitor(base_service.SingletonService):
    """
    Report mongodb metrics and stats to Solomon.
    """

    tick_interval = 15

    def __init__(self, *args, **kwargs):
        super(MongoMonitor, self).__init__(*args, **kwargs)

        self._instances = list(self._get_instances())
        self._pool = concurrent.futures.ThreadPoolExecutor(max_workers=len(self._instances))

        s = self.sandbox_config.server.solomon.push
        self._solomon = Solomon(s.project, s.cluster, s.service)

        self._pusher = concurrent.futures.ThreadPoolExecutor(max_workers=4)
        self._session = requests.Session()
        self._session.headers["Content-Type"] = "application/json"
        if self._solomon.token:
            self._session.headers["Authorization"] = "OAuth " + self._solomon.token

    def _get_instances(self):
        for instance in self.service_config["mongod"]:
            yield ServerStatusReporter(instance)
            yield ReplicaSetStatusReporter(instance)

        if self.service_config["mongos"]:
            yield CurrentOpReporter(self.service_config["mongos"])
            yield StorageSizeReporter(self.service_config["mongos"])
            yield OpLatenciesReporter(self.service_config["mongos"])

    def _get_reports_from_instances(self):
        def _get_report_func(x):
            try:
                return x.get_report()
            except (bson.InvalidBSON, pymongo.errors.PyMongoError) as e:
                logger.error("Error: %s", str(e))
            return None

        for report in self._pool.map(_get_report_func, self._instances):
            if report is None:
                continue
            yield report

    def _push_to_solomon(self, url, payload, ts):
        logger.debug(
            "Push to Solomon... (total %d sensors, delay ~%ss)",
            len(payload["sensors"]), int(time.time()) - ts
        )
        # Solomon has default limit of 4000 sensors per request
        for chunk in utils.chunker(payload["sensors"], 3999):
            data = {
                "sensors": chunk,
            }
            try:
                r = self._session.post(url, json=data, timeout=30)
                r.raise_for_status()
            except requests.exceptions.HTTPError:
                logger.exception("...Failed to push sensors to Solomon")
                logger.error("Reason: %s", r.content)
            except Exception:
                logger.exception("...Failed to push sensors to Solomon")
        logger.debug("Push to Solomon... Done")

    def tick(self):
        ts = int(time.time())
        ts = ts - ts % self.tick_interval
        logger.debug("Collect metrics for %d", ts)
        for report in self._get_reports_from_instances():
            report.to_solomon(self._solomon, ts)
        payload = self._solomon.json(ts)
        # delay push, it can take 10 seconds and longer
        self._pusher.submit(self._push_to_solomon, self._solomon.push_api, payload, ts)

    def on_stop(self):
        # Release the lock here, we don't need it for sending data to Solomon
        # and don't want to block other workers from collecting data.
        self.lock.release()

        self._pool.shutdown(wait=True)
        self._pusher.shutdown(wait=True)


class MongoChecker(base_service.Service):
    """
    Get information about mongo instances and reboot secondary instances.

    Do not modify context in this service process.
    """
    tick_interval = 10
    WINDOW_SIZE = 18
    CPU_BORDER = 3500
    RELOAD_TRESHOLD = 5  # in minutes
    PIDS_TRESHOLD = 2  # in minutes
    CLIENT_UPDATE_TRESHOLD = 2  # in minutes

    CpuStatistics = collections.namedtuple("CpuStatistics", ("cpu", "name"))

    def __init__(self):
        super(MongoChecker, self).__init__()
        self.last_reload = dt.datetime.utcnow()
        self.last_client_update = None
        self.pids_updated = None
        self.processes = {}
        self.cpu_consumptions = collections.defaultdict(list)
        self.system_cpu_consumptions = []
        self.mongo_cpu_usage = {}
        self.state = None

    def load_processes(self, force=False):
        if (
            not force and
            self.pids_updated is not None and
            dt.datetime.utcnow() - self.pids_updated < dt.timedelta(minutes=self.PIDS_TRESHOLD)
        ):
            return
        self.state = controller.State.get(common.config.Registry().this.id)

        if self.state is None:
            logger.error("Can't load info about shards on server.", )
            self.pids_updated = None
            self.cpu_consumptions = collections.defaultdict(list)
            self.processes = {}
            return

        instances = []

        for shard in self.state.shards_info:
            if shard.info["stateStr"] == "SECONDARY" or shard.info["stateStr"] == "PRIMARY":
                instances.append(shard.name)

        self.processes = {}

        for name in instances:
            try:
                reporter = BaseReporter(name)
                status = reporter.client.admin.command("serverStatus")
                pid = int(status["pid"])
                self.processes[name] = psutil.Process(pid)
            except Exception as ex:
                logger.info("Can't load process info for %s. Exception: %s", name, ex)

        if force:
            self.cpu_consumptions = collections.defaultdict(list)
            self.mongo_cpu_usage = {}
        self.pids_updated = dt.datetime.utcnow()

    def reload_mongo_instance(self, name):
        if dt.datetime.utcnow() - self.last_reload < dt.timedelta(minutes=self.RELOAD_TRESHOLD):
            return
        try:
            output = sp.check_output(["/usr/bin/sudo", "service", "mongod_{}".format(name.split(":")[-1]), "restart"])
            logger.info("mongod instance %s restarted: %s", name, output)
            controller.Notification.save(
                transport=ctn.Transport.EMAIL,
                send_to=["sandbox-errors"],
                send_cc=[],
                subject="mongod instance {} restarted".format(name),
                body="{} has been restarted".format(name)
            )
            self.state.shards[name].shard_reloaded_ts = dt.datetime.utcnow()
        except Exception as ex:
            logger.error("Failed to restart mongod instance %s", name, exc_info=ex)
            controller.Notification.save(
                transport=ctn.Transport.EMAIL,
                send_to=["sandbox-errors"],
                send_cc=[],
                subject="mongod instance {} restart failed".format(name),
                body="{} has NOT been restarted with exception: {}".format(name, ex)
            )
        self.last_reload = dt.datetime.utcnow()
        self.load_processes(force=True)

    def update_cpu_usage(self):
        if len(self.system_cpu_consumptions) == self.WINDOW_SIZE:
            self.system_cpu_consumptions.pop(0)
        self.system_cpu_consumptions.append(psutil.cpu_percent(1.0))
        if not len(self.system_cpu_consumptions) == self.WINDOW_SIZE:
            return
        sum_cpu = 0
        for cpu_metric in self.system_cpu_consumptions:
            sum_cpu += cpu_metric
        now = dt.datetime.utcnow()

        if (
            self.state is not None and
            (
                self.last_client_update is None or
                now - self.last_client_update > dt.timedelta(minutes=self.CLIENT_UPDATE_TRESHOLD)
            )
        ):
            self.state.cpu_usage = sum_cpu / float(self.WINDOW_SIZE)
            self.state.cpu_updated = now
            self.state.save()
            self.state.reload()
            self.last_client_update = now

    def tick(self):
        self.update_cpu_usage()
        self.load_processes()
        check = True
        sum_cpu_consumption = {}

        for name, proc in self.processes.iteritems():
            cpu_consumption = self.cpu_consumptions[name]
            if len(cpu_consumption) == self.WINDOW_SIZE:
                cpu_consumption.pop(0)
            try:
                cpu_consumption.append(proc.cpu_percent(interval=1.0))
            except psutil.NoSuchProcess as ex:
                logger.error("Mongo process not found.", exc_info=ex)
                self.load_processes(force=True)
                return

            if len(cpu_consumption) < self.WINDOW_SIZE:
                check = False

            sum_cpu_value = 0
            for cpu_value in cpu_consumption:
                sum_cpu_value += cpu_value
            sum_cpu_consumption[name] = sum_cpu_value

        if not check:
            return

        cpu_statistics = [
            self.CpuStatistics(cpu / float(self.WINDOW_SIZE), name) for name, cpu in sum_cpu_consumption.iteritems()
        ]

        now = dt.datetime.utcnow()

        for cpu, name in cpu_statistics:
            self.state.shards[name].shard_cpu_usage = cpu
            self.state.shards[name].shard_cpu_updated = now

        cpu_statistics = filter(lambda x: self.state.shards[x[1]].info["stateStr"] == "SECONDARY", cpu_statistics)
        if len(cpu_statistics) < 2:
            return

        cpu_statistics.sort(reverse=True)
        if cpu_statistics[0].cpu > self.CPU_BORDER and cpu_statistics[0].cpu > cpu_statistics[1].cpu * 3:
            self.reload_mongo_instance(cpu_statistics[0].name)
