import logging
import requests
import collections
import datetime as dt

from sandbox import common
import sandbox.common.types.misc as ctm
import sandbox.common.types.client as ctc
import sandbox.common.types.notification as ctn

from sandbox.services import base

import sandbox.serviceq.client as qclient

from sandbox.yasandbox import controller
from sandbox.yasandbox.database import mapping

logger = logging.getLogger(__name__)


class CpuBalancer(base.SingletonService):
    """ Process information about cpu usage on servers and try to balance it """
    tick_interval = 120
    CPU_STATISTICS_TRESHOLD = 5  # in minutes
    MAXIMUM_DISABLED_SERVERS = 4
    LOWER_BOUND_CPU_USAGE = 90
    LOWER_BOUND_CPU_USAGE_TO_START = 50
    LOWER_BOUND_MONGO_CPU_USAGE = 3500
    LOWER_BOUND_SERVICEQ_CPU_USAGE = 60
    UNWANTED_CONTENDERS_PERCENT = 50

    class CpuUsage(object):
        CPU_MEASURES_LEN = 5  # Number of last measures to check

        def __init__(self):
            self.measures = collections.deque()
            self.cpu_usage = 0
            self.oauth_token = (
                common.utils.read_settings_value_from_file(common.config.Registry().server.auth.oauth.token)
                if common.config.Registry().server.auth.enabled else
                None
            )
            self.rest = common.rest.Client(
                auth=self.oauth_token, component=ctm.Component.SERVICE, total_wait=30
            )

        def add(self, measure):
            self.measures.append(measure)
            while len(self.measures) > self.CPU_MEASURES_LEN:
                self.measures.popleft()
            if len(self.measures) == self.CPU_MEASURES_LEN:
                cpu_sum = reduce(lambda a, b: a + b, self.measures)
                self.cpu_usage = cpu_sum / self.CPU_MEASURES_LEN

    @common.utils.singleton_property
    def qclient(self):
        return qclient.Client()

    def __init__(self, *args, **kwargs):
        super(CpuBalancer, self).__init__(*args, **kwargs)
        self.cpu_usages = {}

    def active_servers(self, servers):
        now = dt.datetime.utcnow()
        active_servers = {}

        for server in servers:
            cpu_updated = server.cpu_updated

            if cpu_updated is not None and cpu_updated >= now - dt.timedelta(minutes=self.CPU_STATISTICS_TRESHOLD):
                ok_shards = True
                for instance in server.shards_info:
                    if "stateStr" not in instance.info or "not reachable" in instance.info["stateStr"]:
                        ok_shards = False
                if ok_shards:
                    active_servers[server.name] = server
        logger.info("Active servers %s.", ", ".join(s for s in active_servers.iterkeys()))

        return active_servers

    def turn_server(self, server, state):
        logger.info("Turn api on server %s to %s.", server.name, state)
        server.api_enabled = state
        server.save()

    def updated_turned_off_servers(self, turned_off, maximum_disabled_servers):
        turned_on = []
        if len(turned_off) > maximum_disabled_servers:
            turned_on = list(turned_off.values())
            logger.info("Turn on all server")
        else:
            for server_name, server in turned_off.iteritems():
                sum_cpu = 0
                for name, instance in server.shards.iteritems():
                    sum_cpu += instance.shard_cpu_usage

                if (
                    sum_cpu < self.LOWER_BOUND_MONGO_CPU_USAGE and
                    server.cpu_usage < self.LOWER_BOUND_CPU_USAGE_TO_START
                ):
                    logger.info(
                        "Turn on server %s. It's cpu usage: %s, mongo usage: %s.",
                        server.name, server.cpu_usage, sum_cpu
                    )
                    turned_on.append(server)

        for server in turned_on:
            self.turn_server(server, True)
            del turned_off[server.name]

    def __check_serviceapi_availability(self, serviceapi_servers):
        non_active_count = 0
        for server in serviceapi_servers.itervalues():
            try:
                requests.get("http://{}:{}/health_check".format(
                    server, common.config.Registry().server.api.port), timeout=30
                )
            except Exception:
                non_active_count += 1
        return non_active_count

    def update_api_availability(self):
        servers = list(mapping.State.objects())
        active_servers = self.active_servers(servers)
        serviceapi_servers = {
            client.hostname: client.info.get("system", {}).get("fqdn", "")
            for client in mapping.Client.objects(tags=str(ctc.Tag.SERVER))
        }
        for server_name in (active_servers.viewkeys() - serviceapi_servers.viewkeys()):
            active_servers.pop(server_name)

        for server in servers:
            if server.name not in active_servers and not server.api_enabled:
                self.turn_server(server, True)

        turned_off = dict((server.name, server) for server in active_servers.itervalues() if not server.api_enabled)
        maximum_disabled_servers = (
            self.MAXIMUM_DISABLED_SERVERS - self.__check_serviceapi_availability(serviceapi_servers)
        )

        self.updated_turned_off_servers(turned_off, maximum_disabled_servers)

        sorted_servers = sorted(active_servers.values(), key=lambda server: server.cpu_usage, reverse=True)

        turn_off_count = maximum_disabled_servers

        for server in active_servers.itervalues():
            if not server.api_enabled:
                turn_off_count -= 1

        for server in sorted_servers:
            if not server.api_enabled:
                continue
            if server.cpu_usage > self.LOWER_BOUND_CPU_USAGE:
                if turn_off_count > 0:
                    turn_off_count -= 1
                    self.turn_server(server, False)

    def serviceq_unwanted_contenders(self):
        servers = list(mapping.State.objects())
        primary_servers = []
        high_load_servers = []
        now = dt.datetime.utcnow()
        contenders = self.qclient.contenders()
        primary = contenders[0].split(".", 1)[0]
        nodes = {node.split(".", 1)[0]: node.split(":", 1)[0] for node in contenders}
        logging.info("Got nodes from serviceq: %s", nodes.keys())
        for server in servers:
            if server.name not in nodes:
                continue
            self.cpu_usages.setdefault(server.name, self.CpuUsage()).add(server.cpu_usage)

            for shard in server.shards_info:
                if shard.info["stateStr"] == "PRIMARY":
                    primary_servers.append((server.name, self.cpu_usages[server.name].cpu_usage))
                    break
            else:
                if (
                    server.cpu_usage > self.LOWER_BOUND_SERVICEQ_CPU_USAGE and
                    server.cpu_updated > now - dt.timedelta(minutes=self.CPU_STATISTICS_TRESHOLD)
                ):
                    high_load_servers.append((server.name, self.cpu_usages[server.name].cpu_usage))

        high_load_servers.sort(key=lambda x: x[1], reverse=True)
        maximum_unwanted_contenders = int(len(nodes) / float(100) * self.UNWANTED_CONTENDERS_PERCENT)
        unwanted_contenders = primary_servers + high_load_servers
        logging.info(
            "Calculate unwanted contenders with limit %s: %s",
            maximum_unwanted_contenders, unwanted_contenders
        )
        unwanted_contenders = [contender[0] for contender in unwanted_contenders[:maximum_unwanted_contenders]]

        if primary in unwanted_contenders:
            if not self.context.get("qprimary_is_unwanted"):
                reason = "primary of MongoDB" if primary in map(lambda x: x[0], primary_servers) else "high CPU load"
                subject = "ServiceQ primary on host with {}".format(reason)
                body = "ServiceQ primary on host {} with {}".format(primary, reason)

                controller.Notification.save(
                    transport=ctn.Transport.EMAIL,
                    send_to=["sandbox-errors"],
                    send_cc=None,
                    subject=subject,
                    body=body
                )
            self.context["qprimary_is_unwanted"] = True
        else:
            self.context["qprimary_is_unwanted"] = False

        unwanted_contenders = [nodes[contender] for contender in unwanted_contenders]

        logging.info("Result unwanted contenders: %s", unwanted_contenders)
        self.qclient.set_unwanted_contenders(unwanted_contenders)

    def tick(self):
        controller.State.update_state_infromation()
        self.update_api_availability()
        self.serviceq_unwanted_contenders()
