import time
import logging
import requests
import collections
import datetime as dt
import operator as op

from sandbox import common
import sandbox.common.types.client as ctc
import sandbox.common.types.misc as ctm

from sandbox.services import base
from sandbox.services.modules.metrics_reporter import solomon

logger = logging.getLogger(__name__)


class ClientAvailabilityChecker(base.SingletonService):
    """
    Service thread for checking dead clients.
    """

    tick_interval = 180
    fresh_limit = 1000
    reset_clients_limit = 2000  # max number of clients to reset per time window
    reset_clients_time_window = 60 * 60  # (sec)
    TOTAL_TAG = "total"

    def __init__(self, *args, **kwargs):
        super(ClientAvailabilityChecker, self).__init__(*args, **kwargs)
        if self.sandbox_config.server.auth.enabled:
            auth = common.utils.read_settings_value_from_file(self.sandbox_config.server.auth.oauth.token)
        else:
            auth = None
        self.rest = common.rest.Client(auth=auth, component=ctm.Component.SERVICE)

        self._solomon = None

    def _add_sensor(self, name, tag, value):
        self._solomon.add_sensor(name, value, labels={"purpose_tag": tag})

    @staticmethod
    def _client_os(client_info):
        # Windows clients report linux OS because of executing in WSL. Therefore, checking the tag is required.
        if ctc.Tag.WINDOWS in client_info["tags"]:
            return ctm.OSFamily.WIN_NT
        return client_info["os"]["name"]

    def send_dead_clients_report(self, dead, restarted_tasks):
        dead_client_counts = collections.defaultdict(lambda: collections.defaultdict(int))
        restarted_task_counts = collections.defaultdict(int)
        service_tags = set(ctc.Tag.Group.SERVICE)
        tags_to_consider = set(ctc.Tag.Group.PURPOSE) | set(ctc.Tag.Group.DENSITY)
        for c in dead:
            host_tags = set(c["tags"])
            if host_tags & service_tags:
                continue
            restarted = sum(t["host"] == c["id"] for t in restarted_tasks)
            restarted_task_counts[self.TOTAL_TAG] += restarted

            os = self._client_os(c)

            dead_client_counts[os][self.TOTAL_TAG] += 1
            for tag in host_tags & tags_to_consider:
                dead_client_counts[os][tag] += 1
                restarted_task_counts[tag] += restarted

        for os, counts in dead_client_counts.items():
            for tag, deads_amount in counts.items():
                self._add_sensor("dead_clients_{}".format(os), tag, deads_amount)

        for tag in list(ctc.Tag.Group.PURPOSE) + [self.TOTAL_TAG]:
            self._add_sensor("restarted_tasks", tag, restarted_task_counts[tag])

    def _adjust_fresh_by_limit(self, fresh_ids):  # type: (list) -> list
        now = int(time.time())
        reset_history = self.context.get("reset_history", [])  # list of pairs: [timestamp, fresh_dead_clients_number]
        while reset_history and reset_history[0][0] < now - self.reset_clients_time_window:
            del reset_history[0]
        reset_limit = self.reset_clients_limit - sum(map(op.itemgetter(1), reset_history))
        if len(fresh_ids) > reset_limit:
            logger.warning(
                "Too many dead clients lately (limit is %d per %dm), %d clients will be reset later",
                self.reset_clients_limit, self.reset_clients_time_window / 60, len(fresh_ids) - reset_limit
            )
            fresh_ids = fresh_ids[:reset_limit]
        reset_history.append([now, len(fresh_ids)])
        self.context["reset_history"] = reset_history
        return fresh_ids

    def find_dead_clients(self, send_report=False):
        # if the run was delayed, there was potential downtime
        # in that case we must not mark clients as dead right away
        offset = max(int((dt.datetime.utcnow() - self._model.time.next_run).total_seconds()), 0)
        known_ids = self.context.get("known", [])

        dead_clients = {
            c["id"]: c for c in self.rest.client.read(
                alive=False, alive_offset=offset,
                limit=len(known_ids) + self.fresh_limit
            )["items"]
        }
        fresh_ids = list(dead_clients.viewkeys() - known_ids)
        if fresh_ids:
            fresh_ids = self._adjust_fresh_by_limit(fresh_ids)
            logger.info("Resetting clients: %s", fresh_ids)
            restarted = self.rest.client(fresh_ids)
            payload = dict(
                id=fresh_ids,
                comment="Reload and clear dead clients after reset (by {})".format(type(self).__name__)
            )
            self.rest.batch.clients.reload.update(payload)
            self.rest.batch.clients.cleanup.update(payload)
            if send_report:
                self.send_dead_clients_report(dead_clients.values(), restarted)
        else:
            logger.info("No new dead clients")
        self.context["known"] = list(dead_clients)
        return self.context["known"]

    def find_idle_clients(self):
        new_clients = collections.defaultdict(int)
        clients_in_service = collections.defaultdict(int)

        for client_info in self.rest.client.read(tags="NEW | MAINTENANCE", limit=1000)["items"]:
            tags = client_info["tags"]
            os = self._client_os(client_info)
            if ctc.Tag.MAINTENANCE in tags:
                clients_in_service[os] += 1
            if ctc.Tag.NEW in tags:
                new_clients[os] += 1

        for os, cnt in new_clients.items():
            self._add_sensor("new_clients_{}".format(os), self.TOTAL_TAG, cnt)
        for os, cnt in clients_in_service.items():
            self._add_sensor("clients_in_service_{}".format(os), self.TOTAL_TAG, cnt)

    def tick(self):
        self._solomon = solomon.SolomonClient()

        self.find_dead_clients(send_report=True)
        self.find_idle_clients()

        try:
            self._solomon.send_data()
        except requests.HTTPError as error:
            logger.exception(error)
