import logging
import datetime as dt
import itertools as it
import collections

import concurrent.futures

from sandbox import common
import sandbox.common.types.task as ctt
import sandbox.common.types.misc as ctm
import sandbox.common.types.resource as ctr
import sandbox.common.types.statistics as ctss

from sandbox.services import base

from sandbox.yasandbox import controller
from sandbox.yasandbox.database import mapping

logger = logging.getLogger(__name__)


AuditEntry = collections.namedtuple("AuditEntry", ("id", "date", "task_id", "status"))


class TasksAuditPosition(common.patterns.Abstract):
    __slots__ = [
        "last_check_time", "checked_audit_ids",
        "audit_objects"
    ]
    __defs__ = [None] * 3

    @classmethod
    def make(cls, service):
        context = service.audit_position or service.context  # see comment at TaskStateSwitcher.audit_position

        last_check_time = context.get("last_check_time", dt.datetime.utcnow() - dt.timedelta(minutes=30))
        min_check_time = last_check_time - service.time_delta
        checked_audit_ids = set(context.get("checked_audit_ids", []))

        with common.utils.Timer() as t:
            # Use pymongo, because it's much faster when dealing with lots of objects
            audit_objects = [
                AuditEntry(d["_id"], d["dt"], d["ti"], d.get("st", None))
                for d in mapping.Audit.objects(date__gt=min_check_time).as_pymongo().limit(1000000)
            ]

            if len(audit_objects) > 10000:
                logger.warning(
                    "Selected %d Audit rows from mongo, last_check_time %s, checked_audit_ids size %d",
                    len(audit_objects),
                    last_check_time,
                    len(checked_audit_ids)
                )

            # Get already processed entries
            checked_audit_ids &= {a.id for a in audit_objects}
            # Filter out already processed entries
            audit_objects = [a for a in audit_objects if a.id not in checked_audit_ids]

        logger.info(
            "Found %d audit records created from %s UTC till now (took %0.2fs)",
            len(audit_objects), min_check_time, t.secs
        )

        return cls(
            last_check_time, checked_audit_ids,
            audit_objects
        )

    def dump(self):
        return dict(
            checked_audit_ids=self.checked_audit_ids,
            last_check_time=self.last_check_time,
        )


