import logging
import datetime as dt
import collections

import pymongo

import sandbox.common.types.task as ctt
import sandbox.common.types.client as ctc
import sandbox.common.types.statistics as ctss

from sandbox.services import base
from sandbox.yasandbox import controller
from sandbox.yasandbox.database import mapping as mp

logger = logging.getLogger(__name__)


class ClientSlotsMonitor(base.SingletonService):
    """
    Service thread for uploading clients info to Clickhouse.
    """

    tick_interval = 30

    def __init__(self, *a, **kw):
        super(ClientSlotsMonitor, self).__init__(*a, **kw)

    @staticmethod
    def get_clients_info():
        clients_info = []
        clients_mp_mapping = {}
        for client_mp in mp.Client.objects.read_preference(pymongo.ReadPreference.SECONDARY).all():
            if client_mp.hostname in clients_mp_mapping:
                if client_mp.updated > clients_mp_mapping[client_mp.hostname].updated:
                    clients_mp_mapping[client_mp.hostname] = client_mp
            else:
                clients_mp_mapping[client_mp.hostname] = client_mp

        utcnow = dt.datetime.utcnow()
        sessions = list(mp.OAuthCache.objects().fast_scalar("source", "task_id", "state"))
        clients_slots = collections.Counter()

        for source, task_id, state in sessions:
            if source and task_id and state != ctt.SessionState.ABORTED and ":" in source:
                clients_slots[source.split(":", 1)[1]] += 1

        for client_mp in clients_mp_mapping.values():
            system_data = client_mp.info["system"]

            invalid_fields = []
            for field in [
                "total_slots", "used_slots", "ncpu", "total_space", "used_space_value", "reserved_space", "disk_status"
            ]:
                if system_data.get(field) is None:
                    invalid_fields.append(field)
            if invalid_fields:
                logger.error("%s has invalid values for system info fields %s", client_mp.hostname, invalid_fields)
                continue

            service_activity = ctc.ServiceActivity.NONE
            if ctc.Tag.MAINTENANCE in client_mp.tags_set:
                service_activity = ctc.ServiceActivity.MAINTAIN
            elif ctc.ReloadCommand.CLEANUP in controller.Client.pending_service_commands(client_mp):
                service_activity = ctc.ServiceActivity.CLEANUP
            elif controller.Client.pending_service_commands(client_mp) & {
                ctc.ReloadCommand.REBOOT, ctc.ReloadCommand.RESTART, ctc.ReloadCommand.SHUTDOWN
            }:
                service_activity = ctc.ServiceActivity.RELOAD

            density_tags = filter(
                lambda t: t in ctc.Tag.Group.DENSITY,
                client_mp.pure_tags
            )
            used_slots = clients_slots.get(client_mp.hostname, 0)
            if ctc.Tag.MULTISLOT in client_mp.tags and used_slots:
                used_slots = system_data["used_slots"]
            clients_info.append(dict(
                type=ctss.SignalType.CLIENT_INFO,
                date=utcnow,
                timestamp=utcnow,
                client_id=client_mp.hostname,
                slots_total=system_data["total_slots"],
                slots_used=used_slots,
                service_activity=service_activity,
                ncpu=system_data["ncpu"],
                ram=client_mp.hardware.ram >> 10,
                disk_total=system_data["total_space"] >> 10,
                disk_used=system_data["used_space_value"] >> 10,
                disk_reserved=system_data["reserved_space"] >> 10,
                disk_status=system_data["disk_status"],
                purpose_tag=next(
                    iter(filter(
                        lambda t: t in ctc.Tag.Group.PURPOSE and t != ctc.Tag.POSTEXECUTE,
                        client_mp.pure_tags
                    )),
                    "NONE"
                ),
                tags=client_mp.tags,
                alive=client_mp.alive,
                density_tags=density_tags,
            ))
        return clients_info

    def tick(self):
        clients_info = self.get_clients_info()
        logging.debug("%d clients in total", len(clients_info))
        self.signaler.push(clients_info)
