import socket
import pymongo
import collections
import bson.timestamp
import datetime as dt
import operator as op
import distutils.version

from sandbox import common
import sandbox.common.types.database as ctd

from sandbox.yasandbox.database import mapping


class Statistics(object):
    """Controller to update and get statistics of Sandbox tasks"""
    Model = mapping.Statistics
    Keys = mapping.Statistics.Keys
    Write = type("Write operation types", (common.utils.Enum,), {
        "INSERT": "insert",
        "REMOVE": "remove",
        "UPDATE": "update",
    })
    Read = type("Read operation types", (common.utils.Enum,), {
        "QUERY": "query",
        "GETMORE": "getmore"
    })
    # Database size statistics fields names
    DB_SIZE_FIELDS = ["avgObjSize", "dataSize", "fileSize", "indexSize"]

    @classmethod
    def get_statistics(cls, key):
        """
        Get statistics (amount) of Sandbox tasks and resources per status

        :param key: key to get corresponding statistics. Possible values are :py:attr:`cls.Keys`
        :return: dict with tasks statuses and resources states as keys and amount of entries in corresponding status
        :rtype: dict
        """

        with mapping.switch_db(mapping.Statistics, ctd.ReadPreference.SECONDARY) as Statistics:
            document = Statistics.objects(key=key).as_pymongo()
        return document[0]["data"] if document else {}

    @classmethod
    def current_db_operations_statistics(cls):
        """
        Get current MongoDB operations statistics

        :return: counter with following items:
                    current_op - amount of current operations
                    duration - total duration of executing current operations in seconds
                    read - amount of current read operations
                    write - amount of current write operations
                    write_lock - total time in milliseconds, that operations locked DB for write operation
        :rtype: collections.Counter
        """
        connection = mapping.get_connection()
        operations = connection.admin.current_op()["inprog"]
        cnt = collections.Counter()
        for operation in operations:
            if operation["active"]:
                cnt["current_op"] += 1
            if (
                # Exclude non-active operations
                not operation["active"] or
                # also exclude replications activity
                operation["desc"].startswith("repl writer worker") or operation["ns"] == "local.oplog.rs" or
                # and don"t forget to skip shards balancing operations.
                operation.get("msg", "") == "step 3 of 6"
            ):
                continue
            if operation["op"] in cls.Write:
                cnt["write"] += 1
            elif operation["op"] in cls.Read:
                cnt["read"] += 1

            cnt["duration"] += operation.get("microsecs_running", 0)
            ls = operation.get("lockStats", {})
            lc = ls.get("timeLockedMicros", {})
            la = ls.get("timeAcquiringMicros", {})
            cnt["write_lock"] += lc.get("w", 0) + la.get("w", 0)
            cnt["read_lock"] += lc.get("r", 0) + la.get("r", 0)
        for i in ("duration", "write_lock", "read_lock"):
            cnt[i] /= 1000000.0  # convert to seconds
        return cnt

    @classmethod
    def db_size(cls, scale=1):
        """
        Get default database size statistics:
          - avgobjsize: the average size of each document in bytes
          - datasize: the total size in bytes of the data held in this database including the padding factor
          - filesize: the total size in bytes of the data files that hold the database including preallocated space
                      and the padding factor
          - indexsize: the total size in bytes of all indexes created on this database

        :param scale: integer above 0 to scale byte values
        :return: db size statistics
        :rtype: dict
        """
        db = mapping.get_connection().get_default_database()
        result = db.command("dbstats", scale=scale)
        return {key.lower(): result[key] for key in cls.DB_SIZE_FIELDS}

    @classmethod
    def db_shards_status(cls):
        """
        Get database shards status

        :return: list of dicts with per shard statistics
        :rtype: list
        """
        result = []
        conn = mapping.get_connection()
        config_shards = [conn.admin.command("getShardMap")["map"]["config"]]
        data_shards = map(op.itemgetter("host"), conn.config.shards.find())
        for shard_hosts in common.itertools.chain(config_shards, data_shards):
            replicaset, hosts = shard_hosts.split("/")
            mongo_client = pymongo.MongoClient(hosts, replicaset=replicaset)
            replset_status = mongo_client.admin.command("replSetGetStatus")
            members = {
                member_status["_id"]: {
                    key: str(value) if isinstance(value, (dt.datetime, bson.timestamp.Timestamp)) else value
                    for key, value in member_status.items()
                }
                for member_status in replset_status["members"]
            }
            for member in members.values():
                member["name"] = cls._resolve_host(member["name"])
                member["syncSourceHost"] = cls._resolve_host(member["syncSourceHost"])
            replset_conf = mongo_client.admin.command("replSetGetConfig")
            for member_config in replset_conf["config"]["members"]:
                members[member_config["_id"]].update({
                    "votes": member_config["votes"],
                    "hidden": member_config["hidden"],
                    "priority": member_config["priority"],
                })
            result.append(
                {
                    "date": str(replset_status["date"]),
                    "replicaset": replicaset,
                    "ok": replset_status["ok"],
                    "members": list(members.values()),
                }
            )
        return sorted(result, key=lambda shard: distutils.version.LooseVersion(shard["replicaset"]))

    @classmethod
    def service_threads_timeline(cls):
        """
        Get service threads timeline from DB

        :return: All service threads names and their last and next run datetime
        :rtype: dict
        """
        res = {}
        with mapping.switch_db(mapping.Service, ctd.ReadPreference.SECONDARY) as Service:
            for item in Service.objects:
                res[item.name.lower()] = {
                    "last_run": str(item.time.last_run),
                    "next_run": str(item.time.next_run)
                }
        return res

    @staticmethod
    def _resolve_host(host):
        if not host:
            return host
        host, _, port = host.partition(":")
        resolved_host = socket.gethostbyaddr(host)[0]
        if port:
            resolved_host += ":" + port
        return resolved_host