class TaskStateSwitcher(base.ThreadedService):
    """
    Service to put tasks in WAIT_* statuses back into the queue; checks all kinds of task triggers on a regular basis.
    """

    notification_timeout = 5
    tick_interval = 10
    expired_chunk_size = 20
    wait_task_gc_delay = dt.timedelta(minutes=15)

    def __init__(self, *args, **kwargs):
        self.max_subworkers = kwargs.pop("max_subworkers", 30)
        super(TaskStateSwitcher, self).__init__(*args, **kwargs)

        # for backward compatibility with old TaskStateSwitcher to avoid various races
        self.zk_name = type(self).__name__

        # on every restart, triggers for WAIT_TASK will be checked within
        # [<service last run time> - self.time_delta, <now>] time range
        self.time_delta = dt.timedelta(seconds=60)

        # workaround for proper audit position saving from child threads
        self.audit_position = {}

        self.pool = concurrent.futures.ThreadPoolExecutor(max_workers=self.max_subworkers)

        # Keep track of currently enqueueing tasks to exclude them from "gc" check
        self._wait_task_triggers_in_progress = {}  # trigger.source -> future

        self.oauth_token = (
            common.utils.read_settings_value_from_file(common.config.Registry().server.auth.oauth.token)
            if common.config.Registry().server.auth.enabled else
            None
        )

    @property
    def _rest(self):
        return common.rest.ThreadLocalCachableClient(
            auth=self.oauth_token, component=ctm.Component.SERVICE, total_wait=180
        )

    @property
    def targets(self):
        return [
            self.Target(self.check_wait_res),
            self.Target(self.check_wait_time),
            self.Target(self.check_wait_output),
            self.Target(self.check_expired),

            # Run as frequently as possible
            self.Target(self.check_wait_task, interval=1),

            # Check triggers against target tasks status in case audit got lost
            self.Target(self.wait_task_garbage_collector, interval=60),
        ]

    @staticmethod
    def __clean_triggers(tid, max_time=None):
        query = {"source": tid}
        if max_time is not None:
            query["ctime__lte"] = max_time

        for trigger_type in (
            mapping.TimeTrigger,
            mapping.TaskOutputTrigger,
            mapping.TaskStatusTrigger,
        ):
            result = trigger_type.objects(**query).delete()
            if result:
                logger.debug(
                    "Removed %d %r trigger(s) for task #%d created before %r",
                    result, trigger_type.__name__, tid, max_time
                )

    def enqueue(self, model, event, trigger=None):
        """
        Asynchronously enqueue a task (override this in tests for a different behaviour)

        :param model: task model
        :param event: event description to be written into task audit
        :param trigger: task trigger, due to which it is to be enqueued
        """

        return self.pool.submit(self._enqueue, model, event, trigger)

    def _enqueue(self, model, event, trigger):
        utcnow = dt.datetime.utcnow()
        if trigger:
            fired_at = (
                trigger.time
                if isinstance(trigger, mapping.TimeTrigger) else
                getattr(trigger, "fired_at", None)
            )
            if fired_at:
                self.signaler.push(dict(
                    type=ctss.SignalType.TASK_WAIT_DELAY,
                    date=utcnow,
                    timestamp=utcnow,
                    owner=model.owner,
                    task_id=model.id,
                    task_type=model.type,
                    status=model.execution.status,
                    delay=int((utcnow - fired_at).total_seconds()),
                ))

        prev_status = model.execution.status
        try:
            # TODO: use X-Current-User header [SANDBOX-7893]
            result = self._rest.batch.tasks.start.update(id=[model.id], comment=event)[0]
            # result["message"] is in common.types.misc.BatchResultStatus
            logging_function = getattr(logger, result["status"].lower(), logger.debug)
            logging_function("Task #%d's enqueue result: %s", model.id, result["message"])
        except common.rest.Client.HTTPError as exc:
            logger.error("Failed to enqueue task #%d: %s", model.id, exc)
        else:
            model.reload()
            if model.execution.status != prev_status:
                self.__clean_triggers(model.id, utcnow)

    @staticmethod
    def _result_message(checked, enqueued):
        return "scheduled for processing {} / {} checked".format(enqueued, checked)

    def _fire_wait_task_trigger(self, trigger, task):
        logger.info(
            "Task #%s (%s): switching from WAIT_TASK to ENQUEUING. Trigger: %r",
            task.id, task.type, dict(trigger.to_mongo())
        )

        if trigger.wait_all:
            event = "All tasks are ready"
        else:
            event = "One of tasks is ready"

        future = self.enqueue(task, event, trigger)
        self._wait_task_triggers_in_progress[trigger.source] = future

        def on_done(_):
            self._wait_task_triggers_in_progress.pop(trigger.source)
        future.add_done_callback(on_done)

    def _check_fired_wait_task_triggers(self):
        # Get all activated triggers with empty targets,
        # but filter out triggers that are currently in queue for enqueueing
        triggers = list(mapping.TaskStatusTrigger.objects(
            activated=True,
            targets__0__exists=False,
            source__nin=list(self._wait_task_triggers_in_progress.keys()),
        ))
        if not triggers:
            return

        # Load all tasks in bulk
        tasks_query = mapping.Task.objects(id__in={t.source for t in triggers})
        task_objects = {t.id: t for t in tasks_query.only("id", "type", "execution__status", "owner")}

        logger.info(
            "Re-checking triggers with empty targets: %s",
            [t.source for t in triggers]
        )
        for trigger in triggers:
            task = task_objects.get(trigger.source, None)
            if task is None:
                logger.error(
                    "Task #%d does not exist, remove trigger with token %s",
                    trigger.source, trigger.token
                )
                trigger.delete()
                continue

            if task.execution.status == ctt.Status.WAIT_TASK:
                # Probably service has crashed between updating trigger and task enqueueing.
                # Check that it's still the same trigger and enqueue task
                trigger.reload()
                if not trigger.targets:
                    self._fire_wait_task_trigger(trigger, task)
                else:
                    # Trigger has been updated since method call
                    logger.warning(
                        "Trigger #%d has been updated since method call: %s",
                        trigger.source, dict(trigger.to_mongo())
                    )
            else:
                # Probably trigger wasn't deleted after task enqueueing.
                # But it's possible that task has already reached WAIT_TASK again and updated the trigger.
                # Therefore delete trigger only if `targets` == []
                result = mapping.TaskStatusTrigger.objects(source=trigger.source, targets__0__exists=False).delete()
                if result:
                    logger.info(
                        "Trigger #%d was deleted because task is already '%s'",
                        trigger.source, task.execution.status
                    )
                else:
                    # Either trigger is already deleted somehow or
                    # it has been updated by the task reaching WAIT_TASK again.
                    trigger.reload()
                    task.reload()

                    logger.warning(
                        "Trigger #%d wasn't deleted. Task is '%s', trigger is: %s",
                        trigger.source, task.execution.status, dict(trigger.to_mongo())
                    )

    def _get_fired_triggers_by_func(self, triggers, has_fired):
        # List of triggered triggers
        triggered = []

        for trigger in triggers:
            # Set of tasks that have reached awaited status
            tids, fired_at = [], []

            for target in trigger.targets:
                local_fired_at = has_fired(trigger, target)
                if local_fired_at:
                    tids.append(target)
                    fired_at.append(local_fired_at)

            if not tids:
                continue

            logger.info("Trigger #%s: these tasks are ready: %s", trigger.source, tids)

            if trigger.wait_all:
                update_kwargs = {
                    "pull_all__targets": tids,
                    "fired_at": max(fired_at),
                }
            else:
                update_kwargs = {
                    "targets": [],
                    "fired_at": min(fired_at),
                }

            trigger_doc = mapping.TaskStatusTrigger.objects(source=trigger.source, token=trigger.token).first()
            if trigger_doc:
                trigger_doc.update(**update_kwargs)
                trigger_doc = (
                    mapping.TaskStatusTrigger
                    .objects(source=trigger.source, token=trigger.token)
                    .read_preference(mapping.ReadPreference.PRIMARY)
                    .first()
                )
            else:
                logger.warning(
                    "Trigger #%s/%s has been deleted during processing (probably task was stopped or deleted)",
                    trigger.source, trigger.token
                )
                continue

            # Check if trigger still has something to wait for
            # Also skip not activated triggers (they'll be processed as soon as they're activated)
            if trigger_doc.targets or not trigger_doc.activated:
                continue

            # Trigger has fired
            triggered.append(trigger_doc)

        return triggered

    def _check_triggers_with_tailed_audit(self, audit_objects):
        # Load all affected triggers (even not activated -- they will be filtered later)
        triggers = list(mapping.TaskStatusTrigger.objects(targets__in={_.task_id for _ in audit_objects}))
        if not triggers:
            return []

        logger.debug(
            "[check_wait_task] Check %d triggers (%r), %d audit records",
            len(triggers),
            [t.source for t in triggers],
            len(audit_objects),
        )

        # Group audit entries by task_id for efficient iteration
        audit_by_task = collections.defaultdict(list)
        for audit in audit_objects:
            audit_by_task[audit.task_id].append(audit)

        def func(trigger, target):
            for a in audit_by_task.get(target, []):
                if a.status in trigger.statuses:
                    return a.date
            return None

        return self._get_fired_triggers_by_func(triggers, has_fired=func)

    def check_wait_task(self):
        """
        General algorithm:

        1. Load new unprocessed audit entries and check affected triggers
        NOTE: Update `targets` field during check: remove tasks that reached awaited status
        2. Get all triggers with targets==[] and asynchronously enqueue corresponding tasks

        We repeat these steps without waiting for tasks to be enqueued (and triggers deleted).
        We can't hit trigger twice, because we always update `targets` field first.

        Unfortunately service can fail between trigger update and task start.
        To handle this case we add an additional step:

        0. Get all triggers with targets==[] and enqueue corresponding tasks

        """

        # Check if there are too many tasks queued for enqueueing
        if len(self._wait_task_triggers_in_progress) > 128:
            return "Too many jobs in queue"

        # Check triggers with empty targets first
        self._check_fired_wait_task_triggers()

        # Load new audit entries and check affected triggers
        state = TasksAuditPosition.make(self)
        triggered = self._check_triggers_with_tailed_audit(state.audit_objects)

        # Load all tasks in bulk
        tasks_query = mapping.Task.objects(id__in={t.source for t in triggered})
        task_objects = {t.id: t for t in tasks_query.only("id", "type", "execution__status", "owner")}

        # For return values
        checked = len(triggered)
        enqueued = 0

        for trigger in triggered:
            task = task_objects.get(trigger.source, None)
            if task is None:
                logger.error(
                    "Task #%d does not exist, remove trigger with token %s",
                    trigger.source, trigger.token
                )
                trigger.delete()
                continue

            if task.execution.status == ctt.Status.WAIT_TASK:
                self._fire_wait_task_trigger(trigger, task)
                enqueued += 1
            else:
                # TODO: check this branch never happens
                logger.warning(
                    "Task #%s (%s) status is `%s`, skip trigger",
                    task.id, task.type, task.execution.status
                )

        # Remember checked entries to exclude them from processing during next iteration
        # and save latest processed audit to continue from it next time
        if state.audit_objects:
            state.checked_audit_ids |= {a.id for a in state.audit_objects}
            state.last_check_time = max(_.date for _ in state.audit_objects)

        self.audit_position.update(state.dump())
        return self._result_message(checked, enqueued)

    def wait_task_garbage_collector(self):
        """
        Check triggers against target tasks status.

        Audit entries may get lost sometimes (e.g. database failed before replication is completed).
        Task status is always updated with write_concern=majority, so we can rely on it.
        """

        triggers = list(mapping.TaskStatusTrigger.objects(
            activated=True,
            last_gc_check__lt=dt.datetime.utcnow() - self.wait_task_gc_delay
        ))
        if not triggers:
            return "Nothing to check"

        logger.debug(
            "[wait_task_garbage_collector] check %d triggers (%r)",
            len(triggers), [t.source for t in triggers]
        )

        target_tids = {tid for t in triggers for tid in t.targets}

        task_query = mapping.Task.objects(id__in=target_tids)
        task_data = {
            tid: (status, updated)
            for tid, status, updated in task_query.fast_scalar("id", "execution__status", "time__updated")
        }

        def func(trigger, target):
            status, updated = task_data.get(target, (None, None))
            if status in trigger.statuses:
                return updated
            return None

        # Don't care for return value, as triggers would be already modified
        fired = self._get_fired_triggers_by_func(triggers, has_fired=func)

        logger.info(
            "[wait_task_garbage_collector] fixed %d triggers (%r)",
            len(fired), [t.source for t in fired]
        )

        # Touch all these triggers
        mapping.TaskStatusTrigger \
            .objects(source__in=[t.source for t in triggers]) \
            .update(last_gc_check=dt.datetime.utcnow())

    def check_wait_output(self):
        checked = 0
        enqueued = 0
        tasks = collections.defaultdict(lambda: collections.defaultdict(list))
        for source, targets, wait_all in mapping.TaskOutputTrigger.objects().scalar("source", "targets", "wait_all"):
            checked += 1
            wait_set = set()
            for tf in targets:
                if wait_all:
                    wait_set.add((tf.target, tf.field))
                tasks[tf.target][tf.field].append((source, wait_set if wait_all else None))

        futures = []
        done_sources = set()
        for target, outputs, status in mapping.Task.objects(id__in=tasks.keys()).scalar(
            "id", "parameters__output", "execution__status"
        ):
            if self.stop_event.is_set():
                break

            msg = "task #{} in status {}".format(target, status)
            output_keys = (
                tasks[target]
                if status in ctt.Status.Group.FINISH else
                (output.key for output in outputs)
            )
            for output in output_keys:
                for source, wait_set in tasks[target].get(output, []):
                    if wait_set is not None:
                        wait_set.remove((target, output))
                    if not wait_set and source not in done_sources:
                        done_sources.add(source)
                        logger.debug("Task #%s triggered by output parameter '%s' of %s", source, output, msg)
                        try:
                            model = controller.Task.get(source)
                        except controller.Task.NotExists as ex:
                            logger.warning(ex)
                            logger.info("Removing all triggers for non-existent task #%d", source)
                            self.__clean_triggers(source)
                            continue

                        if model.execution.status == ctt.Status.WAIT_OUT:
                            logger.info(
                                "Task #%s (%s): switching from WAIT_OUT to ENQUEUING", model.id, model.type
                            )
                            futures.append(self.enqueue(model, "Output parameters are ready"))
                            enqueued += 1

                        elif model.execution.status not in ctt.Status.Group.EXECUTE:
                            mapping.TaskOutputTrigger.objects(source=source).delete()

        return self._result_message(checked, enqueued), futures

    def check_wait_res(self):
        enqueued = 0
        tasks = [
            (row[0], set(row[1]))
            for row in mapping.Task.objects(
                execution__status=ctt.Status.WAIT_RES
            ).fast_scalar("id", "requirements__resources")
        ]
        res_ids = set(it.chain.from_iterable(t[1] for t in tasks))

        ready_res_ids = set(mapping.Resource.objects(
            id__in=res_ids,
            state=ctr.State.READY
        ).fast_scalar("id"))

        futures = []
        for t in tasks:
            if self.stop_event.is_set():
                break

            if t[1] & ready_res_ids == t[1]:
                try:
                    model = controller.Task.get(t[0])
                except controller.Task.NotExists as ex:
                    logger.warning(ex)
                else:
                    logger.info("Task #%s (%s): switching from WAIT_RES to ENQUEUING", model.id, model.type)
                    futures.append(self.enqueue(model, "All dependent resources are ready"))
                    enqueued += 1

        return self._result_message(len(tasks), enqueued), futures

    def check_wait_time(self):
        checked = controller.TimeTrigger.count()
        enqueued = 0

        futures = []
        for tr in controller.TimeTrigger.active_triggers():
            if self.stop_event.is_set():
                break

            try:
                model = controller.Task.get(tr.source)
            except controller.Task.NotExists as ex:
                logger.warning(ex)
            else:
                if model.execution.status in (ctt.Status.WAIT_TIME, ctt.Status.WAIT_TASK, ctt.Status.WAIT_OUT):
                    logger.info(
                        "Task #%s (%s): switching from %s to ENQUEUING", model.id, model.type, model.execution.status
                    )
                    futures.append(self.enqueue(model, "Time has expired", tr))
                    enqueued += 1

                elif (
                    model.execution.status not in
                    common.utils.chain(ctt.Status.Group.EXECUTE, ctt.Status.ENQUEUING)
                ):
                    logger.warning(
                        "Task #%s (%s) status is `%s`. Removing its document from TimeTrigger collection",
                        model.id, model.type, model.execution.status
                    )
                    controller.TimeTrigger.delete(tr)

        return self._result_message(checked, enqueued), futures

    def _expire(self, tasks):
        try:
            logger.info("Expiring tasks: %s", tasks)
            task_ids = [task[0] for task in tasks]
            results = self._rest.batch.tasks.expire.update(id=task_ids, comment="Expired")
            for result in results:
                if result["status"] != ctm.BatchResultStatus.SUCCESS:
                    getattr(logger, result["status"].lower())(
                        "Task #%s not expired: %s", result["id"], result["message"]
                    )
        except common.rest.Client.HTTPError as exc:
            logger.error("Failed to expire tasks %s: %s", tasks, exc)

    def check_expired(self):
        checked = 0
        futures = []
        now = dt.datetime.utcnow()
        statuses = common.utils.chain(ctt.Status.Group.QUEUE, ctt.Status.Group.EXECUTE, ctt.Status.Group.WAIT)
        tasks = list(mapping.Task.objects(
            expires_at__lt=now, execution__status__in=statuses
        ).fast_scalar('id', 'execution__status'))
        for task_chunk in common.utils.chunker(tasks, self.expired_chunk_size):
            if self.stop_event.is_set():
                break
            futures.append(self.pool.submit(self._expire, task_chunk))
            checked += len(task_chunk)

        return self._result_message(checked, checked), futures

    def on_stop(self):
        super(TaskStateSwitcher, self).on_stop()
        self.context.update(self.audit_position)
        self._model.save()
        self.pool.shutdown(wait=True)
