import math
import pickle
import httplib
import datetime
import operator as op
import functools as ft
import itertools as it

import six
import json
import flask

from sandbox import common
from sandbox.common import config
import sandbox.common.types.task as ctt
import sandbox.common.types.misc as ctm
import sandbox.common.types.client as ctc
import sandbox.common.types.statistics as ctst

from sandbox.web.api import v1
from sandbox.yasandbox import context
from sandbox.yasandbox.database import mapping

from sandbox.yasandbox import controller
from sandbox.yasandbox.api.json import mappers as task_mappers

from sandbox.serviceapi import mappers
from sandbox.serviceapi import constants as sa_consts
from sandbox.serviceapi.mules import signaler
from sandbox.serviceapi.web import RouteV1, exceptions


registry = config.Registry()


class Task(RouteV1(v1.task.TaskList)):
    LIST_QUERY_MAP = {
        "id": ("id", "id"),
        "type": ("type", "type"),
        "status": ("status", "execution__status"),
        "parent": ("parent_id", "parent_id"),
        "scheduler": ("scheduler", "scheduler"),
        "template_alias": ("template_alias", "template_alias"),
        "host": ("host", "requirements__host"),
        "arch": ("arch", "requirements__platform"),
        "model": ("model", "requirements__cpu_model"),
        "requires": ("requires", "requirements__resources"),
        "owner": ("owner", "owner"),
        "author": ("author", "author"),
        "desc_re": ("descr_mask", "description"),
        "children": ("show_childs", None),
        "hidden": ("hidden", None),
        "se_tag": ("se_tag", None),
        "priority": ("priority", None),
        "important": ("important_only", "flagged"),
        "created": ("created", "time__created"),
        "updated": ("updated", "time__updated"),
        "fields": ("fields", None),
        "input_parameters": ("input_parameters", None),
        "output_parameters": ("output_parameters", None),
        "any_params": ("any_params", None),
        "tags": ("tags", "tags"),
        "all_tags": ("all_tags", None),
        "hints": ("hints", "hints"),
        "all_hints": ("all_hints", None),
        "release": ("release", "release__status"),
        "semaphore_acquirers": ("semaphore_acquirers", "semaphore_acquirers"),
        "semaphore_waiters": ("semaphore_waiters", "semaphore_waiters"),
        "limit": ("limit", None),
        "offset": ("offset", None),
        "order": ("order_by", None),
    }

    CHUNK_SIZE = 1000

    @common.patterns.singleton_classproperty
    def legacy_client(cls):
        client = common.rest.Client(
            sa_consts.LegacyAPI.url + "api/v1.0", total_wait=0
        )
        client.RETRYABLE_CODES = ()
        return client

    @staticmethod
    def tasks_meta(task_types):
        try:
            return Task.legacy_client.task.meta.create(types=task_types)
        except Task.legacy_client.HTTPError as ex:
            if ex.status == httplib.SERVICE_UNAVAILABLE:
                raise exceptions.ServiceUnavailable("Legacy server temporary unavailable")
            raise exceptions.BadRequest(str(ex))
        except Task.legacy_client.TimeoutExceeded:
            raise exceptions.ServiceUnavailable("Legacy server temporary unavailable")

    @staticmethod
    def _status_arg_list(x):
        """ Converts comma-separated string of task statuses including statuses groups to a list of statuses. """
        groups = set(map(str, ctt.Status.Group))
        ret = []
        for s in x:
            if s in groups:
                ret.extend(list(getattr(ctt.Status.Group, s)))
            else:
                ret.append(s)
        return ret

    @classmethod
    def _add_children(cls, items, docs):
        # Build children list in one query
        children = {
            pid: [(item[0], item[2]) for item in ids]
            for pid, ids in it.groupby(
                mapping.Task.objects(
                    parent_id__in=set(d.id for d in docs)
                ).fast_scalar("id", "parent_id", "execution__status").order_by("+parent_id"),
                key=op.itemgetter(1)
            )
        }
        for item in items:
            item["children"] = [{str(row[0]): row[1]} for row in children.get(item["id"], [])]
            yield item

    @staticmethod
    def priority_parser(priority):
        return priority.upper().split(":")

    @classmethod
    def _semaphore_waiters(cls, docs):
        return (
            set(controller.TaskQueue.qclient.semaphore_waiters(
                [_.id for _ in docs if _.execution.status == ctt.Status.ENQUEUED]
            ))
            if context.current.request.source == ctt.RequestSource.WEB else
            None
        )

    @classmethod
    def _filter_items(cls, items, fields):
        field_getters = {
            field: ft.partial(reduce, lambda a, b: (a or {}).get(b), field.split("."))
            for field in fields or {}
        }
        return (
            {
                field: getter(item)
                for field, getter in field_getters.iteritems()
            }
            for item in items
        )

    @staticmethod
    def _parameters_parser(params):
        data = json.loads(params)
        if not isinstance(data, dict):
            raise ValueError("Parameters must be in format of dictionary: '{\"<name>\": <value>, ...}'")
        for key, value in six.iteritems(data):
            if not isinstance(key, six.string_types):
                raise ValueError("Parameters must be in format of dictionary: '{\"<name>\": <value>, ...}'")
            if isinstance(value, float) and (math.isnan(value) or math.isinf(value)):
                data[key] = str(value)
        return data

    REFORMAT_FIELDS = {
        "status": _status_arg_list.__func__,
        "descr_mask": lambda _: _.encode("utf-8"),
        "priority": priority_parser.__func__,
        "release": lambda x: [_.lower() for _ in x],
        "input_parameters": _parameters_parser.__func__,
        "output_parameters": _parameters_parser.__func__,
        "tags": lambda x: [_.upper() for _ in x],
    }

    @classmethod
    def json_args(cls):
        return {
            "indent": None if flask.request.is_xhr else 2,
            "separators": (",", ":") if flask.request.is_xhr else (", ", ": "),
            "ensure_ascii": False,
        }

    @classmethod
    def get(cls, query):
        try:
            query, offset, limit = cls.remap_query(query, save_query=True)
            for k, v in cls.REFORMAT_FIELDS.iteritems():
                if k in query:
                    query[k] = v(query[k])
        except (TypeError, ValueError, KeyError) as ex:
            raise exceptions.BadRequest("Query parameter validation error: " + str(ex))
        ids_to_intersect = None
        sem_id = query.pop("semaphore_acquirers", None)
        if sem_id is not None:
            semaphore_obj = dict(controller.TaskQueue.qclient.semaphores(sem_id)).get(sem_id)
            ids_to_intersect = semaphore_obj.tasks.keys() if semaphore_obj is not None else []

        sem_id = query.pop("semaphore_waiters", None)
        if sem_id is not None and ids_to_intersect is None:
            # if both "semaphore_acquirers" and "semaphore_waiters" are present, ignoring second
            ids_to_intersect = dict(controller.TaskQueue.qclient.semaphore_wanting(sem_id)).get(sem_id, [])

        query_task_ids = query.get("id", None)
        if query_task_ids is not None and ids_to_intersect is not None:
            if isinstance(query_task_ids, (list, tuple)):
                ids_to_intersect = set(ids_to_intersect)
                query["id"] = [task_id for task_id in query_task_ids if task_id in ids_to_intersect]
            else:
                query["id"] = query_task_ids if query_task_ids in ids_to_intersect else []
        elif ids_to_intersect is not None:
            query["id"] = ids_to_intersect

        ids_order = None
        short_task_result = (context.current.request.source == ctt.RequestSource.WEB)
        fields = set(query.pop("fields", []))
        context_in_fields = any(field == "context" or field.startswith("context.") for field in fields)
        query["load_ctx"] = query.get("load_ctx") or short_task_result or context_in_fields
        if not query.get("order_by") and "id" in query:
            query["order_by"] = None
            ids_order = query["id"]

        filter_wait_mutex = False
        if short_task_result and query.get("status"):
            statuses = set(common.utils.chain(query["status"]))
            if ctt.Status.WAIT_MUTEX in statuses:
                if statuses - {ctt.Status.WAIT_MUTEX, ctt.Status.ENQUEUED}:
                    statuses.remove(ctt.Status.WAIT_MUTEX)
                    query["status"] = statuses
                else:
                    filter_wait_mutex = ctt.Status.ENQUEUED not in statuses
                    query["status"] = ctt.Status.ENQUEUED

        if query.get("descr_mask") and short_task_result:
            if not any(param in query for param in ("owner", "author", "type", "tags", "hints")):
                raise exceptions.BadRequest(
                    "Search by description without 'owner', 'author', 'type', 'tags' or 'hints' fields is not allowed."
                )

        # empty intersection leads to ignoring kwargs["id"] in db query
        task_ids = query.get("id", None)
        if isinstance(task_ids, (list, tuple)) and len(task_ids) == 0:
            return v1.schemas.task.TaskList.create(
                limit=limit,
                offset=offset,
                total=0,
                items=[],
            )

        if not query.get("order_by"):
            query.pop("order_by", None)
        try:
            query = controller.Task.list_query(**query)
            query = query.lite()
        except Exception as ex:
            raise exceptions.BadRequest(str(ex))
        total = query.count()

        if filter_wait_mutex:
            real_offset = 0
            skipped = 0
            docs = []

            while real_offset < total:
                tmp_docs = list(query.skip(real_offset).limit(cls.CHUNK_SIZE))
                tmp_semaphore_waiters = cls._semaphore_waiters(tmp_docs)

                for doc in tmp_docs:
                    if len(docs) >= limit:
                        break

                    if doc.id in tmp_semaphore_waiters:
                        if skipped >= offset:
                            docs.append(doc)
                        else:
                            skipped += 1

                if len(docs) >= limit or len(tmp_docs) < cls.CHUNK_SIZE:
                    break
                real_offset += cls.CHUNK_SIZE
        else:
            docs = list((query if not offset else query.skip(offset)).limit(limit))

        if ids_order:
            docs = sorted(docs, key=lambda _: ids_order.index(_.id))

        if not fields or "scores" in fields:
            weather_map = {_.type: _ for _ in mapping.Weather.objects(type__in={_.type for _ in docs})}
        else:
            weather_map = {}

        semaphore_waiters = set(_.id for _ in docs) if filter_wait_mutex else cls._semaphore_waiters(docs)
        postprocesses = []
        if fields:
            fields.add("id")
            mapper_fields = {_ for _ in fields if not _.startswith("context.")}
            if context_in_fields:
                mapper_fields.add("context")
            postprocesses.append(ft.partial(cls._filter_items, fields=fields))
            if "children" in fields:
                postprocesses.append(ft.partial(cls._add_children, docs=docs))
        else:
            mapper_fields = task_mappers.TaskMapper.get_base_fields()
            if short_task_result:
                mapper_fields.add("short_task_result")
            postprocesses.append(ft.partial(cls._add_children, docs=docs))

        task_mapper = task_mappers.TaskMapper(
            mapper_fields,
            context.current.request.user, context.current.request.base_url,
            semaphore_waiters, weather_map, task_docs=docs,
            types_resolve_func=cls.tasks_meta
        )
        tasks = (task_mapper.dump(doc, legacy=False) for doc in docs)
        for postprocess in postprocesses:
            tasks = postprocess(tasks)
        return flask.current_app.response_class(
            (
                json.dumps(
                    {"limit": limit, "offset": offset, "total": total, "items": list(tasks)},
                    **cls.json_args()
                ).encode("utf8"),
                "\n"
            ),
            httplib.OK, mimetype="application/json; charset=utf8",
        )


