# coding: utf-8

import socket
import logging
import datetime as dt
import operator as op
import itertools as it
import calendar
import collections

import cPickle as pickle

logger = logging.getLogger(__name__)

from sandbox import common
import sandbox.common.types.misc as ctm
import sandbox.common.types.task as ctt
import sandbox.common.types.user as ctu
import sandbox.common.types.resource as ctr
import sandbox.common.types.statistics as ctst

from sandbox.yasandbox import controller
from sandbox.yasandbox.database import mapping

import sandbox.serviceq.errors as qerrors
import sandbox.common.joint.errors as jerrors
import sandbox.common.joint.socket as jsocket


class TaskManager(object):
    """
    The Answer to the Ultimate Question of Task, the Universe, and Everything.
    """

    remoteName = 'task'

    CACHE = {}

    def __init__(self):
        mapping.Task.ensure_indexes()

    def create(
        self,
        task_type,
        owner=None,
        author=None,
        parent_id=None,
        host=None,
        model=None,
        cores=None,
        arch=None,
        priority=None,
        parameters=None,
        request=None,
        context=None,
        ram=None,
        suspend_on_status=None,
        score=None
    ):
        from yasandbox.proxy.task import getTaskClass

        task_cls = getTaskClass(task_type)
        if not task_cls:
            raise common.errors.TaskError("Unknown task type <" + str(task_type) + ">")

        if request:
            if not author:
                author = request.user.login
            elif request.user.login != author and not request.user.super_user:
                raise common.errors.TaskError(
                    "Task author ({}) should be the same as the user, which creates it ({}).".format(
                        author, request.user.login
                    )
                )

        if parameters is None:
            parameters = {}

        task = task_cls()
        task.type = task_type
        task.owner = owner
        task.author = author
        if parent_id:
            task.parent_id = parent_id
            task.author = mapping.Task.objects(id=parent_id).scalar("author")[0]
        task.model = model or ''
        if cores is not None:
            task.cores = cores
        task.arch = arch or ctm.OSFamily.ANY
        if (
            not priority and 'priority' not in parameters and
            request and request.source == request.Source.WEB and
            (request.user.login != ctu.ANONYMOUS_LOGIN or request.user.super_user)
        ):
            task.priority = ctu.DEFAULT_PRIORITY_LIMITS.api.next
        else:
            params_priority = parameters.pop('priority', (None, None))
            task.priority = task_cls.Priority().__setstate__(params_priority) if priority is None else priority

        if ram:
            task.required_ram = ram
        if host:
            task.required_host = host
        if 'fail_on_any_error' in parameters:
            task.ctx['fail_on_any_error'] = parameters.pop('fail_on_any_error')
        if 'tasks_archive_resource' in parameters:
            task.ctx['tasks_archive_resource'] = parameters.pop('tasks_archive_resource')
        if suspend_on_status is not None:
            task.suspend_on_status = suspend_on_status
        if score is not None:
            task.score = score
        task.hidden = mapping.Task.objects.only('hidden').get(id=parent_id).hidden if parent_id else False
        task.se_tag = parameters.get('se_tag')

        if parameters:
            task.from_dict(parameters)
        if not task.descr:
            task.descr = ''

        task.save()
        logger.info("New task #%s of type '%s' in status '%s' created.", task.id, task.type, task.status)
        controller.Task.audit(mapping.Audit(
            task_id=task.id,
            content="Created",
            status=task.status,
            author=task.author,
            source=request.source if request else None
        ))
        task.init_context()

        if context:
            if 'GSID' in context:
                context = task._update_task_gsid(context)
            task.ctx.update(context)
        self.update(task)
        return task

    @staticmethod
    def task_status_history(task_id=None, since=None, to=None, remote_ip=None, order_by="+date"):
        """
        Get status history of task with id 'task_id' since 'since' to 'to'
        All params are optional. If some param is None, than it's not imply on query

        :param task_id: id of task or iterable with task ids to fetch
        :param since: utc datetime object to restrict lowest action date
        :param to: utc datetime object to restrict highest action date
        :param remote_ip: remote_ip of the audit item
        :param order_by: string with document field name to order by
        :return: A list of :class:`Audit` documents
        :rtype: list
        """
        query = {}
        if since:
            query["date__gte"] = since
        if to:
            query["date__lte"] = to
        if isinstance(task_id, collections.Iterable):
            query["task_id__in"] = list(task_id)
        elif task_id is not None:
            query["task_id"] = task_id
        if remote_ip:
            query["remote_ip"] = remote_ip
        result = mapping.Audit.objects(**query).order_by(order_by)
        return list(result)

    @staticmethod
    def exists(task_id):
        """
            Проверка существования задачи

        :param task_id: идентификатор задачи
        :return: True, если есть такая, False - если нет
        :rtype: bool
        """
        return bool(mapping.Task.objects(id=task_id).scalar('id'))

    @staticmethod
    def still_in_status(task_id, status):
        """
        Check for given task is still in status.

        :param task_id: id of task to be checked.
        :param status: Task status to be checked.
        :return: `True` if the task is still in state and `False` otherwise.
        """
        return bool(mapping.Task.objects(id=task_id, execution__status=status).scalar('id'))

    def restart_task(self, task, event, request=None):
        """
        Restart given task

        :param task: task object
        :return: True, if task was restarted, False otherwise
        """
        from sandbox.yasandbox.manager import resource_manager

        if not task.can_restart():
            logger.warning('Task #%s cannot be restarted.', task.id)
            return False

        if task.status == ctt.Status.TEMPORARY:
            interval, left = mapping.Task.objects.fast_scalar(
                "execution__auto_restart__interval", "execution__auto_restart__left"
            ).with_id(task.id)

            run_interval = common.config.Registry().server.services.auto_restart_tasks.run_interval
            interval = interval or run_interval

            if left is None:
                left = common.config.Registry().common.task.execution.max_restarts
            left -= 1

            if left < 0:
                logger.warning("No restarts left for task #%s", task.id)
                task.set_status(ctt.Status.EXCEPTION, "No auto restarts left", force=True)
                return False

            else:
                interval = int(interval * 55 / 34)
                mapping.Task.objects.filter(id=task.id).update_one(
                    set__execution__auto_restart__left=left,
                    set__execution__auto_restart__interval=interval
                )

        for resource in resource_manager.list_task_resources(task.id):
            if resource.is_deleted():
                resource_manager.reset_resource_hosts(resource.id)
                continue
            if resource.type.restart_policy == resource.type.OnRestart.IGNORE:
                continue

            reset = resource.type.restart_policy == resource.type.OnRestart.RESET
            resource_manager.reset_resource(
                resource.id,
                state=ctr.State.NOT_READY if reset else ctr.State.DELETED
            )

        task.set_info("Task was restarted.")
        self.enqueue_task(task, event, request)

        logger.info("Task #%s has been restarted.", task.id)

        return True

    def enqueue_task(self, task, event=None, request=None, sync=False):
        """
        Enqueue given task. If necessary change task status to ENQUEUING to do some preparing work.

        :param task: task object
        :param event: string with description of event that cause task enqueuing
        :param request: SandboxRequest object
        :return: True if task was successfully enqueued, False otherwise
        """
        old_status = task.status
        logger.info("Enqueuing task #%s (current status is '%s', host: '%s')", task.id, old_status, task.host)
        if not task.Status.can_switch(old_status, task.Status.ENQUEUING):
            raise common.errors.IncorrectStatus(
                "Task #{} cannot be enqueued: current status is '{}', host is '{}'.".format(
                    task.id, task.status, task.host
                )
            )

        updated = False
        node_id = common.config.Registry().this.id
        new_status = task.Status.ENQUEUING

        permission_ok, error_desc = controller.Task.check_user_permission(task, request)
        if not permission_ok:
            logger.error(error_desc)
            raise common.errors.AuthorizationError(error_desc)

        if old_status == task.Status.DRAFT:
            try:
                updated = mapping.Task.objects(
                    id=task.id,
                    execution__status=old_status,
                    lock_host=task.lock_host
                ).update_one(
                    set__lock_host=node_id,
                    set__lock_time=dt.datetime.utcnow()
                )
                if not updated:
                    task.reload()
                    logger.warning(
                        "Task #%s enqueueing failed. "
                        "Possible race condition (1) with another master host, "
                        "current status is '%s' (was '%s'), lock_host: '%s'.",
                        task.id, task.status, old_status, task.lock_host
                    )
                    return False

                mapping.Task.objects.filter(id=task.id).update(
                    set__execution__auto_restart__left=task.max_restarts
                )
                task.mapping().reload()
                self.on_first_enqueue(task)
            except common.errors.WaitTask:
                new_status = task.Status.WAIT_TASK
            except common.errors.WaitTime:
                new_status = task.Status.WAIT_TIME
            except common.errors.NothingToWait:
                pass
            except common.errors.TaskStop:
                new_status = task.Status.STOPPING
            except Exception as ex:
                logger.error("Error while enqueuing task #%s: %s", task.id, ex)
                if updated:
                    mapping.Task.objects(
                        id=task.id,
                        execution__status=old_status,
                        lock_host=node_id
                    ).update_one(
                        set__lock_host=""
                    )
                controller.Task.audit(mapping.Audit(
                    task_id=task.id,
                    author=task.author,
                    content="Enqueuing failed: " + str(ex),
                ))

                mp = task.mapping()
                mp.reload()
                ctx = pickle.loads(mp.context)
                ctx["__last_error_trace"] = common.utils.format_exception()
                mp.context = str(pickle.dumps(ctx))
                mp.save()
                raise

        # can't put this into above if branch without having to fiddle with try/except blocks
        if old_status == ctt.Status.DRAFT:
            mp = task.mapping()
            common.statistics.Signaler().push(dict(
                type=ctst.SignalType.TASK_CREATION,
                date=mp.time.created,
                timestamp=mp.time.created,
                author=task.author,
                owner=task.owner,
                task_type=task.type,
                unique_key="",
                hints=[],
                tags=task.tags,
                task_id=task.id,
                parent_id=mp.parent_id,
                scheduler_id=mp.scheduler,
                sdk_type=ctt.SDKType.SDK1,
            ))

        if new_status != task.Status.ENQUEUING or not sync or not common.config.Registry().server.sync_enqueuing:
            queue_put_result = controller.TaskQueue.EnqueueResult.FAILED
        else:
            task.save(save_condition={"execution__status": task.status})
            tw_task = controller.TaskWrapper(mapping.Task.objects.with_id(task.id))
            queue_put_result = controller.TaskQueue.finalize_enqueue_task(task=tw_task.model, logger=logger)

        if queue_put_result == controller.TaskQueue.EnqueueResult.FAILED:
            # Try to "lock" the object in the database by setting it to "ENQUEUING" status and host to this.
            try:
                task.set_status(
                    new_status=task.Status.ENQUEUING,
                    event=event,
                    request=request,
                    lock_host=node_id,
                    set_last_action_user=True
                )
            except common.errors.UpdateConflict:
                task.reload()
                logger.warning(
                    "Task #%s enqueueing failed. "
                    "Possible race condition (2) with another master host, "
                    "current status is '%s' (was '%s'), lock_host: '%s'.",
                    task.id, task.status, old_status, task.lock_host
                )
                return False

            # Now we can safely save the task object and release the lock.
            if new_status != task.Status.ENQUEUING:
                task.set_status(
                    new_status=new_status,
                    request=request,
                    lock_host=node_id
                )
            controller.TaskStatusNotifierTrigger.create_from_task(task.id)
            task.save(release=True)

            if task.status == task.Status.ENQUEUING:
                try:
                    controller.TaskQueue.qclient.prequeue_push(task.id)
                except (
                        qerrors.QTimeout, qerrors.QRetry, jerrors.Reconnect, jerrors.RPCError, jsocket.EOF, socket.error
                ) as ex:
                    logger.error("Error pushing task #%s to prequeue: %s", task.id, ex)
        else:
            tw_task.release_lock_host()

        return True

    @classmethod
    def on_first_enqueue(cls, task, log=None):
        """
        The method should be executed only one per particular task's life on first task enqueuing.
        It will call two methods on it: :meth:`yasandbox.proxy.task.Task._prepare_task` and
        :meth:`yasandbox.proxy.task.Task.on_enqueue`.

        :param task:    Task to be enqueued object.
        :param log:     Logger object to be used for logging. In case of `None` manager's logger will be used.
        """
        log = log or logger
        controller.TaskWrapper.validate_custom_tasks_resource(task.id, task.tasks_archive_resource)

        with common.utils.Timer() as timer:
            task.on_enqueue()
            log.info("Task #%s has been enqueued totally in %s", task.id, timer)

    @staticmethod
    def _restore_task(mp, cls=None):
        from yasandbox.proxy.task import Task
        return Task.restore(mp, cls)

    @staticmethod
    def load_task_ctx(task_id):
        obj = mapping.Task.objects.only('context').with_id(task_id)
        return pickle.loads(obj.context) if obj else {}

    @classmethod
    def load(cls, task_id, load_ctx=True):
        """
            Загрузить объект задачи

            :param task_id: идентификатор задачи
            :param load_ctx: загружать ли контекст
            :return: объект задачи
            :rtype: yasandbox.proxy.task.Task
        """
        try:
            task_id = mapping.ObjectId(task_id)
        except (ValueError, TypeError):
            return None

        flt = mapping.Task.objects(id=task_id)
        if not load_ctx:
            flt.exclude('context')
        task_obj = flt.first()
        return cls._restore_task(task_obj)

    @classmethod
    def update(cls, task):
        task.save()
        return True

    @staticmethod
    def list_query(*args, **kwargs):
        return controller.Task.list_query(*args, **kwargs)

    @classmethod
    def list(
        cls,
        limit=0, offset=0, parent_id=None, task_type='', completed_only=False, owner='', status='', host='',
        id=0, hidden=False, load=True, show_childs=False, descr_mask='', important_only=False,
        load_ctx=True, model='', arch='', order_by='-id',
        created=None, updated=None, author=''
    ):
        query = cls.list_query(
            parent_id=parent_id, type=task_type, completed_only=completed_only, owner=owner,
            status=status, host=host, id=id, hidden=hidden, show_childs=show_childs,
            descr_mask=descr_mask, important_only=important_only, limit=limit, offset=offset,
            model=model, arch=arch, load_ctx=load_ctx, order_by=order_by,
            created=created, updated=updated, author=author
        )
        return [cls._restore_task(mp) for mp in query] if load else list(query.scalar('id'))

    @classmethod
    def list_subtasks(cls, parent_id, completed_only=False, hidden=True):
        return list(cls.list_query(parent_id=parent_id, completed_only=completed_only, hidden=hidden).scalar('id'))

    @classmethod
    def count(
        cls, parent_id=None, task_type='', completed_only=False, owner='', status='', host='', id=None,
        hidden=False, show_childs=False, descr_mask='', important_only=False, model='', arch='',
        created=None, updated=None, author=''
    ):
        return cls.list_query(
            parent_id=parent_id, type=task_type, completed_only=completed_only,
            owner=owner, status=status, host=host, id=id, hidden=hidden, show_childs=show_childs,
            descr_mask=descr_mask, important_only=important_only, model=model, arch=arch,
            order_by='',
            created=created, updated=updated, author=author
        ).count()

    @staticmethod
    def _depends_on_query(task_id=0):
        return mapping.Task.objects(
            requirements__resources__in=mapping.Resource.objects(task_id=task_id).scalar('id')
        )

    @classmethod
    def list_dependent(cls, task_id=0, limit=0, offset=0):
        return list(cls._depends_on_query(task_id).order_by('id')[offset:offset + limit].scalar('id'))

    @classmethod
    def count_dependent(cls, task_id=0):
        return cls._depends_on_query(task_id).count()

    @staticmethod
    def list_not_ready_tasks_dependencies(task_ids):
        tasks = mapping.Task.objects.only('requirements.resources').in_bulk(task_ids)
        resources = set(mapping.Resource.objects(
            id__in=it.chain.from_iterable([t.requirements.resources for t in tasks.itervalues()]),
            state__ne=mapping.Resource.State.READY,
        ).scalar('id'))
        ret = {}
        for i, t in tasks.iteritems():
            deps = set(t.requirements.resources)
            if resources & deps:
                ret[t.id] = resources & deps
        return ret

    @staticmethod
    def list_dependencies(task_id):
        requirements = next(mapping.Task.objects(id=mapping.ObjectId(task_id)).scalar('requirements__resources'), None)
        resources = mapping.Resource.objects(id__in=requirements).scalar('id', 'task_id') if requirements else []
        res_to_task = collections.defaultdict(list)
        for res in resources:
            res_to_task[res[1]].append(res[0])
        return sorted(res_to_task.iteritems())

    @classmethod
    def fast_load_list(cls, ids, load_ctx=True):
        query = mapping.Task.objects
        if not load_ctx:
            query.exclude('context')
        objects = query.in_bulk(map(mapping.ObjectId, ids))
        return map(cls._restore_task, map(objects.get, ids))

    def get_dependent_list(self, task_id=0, limit=0, offset=0):
        task_id_list = self.list_dependent(task_id, limit, offset)
        task_list = self.fast_load_list(task_id_list)
        return zip(task_id_list, task_list)

    def get_dependent_count(self, task_id=0):
        return self.count_dependent(task_id)

    def get_dependence_tree(self, task_id, level=0):
        if level > 2:
            return []
        from yasandbox.manager import resource_manager
        depends_on_tasks = []
        for dep_task_id, resource_ids in self.list_dependencies(task_id):
            task = self.load(dep_task_id)
            if task is not None and task.id != task_id:
                task.depends_on_resources = resource_manager.fast_load_list(resource_ids)
                task.depends_on_tasks = self.get_dependence_tree(dep_task_id, level + 1)
            depends_on_tasks.append((task_id, task))
        return depends_on_tasks

    def get_child_tree(self, task_id):
        child_tasks = []
        list = self.list_subtasks(task_id, hidden=True)
        for child_task_id in list:
            task = self.load(child_task_id)
            if task is not None:
                task.child_tasks = self.get_child_tree(child_task_id)
            child_tasks.append((task_id, task))
        return child_tasks

    @staticmethod
    def register_dep_resource(task_id, resource_id):
        mapping.Task.objects(id=task_id).update_one(add_to_set__requirements__resources=resource_id)

    @staticmethod
    def forget_dep_resource(task_id, resource_id):
        mapping.Task.objects(id=task_id).update_one(pull__requirements__resources=resource_id)

    def server_host_option(self, host, opt):
        from yasandbox.manager import client_manager
        return client_manager.load(host).get('system', {}).get(opt, None)

    @staticmethod
    def get_history(task_id):
        history = mapping.Audit.objects(task_id=task_id).order_by("+date", "+id")
        return [dict(
            time=h.date,
            event=h.content if h.content else "-",
            host=h.hostname,
            status=h.status if h.status else "-",
            author=h.author if h.author else "sandbox",
            request_id=h.request_id if h.request_id else "-",
            remote_ip=h.remote_ip if h.remote_ip else "-",
            source=h.source
        ) for h in history] if history else []

    @classmethod
    def get_last_history_event(cls, task_id, event):
        return [e for e in cls.get_history(task_id) if e['event'] == event]

    @staticmethod
    def get_task_url(task_id):
        return common.utils.get_task_link(task_id)

    @classmethod
    def list_host_queue(cls, host='', limit=0, offset=0):
        object_ids = map(
            lambda x: (x, None),
            mapping.Task.objects(
                execution__status__in=(
                    ctt.Status.PREPARING,
                    ctt.Status.EXECUTING,
                    ctt.Status.FINISHING,
                    ctt.Status.RELEASING,
                    ctt.Status.STOPPING
                ),
                lock_host=host
            ).scalar('id')
        )
        object_ids += [(tid, score) for tid, _, score in controller.TaskQueue.qclient.queue_by_host(host)]
        objects = cls.fast_load_list(map(lambda x: x[0], object_ids[offset:offset + limit]))
        for i, obj in enumerate(objects):
            obj.effective_priority = object_ids[i][1]
        return objects

    @classmethod
    def count_host_queue(cls, host=''):
        extra = len(mapping.Task.objects(
            execution__status__in=(
                ctt.Status.PREPARING,
                ctt.Status.EXECUTING,
                ctt.Status.FINISHING,
                ctt.Status.RELEASING,
                ctt.Status.STOPPING
            ),
            lock_host=host
        ).scalar('id'))
        return extra + len(controller.TaskQueue.qclient.queue_by_host(host))

    def get_bulk_fields(self, ids, fields, safe_xmlrpc=False, strict_mode=False):
        """
        Query database for several fields of multiple tasks

            :param ids: - identifiers of tasks to upload
            :param fields: - field names to select
            :param safe_xmlrpc: - reject context fields unserializable in XMLRPC.
            :param strict_mode: - run a query twice and raise exception if less tasks found
            :return: a dictionary of the following structure
              (order of values is the same as order of fields to query for):

                .. code-block:: python

                    {str(task_id): [field values]}

            :rtype: dict
        """

        # probably mongo have some lag in shards request
        # so we make query second time for missed items

        ids = map(mapping.ObjectId, ids)

        class ModelWrapper(object):
            def __init__(self, mp):
                self.__mp = mp

            @property
            def id(self):
                return self.__mp.id

            @property
            def descr(self):
                return self.__mp.description

            @property
            def info(self):
                return self.__mp.execution.description

            @property
            def parent_id(self):
                return self.__mp.parent_id

            @property
            def host(self):
                return self.__mp.execution.host

            @property
            def hidden(self):
                return self.__mp.hidden

            @property
            def model(self):
                return self.__mp.requirements.cpu_model

            @property
            def priority(self):
                return ctt.Priority.make(self.__mp.priority).__getstate__()

            @property
            def arch(self):
                return self.__mp.requirements.platform

            @property
            def execution_space(self):
                return self.__mp.requirements.disk_space

            @property
            def important(self):
                return self.__mp.flagged

            @property
            def status(self):
                return self.__mp.execution.status

            @property
            def type(self):
                return self.__mp.type

            @property
            def owner(self):
                return self.__mp.owner

            @property
            def author(self):
                return self.__mp.author

            @property
            def timestamp(self):
                return self.__mp.time.created and calendar.timegm(self.__mp.time.created.timetuple())

            @property
            def timestamp_start(self):
                return (
                    self.__mp.execution.time.started and
                    calendar.timegm(self.__mp.execution.time.started.timetuple())
                )

            @property
            def timestamp_finish(self):
                return (
                    self.__mp.execution.time.finished and
                    calendar.timegm(self.__mp.execution.time.finished.timetuple())
                )

            @property
            def updated(self):
                return self.__mp.time.updated and calendar.timegm(self.__mp.time.updated.timetuple())

            @common.utils.singleton_property
            def ctx(self):
                return pickle.loads(self.__mp.context) if self.__mp.context else {}

        def collect_res(_ids):
            res = {}
            ctx_index = fields.index('ctx') if 'ctx' in fields else None
            objects = mapping.Task.objects.in_bulk(_ids)
            for mp in objects.itervalues():
                task = ModelWrapper(mp)
                values = [getattr(task, attr) for attr in fields]
                if safe_xmlrpc and ctx_index is not None:
                    values[ctx_index] = common.proxy.safe_xmlrpc_dict(values[ctx_index])
                res[str(task.id)] = common.proxy.safe_xmlrpc_list(values) if safe_xmlrpc else values
            return res

        # collect results from db first time
        res = collect_res(ids)

        # if some ids are missed - collect again and update results
        res_ids = map(mapping.ObjectId, res)
        if len(ids) != len(res_ids):
            missed_ids = set(ids) - set(res_ids)
            missed_res = collect_res(list(missed_ids))
            res.update(missed_res)

        if strict_mode:
            res_ids = map(mapping.ObjectId, res)
            missed_ids = set(ids) - set(res_ids)
            if missed_ids:
                raise common.errors.TaskError(
                    "Can't find {} in result.".format(sorted(list(missed_ids))))

        return res

    def set_priority(self, request, task, priority, user=None):
        user = user or request.user
        if not task.user_has_permission(user):
            raise ValueError('User "{}" not allowed to change priority for task "{}"'.format(
                user.login, task
            ))

        if priority > controller.Group.allowed_priority(request, task.owner, user):
            raise ValueError("Can not increase task #{} owned by '{}' priority '{}' for user '{}'".format(
                task.id, task.owner, task.priority, (user or request.user).login
            ))

        try:
            controller.TaskQueue.qclient.push(task.id, int(priority), None)
        except qerrors.QAlreadyExecuting:
            logger.warning("Task #%s already executing", task.id)
        except qerrors.QNeedValidation:
            logger.warning("Task queue validation required")
            raise
        task.set_priority(priority)

    @classmethod
    def task_effective_priority(cls, task_id):
        return max(
            map(
                op.itemgetter(0),
                next(
                    iter(controller.TaskQueue.qclient.queue_by_task(task_id)[1]),
                    [[None, ""]]
                )
            )
        )
