import logging
import datetime as dt
import operator as op
import itertools as it
import collections

from sandbox.deploy import juggler

from sandbox import common
import sandbox.common.types.misc as ctm
import sandbox.common.types.task as ctt
import sandbox.common.types.notification as ctn

from sandbox.yasandbox import controller
from sandbox.yasandbox.database import mapping

from sandbox.services import base

logger = logging.getLogger(__name__)


class TaskStatusChecker(base.SingletonService):
    STATUSES_TO_CHECK = [ctt.Status.STOPPING, ctt.Status.FINISHING, ctt.Status.ASSIGNED]
    MAX_STATUS_SWITCH_DELAY = 300  # in seconds
    MAX_STATUS_WITHOUT_SESSION_DELAY = 600  # in seconds
    MAX_TIME_IN_TRANSIENT_STATUS_WARN = 3600  # in seconds
    MAX_TIME_IN_TRANSIENT_STATUS_CRIT = 7200  # in seconds
    TRANSIENT_CRIT_AMOUNT_THRESHOLD = 3  # threshold in tasks, between WARN and CRIT
    WARNING_RETRY_INTERVAL = 3600  # in seconds
    CHUNK_SIZE = 500

    tick_interval = 300

    def __send_warning(self, subjects, message):
        logger.warning(message)
        controller.Notification.save(
            transport=ctn.Transport.EMAIL,
            send_to=["sandbox-errors"],
            send_cc=None,
            subject=subjects,
            body=message
        )

    def __send_audit_consistency_warning(self, task_id, status, audit_status, audit_date):
        message = "Status of task #{} ({}) does not match the last audit record ({} at {})".format(
            task_id, status, audit_status, audit_date
        )
        subject = "[{}] Task status inconsistency detected".format(self.name)
        self.__send_warning(subject, message)

    def check_audit_consistency(self):
        """ Checks consistency of task status and last record in the audit """

        now = dt.datetime.utcnow()
        checkpoint = now - dt.timedelta(seconds=self.MAX_STATUS_SWITCH_DELAY)
        logger.info("Fetching tasks in statuses %s with update time older than %s", self.STATUSES_TO_CHECK, checkpoint)
        sent_warnings = self.context.setdefault("sent_warnings", {})
        statuses = dict(mapping.Task.objects(
            execution__status__in=self.STATUSES_TO_CHECK,
            time__updated__lt=checkpoint
        ).scalar("id", "execution__status"))
        tids_to_recheck = {int(tid) for tid in sent_warnings.iterkeys()} - statuses.viewkeys()
        if tids_to_recheck:
            statuses.update(mapping.Task.objects(
                id__in=list(tids_to_recheck)
            ).scalar("id", "execution__status"))
        logger.info("Found %s task(s)", len(statuses))
        if not statuses and not sent_warnings:
            return
        retry_deadline = now - dt.timedelta(seconds=self.WARNING_RETRY_INTERVAL)
        for tids_chunk in common.utils.chunker(statuses.keys(), self.CHUNK_SIZE):
            audit = sorted(
                filter(
                    lambda _: _[2],
                    mapping.Audit.objects(task_id__in=tids_chunk).scalar("task_id", "date", "status")
                ),
                reverse=True
            )
            for task_id, audit_items in it.groupby(audit, key=lambda _: _[0]):
                _, audit_date, audit_status = audit_items.next()
                status = statuses[task_id]
                str_task_id = str(task_id)
                if audit_status == status:
                    sent_warnings.pop(str_task_id, None)
                    continue
                sent_date = sent_warnings.get(str_task_id)
                if sent_date and sent_date > retry_deadline:
                    continue
                self.__send_audit_consistency_warning(task_id, status, audit_status, audit_date)
                sent_warnings[str_task_id] = now

    def check_session_consistency(self):
        """ Checks consistency of task in execution status and task session """
        now = dt.datetime.utcnow()
        checkpoint = now - dt.timedelta(seconds=self.MAX_STATUS_WITHOUT_SESSION_DELAY)
        task_ids_statuses = list(mapping.Task.objects(
            execution__status__in=set(ctt.Status.Group.EXECUTE) - {ctt.Status.TEMPORARY, ctt.Status.SUSPENDED},
            time__updated__lt=checkpoint
        ).fast_scalar("id", "execution__status"))

        tids_without_session = []
        tids_with_expired_session = []

        for tids_statuses in common.utils.chunker(task_ids_statuses, self.CHUNK_SIZE):
            tids_statuses = dict(tids_statuses)
            sessions = mapping.OAuthCache.objects(task_id__in=tids_statuses.keys())
            tids_without_session_chunk = tids_statuses.viewkeys() - set(_.task_id for _ in sessions)
            tids_without_session.extend(tids_without_session_chunk)
            for tid in tids_without_session_chunk:
                if tids_statuses[tid] == ctt.Status.ASSIGNED:
                    logger.warning("Task #%s in status %s without session", tid, ctt.Status.ASSIGNED)
                    controller.Task.set_status(
                        controller.Task.get(tid),
                        ctt.Status.ENQUEUED,
                        expected_status=ctt.Status.ASSIGNED,
                        event="Get back to the queue due to lack of session"
                    )
            for session in sessions:
                if session.ttl and session.created + dt.timedelta(seconds=session.ttl) < checkpoint:
                    tids_with_expired_session.append(session.task_id)

        if tids_without_session or tids_with_expired_session:
            status = ctm.JugglerCheckStatus.WARNING
            if len(tids_without_session) + len(tids_with_expired_session) > 1:
                status = ctm.JugglerCheckStatus.CRITICAL
        else:
            status = ctm.JugglerCheckStatus.OK

        logger.info(
            "Juggler check status: %s (%s, %s)", status, len(tids_without_session), len(tids_with_expired_session)
        )
        if tids_without_session:
            logger.info("Tasks without sessions: %s", tids_without_session)
        if tids_with_expired_session:
            logger.info("Tasks with expired sessions: %s", tids_with_expired_session)
        message = (
            "OK"
            if status == ctm.JugglerCheckStatus.OK else
            "{} session consistency failure(s): lost sessions {}, expired sessions {}".format(
                len(tids_without_session) + len(tids_with_expired_session),
                tids_without_session[:10],
                tids_with_expired_session[:10],
            )
        )
        getattr(juggler.TaskSessionConsistency, status.lower())(message)

    def check_stalled_transient_statuses(self):
        """ Checks tasks stalled in transient statuses """
        now = dt.datetime.utcnow()
        checkpoint = now - dt.timedelta(seconds=self.MAX_TIME_IN_TRANSIENT_STATUS_WARN)
        stalled = list(mapping.Task.objects(
            execution__status__in=self.STATUSES_TO_CHECK,
            time__updated__lt=checkpoint
        ).fast_scalar("id", "execution__status", "time__updated"))

        status = ctm.JugglerCheckStatus.OK
        critical_updated_threshold = now - dt.timedelta(seconds=self.MAX_TIME_IN_TRANSIENT_STATUS_CRIT)
        if (
            len(stalled) >= self.TRANSIENT_CRIT_AMOUNT_THRESHOLD or
            stalled and min(map(op.itemgetter(2), stalled)) < critical_updated_threshold
        ):
            status = ctm.JugglerCheckStatus.CRITICAL
        elif stalled:
            status = ctm.JugglerCheckStatus.WARNING

        logger.info(
            "Juggler check status: %s (%s task(s) in transient statuses)", status, len(stalled)
        )
        if stalled:
            stalled_by_statuses = collections.defaultdict(list)
            for tid, task_status, _ in stalled:
                stalled_by_statuses[task_status].append(tid)
            details = []
            for task_status, tids in stalled_by_statuses.items():
                logger.warning("Tasks in status %s: %s", task_status, tids)
                details.append("{}: {} ({})".format(task_status, len(tids), tids[:10]))
            message = "Tasks in transient statuses: {}".format(", ".join(details))
        else:
            message = "OK"
        getattr(juggler.TasksTransientStatuses, status.lower())(message)

    def tick(self):
        self.check_audit_consistency()
        self.check_session_consistency()
        self.check_stalled_transient_statuses()
