import logging
import datetime as dt
import itertools as it

import concurrent.futures

from sandbox import common
import sandbox.common.types.misc as ctm
import sandbox.common.types.task as ctt
import sandbox.common.types.user as ctu
import sandbox.common.types.resource as ctr
import sandbox.common.types.statistics as ctst

from sandbox.yasandbox import controller
from sandbox.yasandbox.database import mapping

from sandbox.services import base

logger = logging.getLogger(__name__)


def _filter_tids_to_delete(tids, delete_created_in_gui=False):
    tids = set(tids) - set(mapping.Resource.objects(
        task_id__in=tids,
        state=ctr.State.READY,
        read_preference=mapping.ReadPreference.SECONDARY
    ).scalar("task_id"))

    tids -= set(mapping.Task.objects(
        parent_id__in=tids,
        execution__status__in=Cleaner.SUCCESS_STATUSES,
        read_preference=mapping.ReadPreference.SECONDARY,
    ).scalar("parent_id"))

    if not delete_created_in_gui:
        tids -= set(mapping.Audit.objects(
            task_id__in=tids,
            source=ctt.RequestSource.WEB,
            status=ctt.Status.DRAFT,
        ).scalar("task_id"))

    return tids


class Cleaner(base.ThreadedService):
    """ Service for clearing database. """

    notification_timeout = 60
    tick_interval = 300

    # Historically significant max task id
    MAX_SAFE_TASK_ID = 1000

    # Regular tasks clearing params
    REGULAR_TASKS_TTL = 14
    REGULAR_TASKS_BATCH_SIZE = 100

    MAX_TASKS_TO_DELETE_AT_ONCE = 1000
    TASKS_TO_DELETE_CHUNK_SIZE = 20

    # Audit clearing params
    MAX_TASKS_TO_REDUCE_AUDIT_AT_ONCE = 5000
    AUDIT_TTL = 30

    # DELETED resources clearing params
    DELETED_RESOURCES_TO_REMOVE_AT_ONCE = 30000
    DELETED_RESOURCE_TTL = 14
    # TTL in days for notifications
    NOTIFICATION_TTL = 60
    # OAuth session tokens TTL in days
    SESSION_TTL = 1
    # Parameters meta TTL in days
    PARAMETERS_META_TTL = 30
    # Maximum metas to check in one iteration
    PARAMETERS_META_CHUNK_SIZE = 1000

    SUCCESS_STATUSES = (
        ctt.Status.SUCCESS, ctt.Status.FAILURE, ctt.Status.RELEASED, ctt.Status.RELEASING, ctt.Status.NOT_RELEASED
    )

    # TTL in days for uncompleted tasks, after which they will be switched to status DELETED
    UNCOMPLETED_TASK_TTL = 7
    UNCOMPLETED_STATUSES = (ctt.Status.DRAFT, ctt.Status.STOPPED)

    # TTL in days for erroneous tasks, after which they will be switched to status DELETED
    ERRONEOUS_TASK_TTL = 14
    ERRONEOUS_STATUSES = (ctt.Status.EXCEPTION, ctt.Status.NO_RES, ctt.Status.TIMEOUT)

    # Maximum number of tasks to remove at once
    MAX_TASKS_TO_REMOVE_AT_ONCE = 10000
    # TTL in days for tasks in status DELETED
    DELETED_TASK_TTL = 7

    MAX_SUBWORKERS = 10
    TASKS_COLLECTOR_INTERVAL = 1

    def __init__(self, *args, **kwargs):
        super(Cleaner, self).__init__(*args, **kwargs)

        if self.sandbox_config.server.auth.enabled:
            self._auth = common.utils.read_settings_value_from_file(self.sandbox_config.server.auth.oauth.token)
        else:
            self._auth = None

        self.__subworkers = concurrent.futures.ThreadPoolExecutor(max_workers=self.MAX_SUBWORKERS)

    @property
    def targets(self):
        return [
            self.Target(self.regular_tasks_collector, interval=self.TASKS_COLLECTOR_INTERVAL),
            self.Target(self.common_cleaner),
        ]

    def _switch_to_deleted(self, tids):
        logger.debug("Deleting %s", tids)
        api = common.rest.ThreadLocalCachableClient(auth=self._auth)

        try:
            response = api.batch.tasks["delete"].update({
                "comment": "Deleted by cleaner",
                "id": tids,
            })
        except api.HTTPError as error:
            logger.error("Failure while deleting tasks %s: %s", tids, error)
            return

        response = [_ for _ in response if _["status"] != ctm.BatchResultStatus.SUCCESS]
        if response:
            logger.warning("Problems while tasks deleting: %s", response)

    def _switch_to_stopped(self, tids):
        logger.debug("Stopping %s", tids)
        api = common.rest.ThreadLocalCachableClient(auth=self._auth)

        try:
            response = api.batch.tasks["stop"].update({
                "comment": "stopped by cleaner",
                "id": tids,
            })
        except api.HTTPError as error:
            logger.error("Failure while stopping tasks %s: %s", tids, error)
            return

        response = [_ for _ in response if _["status"] != ctm.BatchResultStatus.SUCCESS]
        if response:
            logger.warning("Problems while tasks stopping: %s", response)

    def _send_signals(self, signals):
        now = dt.datetime.utcnow()
        self.signaler.push(dict(
            date=now,
            timestamp=now,
            type=ctst.SignalType.CLEANER,
            **signals
        ))

    def regular_tasks_collector(self):
        deleted = 0
        last_task_id = self.context.get("last_task_id")
        if not last_task_id:
            last_task_id = mapping.Task.objects().order_by("-id").first()["id"]

        logger.info("Collecting regular tasks less than %s", last_task_id)
        now = dt.datetime.utcnow()
        tids = set(mapping.Task.objects(
            id__lt=last_task_id,
            id__gt=self.MAX_SAFE_TASK_ID,
            execution__status__in=self.SUCCESS_STATUSES,
            time__updated__lte=now - dt.timedelta(days=self.REGULAR_TASKS_TTL),
            read_preference=mapping.ReadPreference.SECONDARY,
        ).order_by("-id").limit(self.REGULAR_TASKS_BATCH_SIZE).scalar("id"))

        if tids:
            last_task_id = min(tids)
            tids = _filter_tids_to_delete(tids)
        else:
            last_task_id = 0

        logger.debug(
            "Fetched %s tasks in range [%s, %s], last task id: %s",
            len(tids), min(tids or [0]), max(tids or [0]), last_task_id
        )

        if not self.stop_event.is_set():
            futures = [
                self.__subworkers.submit(self._switch_to_deleted, chunk)
                for chunk in common.utils.chunker(sorted(tids), self.TASKS_TO_DELETE_CHUNK_SIZE)
            ]
            if futures:
                concurrent.futures.wait(futures)
            self._send_signals({"tasks_switched_to_deleted": len(tids)})
            deleted = len(tids)
            self.context["last_task_id"] = last_task_id

        return "Total tasks deleted: {}".format(deleted), []

    def collect_tasks_to_delete(self):
        now = dt.datetime.utcnow()
        tids_to_delete = set()
        for tasks_type, statuses, ttl in (
            ("uncompleted", self.UNCOMPLETED_STATUSES, self.UNCOMPLETED_TASK_TTL),
            ("erroneous", self.ERRONEOUS_STATUSES, self.ERRONEOUS_TASK_TTL),
        ):
            var_name = "last_{}_task_id".format(tasks_type)
            last_task_id = self.context.get(var_name, self.MAX_SAFE_TASK_ID)
            last_task_id = max(last_task_id, self.MAX_SAFE_TASK_ID)
            logger.info("Collecting %s tasks to delete greater than #%s", tasks_type, last_task_id)
            tids = set(mapping.Task.objects(
                id__gt=last_task_id,
                execution__status__in=statuses,
                time__updated__lte=now - dt.timedelta(days=ttl),
                read_preference=mapping.ReadPreference.SECONDARY,
            ).order_by("+id").limit(self.MAX_TASKS_TO_DELETE_AT_ONCE).scalar("id"))
            logger.debug(
                "Collected %s %s tasks in range [%s, %s]", len(tids), tasks_type, min(tids or [0]), max(tids or [0])
            )
            tids_to_delete.update(tids)
            last_task_id = 0 if len(tids) < self.MAX_TASKS_TO_DELETE_AT_ONCE else max(tids)
            self.context[var_name] = last_task_id
        return _filter_tids_to_delete(tids_to_delete, delete_created_in_gui=True)

    def collect_deleted_tasks(self):
        last_task_id_to_remove = self.context.get("last_task_id_to_remove", self.MAX_SAFE_TASK_ID)
        last_task_id_to_remove = max(last_task_id_to_remove, self.MAX_SAFE_TASK_ID)

        logger.info("Collecting deleted tasks with ID greater than %s", last_task_id_to_remove)
        now = dt.datetime.utcnow()
        tids = set(
            mapping.Task.objects(
                id__gt=last_task_id_to_remove,
                execution__status=ctt.Status.DELETED,
                time__updated__lte=now - dt.timedelta(days=self.DELETED_TASK_TTL),
                read_preference=mapping.ReadPreference.SECONDARY,
            ).order_by("+id").limit(self.MAX_TASKS_TO_REMOVE_AT_ONCE).scalar("id"),
        )
        logger.debug("Collected %s tasks to remove in range [%s, %s]", len(tids), min(tids or [0]), max(tids or [0]))
        last_task_id_to_remove = 0 if len(tids) < self.MAX_TASKS_TO_DELETE_AT_ONCE else max(tids)
        tids -= set(mapping.Resource.objects(
            task_id__in=tids,
            read_preference=mapping.ReadPreference.SECONDARY,
        ).scalar("task_id"))
        logger.debug("%s tasks to delete after filtering", len(tids))

        aids = list(mapping.Audit.objects(
            task_id__in=tids,
            read_preference=mapping.ReadPreference.SECONDARY,
        ).scalar("id"))
        logger.debug(
            "Collected %s tasks to remove in range [%s, %s], last task id to remove: %s",
            len(tids), min(tids or [0]), max(tids or [0]), last_task_id_to_remove,
        )
        self.context["last_task_id_to_remove"] = last_task_id_to_remove
        return tids, aids

    def collect_deleted_resources(self):
        logger.info("Collecting resources in states DELETED and NOT_READY")
        now = dt.datetime.utcnow()
        rid2tid = dict(
            mapping.Resource.objects(
                state__in=[ctr.State.NOT_READY, ctr.State.BROKEN, ctr.State.DELETED],
                mds=None,
                time__updated__lte=now - dt.timedelta(days=self.DELETED_RESOURCE_TTL),
                read_preference=mapping.ReadPreference.SECONDARY
            ).order_by("+id").limit(self.DELETED_RESOURCES_TO_REMOVE_AT_ONCE).scalar("id", "task_id")
        )
        rids_to_remove = set(rid2tid)
        logger.debug(
            "Collected %s resources to remove in range [%s, %s]",
            len(rids_to_remove), min(rids_to_remove or [0]), max(rids_to_remove or [0])
        )

        tids_to_check = set(mapping.Task.objects(
            id__in=rid2tid.values(),
            execution__status__in=self.SUCCESS_STATUSES,
            read_preference=mapping.ReadPreference.SECONDARY,
        ).scalar("id"))
        tids_to_delete = _filter_tids_to_delete(tids_to_check)

        logger.debug(
            "%s tasks in range [%s, %s] were added for deletion after checking %s tasks",
            len(tids_to_delete), min(tids_to_delete or [0]), max(tids_to_delete or [0]), len(tids_to_check),
        )
        return rids_to_remove, tids_to_delete

    def collect_task_audit(self):
        last_task_id_with_reduced_history = self.context.get("last_task_id_with_reduced_history", self.MAX_SAFE_TASK_ID)
        last_task_id_with_reduced_history = max(last_task_id_with_reduced_history, self.MAX_SAFE_TASK_ID)

        logger.info("Collecting audit records for tasks greater than %s", last_task_id_with_reduced_history)
        now = dt.datetime.utcnow()
        tids = list(mapping.Task.objects(
            id__gt=last_task_id_with_reduced_history,
            execution__status__ne=ctt.Status.DELETED,
            time__updated__lte=now - dt.timedelta(days=self.AUDIT_TTL),
            read_preference=mapping.ReadPreference.SECONDARY
        ).order_by("+id").limit(self.MAX_TASKS_TO_REDUCE_AUDIT_AT_ONCE).scalar("id"))
        last_task_id_with_reduced_history = 0 if len(tids) < self.MAX_TASKS_TO_REDUCE_AUDIT_AT_ONCE else max(tids)
        aids = []
        cur_tid = 0
        saved_statuses = {
            ctt.Status.ENQUEUED: False,
            ctt.Status.EXECUTING: False
        }
        if self.stop_event.is_set():
            return aids
        for aid, tid, status in mapping.Audit.objects(task_id__in=tids).order_by("+task_id", "+date").scalar(
            "id", "task_id", "status"
        ):
            if cur_tid != tid:
                if self.stop_event.is_set():
                    return
                if aids:
                    aids.pop()
                for _ in saved_statuses:
                    saved_statuses[_] = False
                continue
            if saved_statuses.get(status) is False:
                saved_statuses[status] = True
                continue
            aids.append(aid)
        logger.debug("Last task id with reduced history: %s", last_task_id_with_reduced_history)
        self.context["last_task_id_with_reduced_history"] = last_task_id_with_reduced_history
        return aids

    def remove_resources(self, rids):
        logger.info("Removing %s resource(s)", len(rids))
        controller.Resource.list_resources_audit("Remove resource from database", rids)
        return mapping.Resource.objects(id__in=rids).delete()

    def remove_audit(self, aids):
        logger.info("Removing %s task audit record(s)", len(aids))
        return mapping.Audit.objects(id__in=aids).delete()

    def remove_tasks(self, tids):
        logger.info("Removing %s task(s)", len(tids))
        return mapping.Task.objects(id__in=tids).delete()

    def remove_notifications(self):
        logger.info("Removing outdated notifications")
        return mapping.Notification.objects(
            date__lte=dt.datetime.utcnow() - dt.timedelta(days=self.NOTIFICATION_TTL)
        ).delete()

    def remove_sessions(self):
        logger.debug("Removing outdated sessions")
        to_stop, drop, checked, now = [], [], 0, dt.datetime.utcnow()
        ttl = dt.timedelta(days=self.SESSION_TTL)
        for cache in mapping.OAuthCache.objects(source=ctu.TokenSource.EXTERNAL_SESSION).lite():
            if cache.validated + (dt.timedelta(seconds=cache.ttl) if cache.ttl else ttl) < now:
                drop.append(cache.token)
        for cache in mapping.OAuthCache.objects(
            validated__lte=dt.datetime.utcnow() - ttl, source__ne=ctu.TokenSource.EXTERNAL_SESSION
        ).lite():
            if cache.source.startswith(ctu.TokenSource.CLIENT):
                if cache.state == ctt.SessionState.ABORTED and cache.abort_reason == ctt.Status.STOPPING:
                    status = next(iter(mapping.Task.objects(id=cache.task_id).scalar("execution__status")), None)
                    if status not in ctt.Status.Group.STOP:
                        to_stop.append(cache.task_id)
                        continue
                cid = cache.source.partition(":")[-1]
                # Drop tokens only for non-existing clients. Others will drop stalled tokens on next task get.
                if not mapping.Client.objects(hostname=cid).count():
                    logger.debug("Dropping stalled token for client '%s'", cid)
                    drop.append(cache.token)
            elif cache.validated + (cache.ttl and dt.timedelta(seconds=cache.ttl) or ttl) < now:
                logger.debug("Dropping expired token from %s for %s", cache.source, cache.app_id)
                drop.append(cache.token)
            checked += 1
        deleted = mapping.OAuthCache.objects(token__in=drop).delete()
        if to_stop:
            self._switch_to_stopped(to_stop)
        logger.info(
            "Checked %d sessions, removed %d: %r, trying to stop %d: %r",
            checked, len(drop), sorted(drop), len(to_stop), sorted(to_stop)
        )
        return deleted

    def remove_parameters_metas(self):
        logger.debug("Removing outdated parameters metas")
        deleted, updated = 0, 0
        deadline = dt.datetime.utcnow() - dt.timedelta(days=self.PARAMETERS_META_TTL)
        ids_to_check = list(
            mapping.ParametersMeta.objects(
                accessed__lte=deadline
            ).fast_scalar("id").limit(self.PARAMETERS_META_CHUNK_SIZE)
        )
        if ids_to_check:
            used_ids = set(it.chain(
                mapping.Task.objects(parameters_meta__in=ids_to_check).fast_scalar("parameters_meta"),
                mapping.Scheduler.objects(parameters_meta__in=ids_to_check).fast_scalar("parameters_meta")
            ))
            ids_to_delete = list(set(ids_to_check) - used_ids)
            if ids_to_delete:
                deleted = mapping.ParametersMeta.objects(id__in=ids_to_delete).delete()
            ids_to_update = list(used_ids - set(ids_to_check))
            if ids_to_update:
                updated = mapping.ParametersMeta.objects(
                    id__in=ids_to_update
                ).update(set__accessed=dt.datetime.utcnow())
        logger.info(
            "Checked %d parameters' metas older than %s, removed %d, updated %d",
            len(ids_to_check), deadline, deleted, updated
        )
        return deleted

    def common_cleaner(self):
        logger.info("Common cleaner started")

        tids_to_remove = set()
        rids_to_remove = set()
        aids_to_remove = set()
        tids_to_delete = set()

        with common.utils.Timer() as timer:
            with timer[self.collect_tasks_to_delete.__name__]:
                tids = self.collect_tasks_to_delete()
                logger.info("Collected %s tasks in range [%s, %s]", len(tids), min(tids or [0]), max(tids or [0]))
            tids_to_delete.update(tids)
            with timer[self.collect_deleted_tasks.__name__]:
                tids, aids = self.collect_deleted_tasks()
                logger.info("Collected %s tasks with %s audit records", len(tids), len(aids))
                tids_to_remove.update(tids)
                aids_to_remove.update(aids)
            with timer[self.collect_deleted_resources.__name__]:
                rids, tids = self.collect_deleted_resources()
                logger.info("Collected %s resources to remove and %s tasks to delete", len(rids), len(tids))
                rids_to_remove.update(rids)
                tids_to_delete.update(tids)
            with timer[self.collect_task_audit.__name__]:
                aids = self.collect_task_audit()
                logger.info("Collected %s task audit records", len(aids))
                aids_to_remove.update(aids)
            tids_to_delete -= tids_to_remove

            tids_to_delete_futures = [
                self.__subworkers.submit(self._switch_to_deleted, chunk)
                for chunk in common.utils.chunker(sorted(tids_to_delete), self.TASKS_TO_DELETE_CHUNK_SIZE)
            ]

            signals = {}
            for method, args, signal_name in (
                (self.remove_audit, (aids_to_remove,), "removed_audits"),
                (self.remove_resources, (rids_to_remove,), "removed_resources"),
                (self.remove_tasks, (tids_to_remove,), "removed_tasks"),
                (self.remove_notifications, (), "removed_notifications"),
                (self.remove_sessions, (), "removed_sessions"),
                (self.remove_parameters_metas, (), "removed_parameters_metas"),
            ):
                with timer[method.__name__]:
                    signals[signal_name] = method(*args)

            with timer["wait_tasks_deleting"]:
                concurrent.futures.wait(tids_to_delete_futures)
                signals["tasks_switched_to_deleted"] = len(tids_to_delete)

        self._send_signals(signals)
        logger.info("Common cleaner finished totally in %s", timer)
        return "Finished", []

    def on_stop(self):
        super(Cleaner, self).on_stop()
        self._model.save()
        self.__subworkers.shutdown(wait=True)