class TaskProlongate(RouteV1(v1.task.TaskProlongate)):
    @classmethod
    def post(cls, task_id, delta):
        doc = mapping.Task.objects.with_id(task_id)
        if doc is None:
            raise exceptions.NotFound("Task #{} not found.".format(task_id))

        if not controller.user_has_permission(
            context.current.user, (doc.author, doc.owner)
        ):
            raise exceptions.Forbidden(
                "User `{}` is not permitted to modify task #{}".format(context.current.user.login, task_id)
            )

        if registry.server.auth.enabled and context.current.request.source != ctt.RequestSource.WEB:
            raise exceptions.BadRequest("Kill timeout can be modified only from UI request")

        new_kill_timeout = doc.kill_timeout + delta
        if not controller.OAuthCache.update(doc, kill_timeout=new_kill_timeout):
            raise exceptions.BadRequest("No active sessions for task #{}".format(task_id))

        return "", httplib.NO_CONTENT


class TaskCommit(RouteV1(v1.task.TaskCommit)):
    @classmethod
    def post(cls):
        controller.OAuthCache.commit(context.current.request.session)
        return "", httplib.NO_CONTENT


class TaskCurrentExecution(RouteV1(v1.task.TaskCurrentExecution)):
    @classmethod
    def put(cls, body):
        doc = mapping.Task.objects.lite().with_id(context.current.request.session.task_id)
        update_expr = {}
        utcnow = datetime.datetime.utcnow()
        if body.description is not ctm.NotExists:
            update_expr["set__execution__description"] = body.description
        du, work_du = body.disk_usage, body.work_disk_usage
        if du is not None:
            du_max, du_last = 0, 0
            if mapping.Client.objects(
                hostname=context.current.request.session.client_id,
                pure_tags=str(ctc.Tag.MULTISLOT),
                read_preference=mapping.ReadPreference.SECONDARY
            ).lite().first():
                du_max, du_last = work_du.max or du.max, work_du.last or du.last
                if work_du.resources is not None:
                    du_max += work_du.resources
                    du_last += work_du.resources

            # these values are also reset on task restart, see yasandbox.api.json.batch.BatchTask._op()
            if (du_max, du_last) == (0, 0):
                du_max, du_last = du.max, du.last
            if du_max is not None:
                update_expr.update(set__execution__disk_usage__max=max(doc.execution.disk_usage.max, du_max >> 10))
            if du_last is not None:
                update_expr.update(set__execution__disk_usage__last=du_last >> 10)

            signaler.send_msg(dict(
                type=ctst.SignalType.TASK_DISK_USAGE,
                date=utcnow,
                timestamp=utcnow,
                max_working_set_usage=work_du.max or 0,
                last_working_set_usage=work_du.last or 0,
                reserved_working_set=doc.requirements.disk_space << 20,  # stored in MiB
                resources_synced=work_du.resources or 0,
                place_delta=du_max,
                task_id=doc.id,
                task_type=doc.type,
                owner=doc.owner,
                task_tags=doc.tags
            ))

        if update_expr:
            mapping.Task.objects(id=doc.id).update_one(**update_expr)
        return "", httplib.NO_CONTENT


