import collections
import datetime as dt
import itertools as it
import logging

import sandbox.common.types.notification as ctn
import sandbox.common.types.task as ctt
from sandbox import common
from sandbox.common import utils
from sandbox.services import base
from sandbox.yasandbox import controller
from sandbox.yasandbox.database import mapping

logger = logging.getLogger(__name__)

# Maximum number of elements processed with a single query
BATCH_SIZE = 10000

# On every restart, triggers will be checked within [<service last run time> - CHECK_DELTA, <now>] range
CHECK_DELTA = 60  # seconds

# Maximum length of event description
INFO_LENGTH_LIMIT = 2 << 20  # 2 MB

# Max len of task description in telegram and Q notifications
DESCRIPTION_MAX_LEN = 200

# Statuses in which task is not running
TASK_IDLE_STATUSES = frozenset(it.chain(
    ctt.Status.Group.FINISH,
    (st for st in ctt.Status if ctt.Status.can_switch(st, ctt.Status.ENQUEUING)),
))


class Notifier(object):

    class Stopping(Exception):
        pass

    class State(common.patterns.Abstract):
        __slots__ = (
            "last_check_time",
            "last_gc_time",
            "checked_audit_ids",
        )
        __defs__ = (None, None, set())

    def __init__(self, service):
        self._service = service
        self.state = None

        self._gc_max_duration = service.service_config["gc_max_duration"]
        self._task_idle_interval = service.service_config["task_idle_interval"]

        self._start_time = dt.datetime.utcnow()
        self._tasks_to_drop_triggers_from = set()
        self._trigger_ids_to_drop = set()

    def _raise_if_stopping(self):
        if self._service._stop_requested.is_set():
            raise self.Stopping

    def __enter__(self):
        context = self._service.context
        self.state = self.State(
            last_check_time=context.get("last_check_time", dt.datetime.utcnow() - dt.timedelta(minutes=30)),
            last_gc_time=context.get("last_gc_time", dt.datetime.min),
            checked_audit_ids=set(context.get("checked_audit_ids", ())),
        )
        return self

    def __exit__(self, exc_type, exc_val, exc_tb):

        suppress_exc = True
        if exc_type and exc_type is not self.Stopping:
            suppress_exc = False

        self.save_context()

        if self._trigger_ids_to_drop:
            n_deleted = 0
            for batch in utils.grouper(self._trigger_ids_to_drop, BATCH_SIZE):
                n_deleted += mapping.TaskStatusNotificationTrigger.objects.filter(
                    id__in=batch,
                    ctime__lt=self._start_time
                ).delete()

            logger.info("Removed %d used triggers", n_deleted)

        if self._tasks_to_drop_triggers_from:
            n_deleted = 0
            for batch in utils.grouper(self._tasks_to_drop_triggers_from, BATCH_SIZE):
                n_deleted += mapping.TaskStatusNotificationTrigger.objects.filter(
                    source__in=batch,
                    ctime__lt=self._start_time
                ).delete()

            logger.info("Removed %d garbage triggers for %d tasks", n_deleted, len(self._tasks_to_drop_triggers_from))

        return suppress_exc

    def save_context(self):
        self._service.context.update(self.state)
        self._service._model.save()

    def notify(self):
        assert self.state, "State is not initialized"
        logger.info("Triggers check: started")

        lower_time_bound = self.state.last_check_time - dt.timedelta(seconds=CHECK_DELTA)

        audit_qs = (
            mapping.Audit.objects
            .filter(date__gt=lower_time_bound)
            .only("id", "date", "task_id", "status", "content")
            .order_by("+date")  # the code below relies on the ordering by date
            .as_pymongo()
        )

        audit_objects = [
            {
                "id": obj["_id"],
                "date": obj["dt"],
                "task_id": obj["ti"],
                "status": obj.get("st", None),
                "content": obj.get("ct", ""),
            }
            for obj in audit_qs
        ]

        # Rotate `checked_audit_ids`
        self.state.checked_audit_ids &= {obj["id"] for obj in audit_objects}

        audit_objects = [obj for obj in audit_objects if obj["id"] not in self.state.checked_audit_ids]

        logger.info("Fetched %d audit records created from %s till now", len(audit_objects), lower_time_bound)

        # Triggers corresponding to fetched audit records
        task_triggers = collections.defaultdict(list)  # maps task id to the list of its triggers

        task_ids = set(obj["task_id"] for obj in audit_objects)

        for batch in utils.grouper(task_ids, BATCH_SIZE):
            for trigger in mapping.TaskStatusNotificationTrigger.objects.filter(source__in=batch).order_by("+source"):
                task_triggers[trigger.source].append(trigger)

        # Tasks corresponding to fetched triggers
        tasks = {}  # maps task id to task object

        for batch in utils.grouper(task_triggers.iterkeys(), BATCH_SIZE):
            task_qs = mapping.Task.objects.only("id", "type", "description", "owner", "execution").filter(id__in=batch)
            tasks.update((task.id, task) for task in task_qs)

        last_events = {}  # maps task id to its last audit object's id
        for obj in audit_objects:
            last_events[obj["task_id"]] = obj["id"]

        logger.debug(
            "Fetched %d triggers from %d audit records, %d tasks to be processed.",
            len(task_triggers), len(audit_objects), len(tasks)
        )

        # Process audit records
        for obj in audit_objects:
            self._raise_if_stopping()

            self.state.checked_audit_ids.add(obj["id"])
            self.state.last_check_time = obj["date"]

            task_id = obj["task_id"]

            triggers = task_triggers.get(task_id)
            if not triggers:
                continue

            task = tasks.get(task_id)
            if not task:
                continue

            for trigger in triggers:
                if obj["status"] not in trigger.statuses:
                    continue
                if trigger.transport == ctn.Transport.JUGGLER:
                    try:
                        recipients = controller.Notification.juggler_expanded_recipients(
                            trigger.recipients, juggler_key=ctn.JugglerCheck.TASK_STATUS_CHANGED
                        )
                        body = u"\nTask {type} #{id} owned by {owner} is {st}.\n{desc}\n{url}\n".format(
                            type=task.type,
                            id=task_id,
                            st=obj["status"],
                            desc=task.description,
                            url=common.utils.get_task_link(task_id),
                            owner=task.owner
                        )
                        if recipients:
                            controller.Notification.save(
                                transport=trigger.transport, send_to=recipients, send_cc=[], subject=None, body=body,
                                author=str(task.owner), task_id=task_id,
                                task_model=task, check_status=trigger.check_status,
                                juggler_tags=trigger.juggler_tags
                            )
                    except:
                        logger.exception("Can't process juggler trigger. Skip it.")
                elif trigger.transport in (ctn.Transport.TELEGRAM, ctn.Transport.Q):
                    link = common.utils.get_task_link(task_id)
                    if trigger.transport == ctn.Transport.TELEGRAM:
                        link = "#<a href='{href_id}'>{id}</a>".format(href_id=link, id=task_id)

                    try:
                        if not task.description or len(task.description) <= DESCRIPTION_MAX_LEN:
                            descr = task.description
                        else:
                            descr = u"{}...".format(common.encoding.escape(task.description)[:DESCRIPTION_MAX_LEN])
                    except:
                        descr = "Can't cut description to {} symbols".format(DESCRIPTION_MAX_LEN)
                        logger.exception("Can't cut description of task %s", task_id)

                    body = u"[Sandbox] {type} {link} is {st}.\nTask owner: {owner}.\nTask description: {descr}".format(
                        type=task.type, link=link, st=obj["status"], owner=task.owner,
                        descr=descr
                    )
                    controller.Notification.save(
                        transport=trigger.transport, send_to=trigger.recipients, send_cc=[], subject=None, body=body,
                        author=str(task.owner), content_type="text/html", task_id=task_id,
                        view=ctn.View.EXECUTION_REPORT,
                        task_model=task,
                    )
                else:
                    all_recipients = controller.Notification.expand_groups_emails(trigger.recipients)
                    send_to = sorted(set(_.split("@")[0] for _ in all_recipients))

                    if last_events[task_id] == obj["id"]:
                        # If audit record is last then use current task state for message body
                        info = task.execution.description
                        if info and len(info) > INFO_LENGTH_LIMIT:
                            info = "{}...".format(info[:INFO_LENGTH_LIMIT])
                    else:
                        # ...otherwise use audit object content for message body
                        info = obj["content"] or ""

                    subj = u"[Sandbox] {type} #{id} is {st}".format(type=task.type, id=task_id, st=obj["status"])
                    body = u"\nTask {type} #{id} owned by {owner} is {st}.\n{desc}\n{url}\n{info}\n".format(
                        type=task.type,
                        id=task_id,
                        st=obj["status"],
                        desc=task.description,
                        url=common.utils.get_task_link(task_id),
                        info=info,
                        owner=task.owner
                    )

                    logger.info(
                        "Notify (%s) via %s from task #%d (%s)",
                        ", ".join(send_to), trigger.transport, task_id, obj["status"]
                    )

                    controller.Notification.save(
                        transport=trigger.transport, send_to=send_to, send_cc=[], subject=subj, body=body,
                        author=str(task.owner), content_type="text/html", task_id=task_id,
                        view=ctn.View.EXECUTION_REPORT, task_model=task,
                    )

                if task.execution.status in TASK_IDLE_STATUSES:
                    logger.info(
                        "Going to remove: used trigger %s -> %s for task #%d",
                        trigger.statuses, trigger.recipients, task_id
                    )
                    self._trigger_ids_to_drop.add(trigger.id)

    def collect_garbage(self):
        assert self.state, "State is not initialized"
        logger.info("GC: started")

        gc_start = dt.datetime.utcnow()
        gc_deadline = gc_start + dt.timedelta(seconds=self._gc_max_duration)

        idle_interval = dt.timedelta(seconds=self._task_idle_interval)

        trigger_tasks = mapping.TaskStatusNotificationTrigger.objects.all().only("source").as_pymongo()
        trigger_tasks = sorted({t["source"] for t in trigger_tasks})

        deleted_tasks = []
        idle_tasks = []

        try:
            for batch in utils.grouper(trigger_tasks, BATCH_SIZE):
                self._raise_if_stopping()

                # Stop collecting when GC takes more than allowed
                if dt.datetime.utcnow() > gc_deadline:
                    raise self.Stopping

                tasks = (
                    mapping.Task.objects.filter(id__in=batch)
                    .only("id", "execution__status", "time__updated")
                    .as_pymongo()
                )

                non_existent = set(batch) - set(t["_id"] for t in tasks)
                deleted_tasks.extend(non_existent)

                for task_id in non_existent:
                    logger.info("Going to remove: triggers for task #%d (non-existent)", task_id)

                for task in tasks:
                    not_updated = gc_start - task["time"]["up"]
                    if task["exc"]["st"] in TASK_IDLE_STATUSES and not_updated > idle_interval:
                        idle_tasks.append(task["_id"])
                        logger.info(
                            "Going to remove: triggers for task #%d (idle in %s for %s)",
                            task["_id"], task["exc"]["st"], utils.td2str(not_updated)
                        )

        finally:
            self._tasks_to_drop_triggers_from.update(deleted_tasks, idle_tasks)
            logger.info("Collected triggers for %d non-existent tasks", len(deleted_tasks))
            logger.info("Collected triggers for %d idle tasks", len(idle_tasks))

        # Do not record time if GC stops in the middle
        self.state.last_gc_time = gc_start


class TaskStatusNotifier(base.SingletonService):
    """
    Generates notifications on task status change
    """

    def __init__(self, *args, **kwargs):
        super(TaskStatusNotifier, self).__init__(*args, **kwargs)
        # For backward compatibility with old TaskStatusNotifier to avoid various races
        self.zk_name = type(self).__name__

    @property
    def tick_interval(self):
        return self.service_config["run_interval"]

    def tick(self):
        with Notifier(self) as notifier:
            notifier.notify()
            notifier.save_context()  # in case garbage collection takes too long

            gc_interval = dt.timedelta(seconds=self.service_config["gc_interval"])

            if dt.datetime.utcnow() - notifier.state.last_gc_time > gc_interval:
                notifier.collect_garbage()
