import socket
import logging
import itertools as it
import collections

import six

from sandbox import common
import sandbox.common.types.task as ctt
import sandbox.common.types.notification as ctn

from sandbox.services import base

from sandbox.yasandbox import controller
from sandbox.yasandbox.database import mapping
import sandbox.serviceq.errors as qerrors

from sandbox import deploy


logger = logging.getLogger(__name__)


class CheckSemaphores(base.SingletonService):
    """
    Find and release semaphores that was supposed to be released
    """
    Task = collections.namedtuple("Task", "status semaphores")

    tick_interval = 300

    def get_semaphores(self):
        try:
            return dict(controller.TaskQueue.qclient.semaphores())
        except (qerrors.QException, socket.error):
            return False

    def tick(self):
        not_released = set()
        inconsistent_semaphores = set()
        active_sessions = set(mapping.OAuthCache.objects(
            task_id__exists=True,
            state__in=list(ctt.SessionState.Group.ACTIVE)
        ).fast_scalar("task_id"))

        with common.utils.Timer() as timer:
            db_sems = {sem.id: sem for sem in mapping.Semaphore.objects()}
            q_sems = common.itertools.progressive_waiter(0, 1, 30, self.get_semaphores)[0]
            q_sems_index = {sem.name: sid for sid, sem in six.iteritems(q_sems)}

            db_extra_sids = six.viewkeys(db_sems) - six.viewkeys(q_sems)
            q_extra_sids = six.viewkeys(q_sems) - six.viewkeys(db_sems)

            q_tids = set(it.chain.from_iterable(s.tasks for s in six.itervalues(q_sems)))

            nonactive_tasks = {
                t[0]: self.Task(*t[1:])
                for t in mapping.Task.objects(
                    id__in=list(q_tids - active_sessions)
                ).fast_scalar("id", "execution__status", "requirements__semaphores")
            }

            if db_extra_sids:
                logger.warning("Extra semaphores in DB (%s): %s", len(db_extra_sids), list(db_extra_sids))

            if q_extra_sids:
                logger.warning("Extra semaphores in Q (%s): %s", len(q_extra_sids), list(q_extra_sids))

            for sid, sem in six.iteritems(q_sems):
                sem_value = sum(sem.tasks.itervalues())
                if sem_value != sem.value:
                    logger.error(
                        "Detected inconsistency for semaphore %s (%s), value is %s instead of %s",
                        sem.name, sid, sem_value, sem.value
                    )
                    inconsistent_semaphores.add(sid)

            for tid, task in six.iteritems(nonactive_tasks):
                task_sem = task.semaphores
                if task_sem is None:
                    logger.warning("Releasing removed semaphores for task #%s", tid)
                    try:
                        if not controller.TaskQueue.qclient.release_semaphores(tid, task.status, task.status):
                            logger.error("Semaphores for task #%s not released", tid)
                    except (qerrors.QException, socket.error):
                        logger.exception("Can't release semaphore for task %s", tid)
                    continue

                release_statuses = set(ctt.Status.Group.expand(task_sem["release"], ctt.Status.DELETED))

                for acquire in task_sem["acquires"]:
                    if task_sem and task.status in release_statuses:
                        sid = q_sems_index[acquire["name"]]
                        weight = q_sems[sid].tasks.get(tid, None)
                        logger.warning(
                            "Semaphore %s (#%s) still occupied by task #%s in status %s with weight %s",
                            acquire["name"], sid, tid, task.status, weight
                        )
                        not_released.add(tid)

        still_not_released = set(self.context.get("not_released", [])) & not_released
        still_inconsistent_semaphores = set(
            self.context.get("inconsistent_semaphores", [])
        ) & inconsistent_semaphores
        still_db_extra_sids = set(self.context.get("db_extra_sids", [])) & db_extra_sids
        still_q_extra_sids = set(self.context.get("q_extra_sids", [])) & q_extra_sids

        errors = []
        warnings = []

        if still_not_released:
            cache = None
            for tid in list(still_not_released):
                task = nonactive_tasks[tid]
                task_status = None if task.status == ctt.Status.DELETED else task.status
                try:
                    if controller.TaskQueue.qclient.release_semaphores(tid, task_status, task_status):
                        logger.info("Semaphores for task #%s successfully released", tid)
                        still_not_released.remove(tid)
                    else:
                        if cache is None:
                            cache = controller.TaskQueue.HostsCache()

                        task = mapping.Task.objects.with_id(tid)
                        hosts = controller.TaskQueue.task_hosts_list(
                            task, cache
                        )[1]

                        try:
                            controller.TaskQueue.add(
                                task, hosts, controller.Task.client_tags(task), task.score or 0
                            )
                        except qerrors.QException as ex:
                            logger.error("Error while adding task #%s to the queue: %s", tid, ex)
                        if controller.TaskQueue.qclient.release_semaphores(
                            tid, task.execution.status, task.execution.status
                        ):
                            message = "Semaphores for task #{} forcedly released".format(tid)
                            logger.warning(message)
                            controller.Notification.save(
                                transport=ctn.Transport.EMAIL,
                                send_to=["sandbox-errors"],
                                send_cc=None,
                                subject="[check_semaphores] Cannot release semaphore",
                                body=message
                            )
                            still_not_released.remove(tid)
                except Exception:
                    logger.exception("Can't release semaphore for task %s", tid)

            if still_not_released:
                errors.append("{} task(s) not released semaphore(s)".format(len(still_not_released)))

        if still_inconsistent_semaphores:
            errors.append("{} semaphore(s) in inconsistent state".format(len(still_inconsistent_semaphores)))
        if still_db_extra_sids:
            errors.append("DB has {} extra semaphore(s)".format(len(still_db_extra_sids)))
        if still_q_extra_sids:
            warnings.append("Q has {} extra semaphore(s)".format(len(still_q_extra_sids)))

        if errors:
            deploy.juggler.CheckSemaphoresConsistency.critical("; ".join(errors + warnings))
        elif warnings:
            deploy.juggler.CheckSemaphoresConsistency.warning("; ".join(warnings))
        else:
            deploy.juggler.CheckSemaphoresConsistency.ok("All semaphores in consistent state")

        self.context["not_released"] = list(not_released)
        self.context["inconsistent_semaphores"] = list(inconsistent_semaphores)
        self.context["db_extra_sids"] = list(db_extra_sids)
        self.context["q_extra_sids"] = list(q_extra_sids)

        logger.info(
            "Semaphores checked,"
            " found %s task(s) that not released semaphore(s),"
            " %s semaphore(s) with wrong value. Elapsed time: %s",
            len(not_released), len(inconsistent_semaphores), timer
        )