class TaskAudit(RouteV1(v1.task.TaskAudit)):
    @classmethod
    def get(cls, id_):
        raw_docs = mapping.Audit.objects(task_id=id_).order_by("+date").as_pymongo()
        return mappers.task.TaskAuditMapperRaw.dump_list(raw_docs)


class TaskCurrentAudit(RouteV1(v1.task.TaskCurrentAudit)):
    @classmethod
    def post(cls, body):
        task_id = context.current.request.session.task_id
        task = mapping.Task.objects.lite().with_id(task_id)
        if task is None:
            raise exceptions.NotFound("Task {} not found".format(task_id))

        prev_status = task.execution.status
        client = context.current.request.is_task_session and context.current.request.session.client_id
        try:
            if not body.status:
                controller.Task.audit(mapping.Audit(task_id=task.id, content=body.message))
            else:
                controller.Task.set_status(
                    task,
                    body.status,
                    event=body.message,
                    lock_host=None,
                    keep_lock=None,
                    force=body.force,
                    expected_status=body.expected_status,
                    wait_targets=body.wait_targets
                )
        except common.errors.IncorrectStatus as ex:
            raise exceptions.BadRequest(str(ex))
        except common.errors.UpdateConflict as ex:
            raise exceptions.Conflict(str(ex))
        except Exception as ex:
            context.current.logger.error(
                "Unexpected error while switching status"
                " from %s to %s for task #%s (lock: %s, client: %s, force: %s): %s",
                prev_status, body.status, task.id, body.lock, client, body.force, ex
            )
            raise
        return "", httplib.NO_CONTENT


class TaskAuditList(RouteV1(v1.task.TaskAuditList)):
    LIST_AUDIT_MAX_INTERVAL = datetime.timedelta(days=14)

    @classmethod
    def get(cls, query):
        ids = query["id"]
        remote_ip = query["remote_ip"]

        to = query["to"]
        since = query["since"]

        if to is None:
            to = datetime.datetime.utcnow()
        else:
            to = min(to, datetime.datetime.utcnow())

        if since is None and len(ids) == 0:
            since = to - datetime.timedelta(days=1)

        if since and to - since > cls.LIST_AUDIT_MAX_INTERVAL:
            raise exceptions.BadRequest(
                "Intervals larger than {} days are not allowed".format(cls.LIST_AUDIT_MAX_INTERVAL.days)
            )

        me_kwargs = {}
        if since:
            me_kwargs["date__gte"] = since
        if to:
            me_kwargs["date__lte"] = to
        if ids:
            me_kwargs["task_id__in"] = ids
        if remote_ip:
            me_kwargs["remote_ip"] = remote_ip

        raw_docs = mapping.Audit.objects(**me_kwargs).order_by("+date").as_pymongo()
        return mappers.task.TaskAuditMapperRaw.dump_list(raw_docs)


class TaskResources(RouteV1(v1.task.TaskResources)):
    @classmethod
    def get(cls, id_):
        query = mapping.Resource.objects(task_id=id_).lite().order_by("-id")
        return v1.schemas.resource.ResourceList.create(
            offset=0,
            limit=0,
            total=query.count(),
            items=mappers.resource.ResourceListMapper().dump(query),
        )


class TaskCurrentContextValue(RouteV1(v1.task.TaskCurrentContextValue)):
    @classmethod
    def put(cls, body):
        doc = mapping.Task.objects.lite().with_id(context.current.request.session.task_id)
        if doc is None:
            raise exceptions.NotFound("Task #{} not found.".format(context.current.request.session.task_id))
        ctx = pickle.loads(doc.context)
        ctx[body.key] = body.value
        mapping.Task.objects(id=doc.id).update_one(
            set__context=mapping.Task.context.to_mongo(pickle.dumps(ctx))
        )
        return "", httplib.NO_CONTENT


class TaskContext(RouteV1(v1.task.TaskContext)):
    @classmethod
    def get(cls, id_):
        doc = mapping.Task.objects.lite().with_id(id_)
        if doc is None:
            raise exceptions.NotFound("Task #{} not found.".format(id_))
        return pickle.loads(doc.context)


class TaskCurrentContext(RouteV1(v1.task.TaskCurrentContext)):
    @classmethod
    def put(cls, body):
        doc = mapping.Task.objects(id=context.current.request.session.task_id).update_one(
            set__context=mapping.Task.context.to_mongo(pickle.dumps(body))
        )
        if not doc:
            raise exceptions.NotFound("Task #{} not found.".format(context.current.request.session.task_id))
        return "", httplib.NO_CONTENT


class TaskCurrentTriggerTime(RouteV1(v1.task.TaskCurrentTriggerTime)):
    @classmethod
    def post(cls, body):
        try:
            controller.TimeTrigger.create(controller.TimeTrigger.Model(
                source=context.current.request.session.task_id,
                time=datetime.datetime.utcnow() + datetime.timedelta(seconds=body.period),
                token=context.current.request.session.token
            ))
        except controller.TimeTrigger.AlreadyExists:
            raise exceptions.Conflict(
                "Time trigger for task {} already exists".format(context.current.request.session.task_id)
            )
        return "", httplib.NO_CONTENT


class TaskCurrentTriggerTask(RouteV1(v1.task.TaskCurrentTriggerTask)):
    @classmethod
    def post(cls, body):
        statuses = ctt.Status.Group.expand(body.statuses) - ctt.Status.Group.NONWAITABLE
        if not statuses:
            raise exceptions.BadRequest("Empty statuses list to wait for")

        targets = controller.TaskStatusTrigger.get_not_ready_targets(
            targets=body.targets, statuses=statuses
        )
        if not targets or (not body.wait_all and len(body.targets) != len(targets)):
            raise exceptions.NotAcceptable("Don't need to create trigger")

        try:
            controller.TaskStatusTrigger.create(controller.TaskStatusTrigger.Model(
                source=context.current.request.session.task_id,
                targets=targets,
                statuses=statuses,
                wait_all=body.wait_all,
                token=context.current.request.session.token,
            ))
        except controller.TaskStatusTrigger.AlreadyExists:
            raise exceptions.Conflict(
                "Task trigger for task {} already exists".format(context.current.request.session.task_id)
            )
        if body.timeout:
            try:
                controller.TimeTrigger.create(controller.TimeTrigger.Model(
                    source=context.current.request.session.task_id,
                    time=datetime.datetime.utcnow() + datetime.timedelta(seconds=body.timeout)
                ))
            except controller.TimeTrigger.AlreadyExists:
                raise exceptions.Conflict(
                    "Time trigger for task {} already exists".format(context.current.request.session.task_id)
                )
        return "", httplib.NO_CONTENT


class TaskCurrentTriggerOutput(RouteV1(v1.task.TaskCurrentTriggerOutput)):
    @classmethod
    def post(cls, body):
        try:
            task_ids, fields = [], []
            for tid, flds in body.targets.iteritems():
                for fld in common.utils.chain(flds):
                    task_ids.append(int(tid))
                    fields.append(fld)

            if not task_ids:
                raise ValueError("No output parameters to wait for")
            task_ids_set = set(task_ids)

            tasks = dict(
                mapping.Task.objects(
                    id__in=task_ids_set,
                ).read_preference(mapping.ReadPreference.PRIMARY).fast_scalar("id", "parameters__output")
            )
            for tid in tasks.keys():
                tasks[tid] = set(p.get("k") for p in tasks[tid])

            missing = task_ids_set - set(tasks)
            if missing:
                missing_text = ", ".join(map(str, missing))
                raise ValueError("Cannot wait for non-existing task(s): {}".format(missing_text))

            tf_model = controller.TaskOutputTrigger.Model.TargetField
            targets = [
                tf_model(target=tid, field=field) for tid, field in it.izip(task_ids, fields)
                if field not in tasks[tid]
            ]

            wait_all = body.wait_all
            if not targets or (not wait_all and len(task_ids) != len(targets)):
                raise exceptions.NotAcceptable("")
        except (KeyError, TypeError, ValueError) as ex:
            raise exceptions.BadRequest(str(ex))

        try:
            controller.TaskOutputTrigger.create(controller.TaskOutputTrigger.Model(
                source=context.current.request.session.task_id,
                targets=targets,
                wait_all=wait_all,
                token=context.current.request.session.token
            ))
        except controller.TaskOutputTrigger.AlreadyExists:
            raise exceptions.Conflict("")
        if body.timeout:
            try:
                controller.TimeTrigger.create(controller.TimeTrigger.Model(
                    source=context.current.request.session.task_id,
                    time=datetime.datetime.utcnow() + datetime.timedelta(seconds=body.timeout)
                ))
            except controller.TimeTrigger.AlreadyExists:
                raise exceptions.Conflict("")
        return "", httplib.NO_CONTENT


class TaskSemaphores(RouteV1(v1.task.TaskSemaphores)):
    @classmethod
    def delete(cls):
        controller.TaskQueue.qclient.release_semaphores(context.current.request.session.task_id, None, None)
        return "", httplib.NO_CONTENT


class TaskAuditHosts(RouteV1(v1.task.TaskAuditHosts)):
    @classmethod
    def get(cls, id_):
        intervals = []
        reduce(
            mappers.task.AuditHostsListItemEntry,
            mapping.Audit.objects(task_id=id_).lite().order_by("+date"), intervals
        )
        result = [v1.schemas.task.TaskAuditHostsItem.create(**item) for item in intervals]
        return result
