import json
import hashlib
import datetime as dt

from six.moves import cPickle

import pymongo
import mongoengine as me

from sandbox.common import patterns
import sandbox.common.types.misc as ctm
import sandbox.common.types.task as ctt
import sandbox.common.types.notification as ctn

from . import base

Priority = ctt.Priority


class Audit(base.ConnectionSwitcherMixin, me.Document, base.Aggregatable):
    """
    The class represents a mapping between Python object and database storage for "TaskHistory" entity
    """
    RequestSource = ctt.RequestSource

    meta = {"indexes": ["task_id", "date", "status"]}

    #: Id of task
    task_id = me.IntField(required=True, db_field="ti")
    #: Task status
    status = me.StringField(db_field="st")
    #: Date and time since which task entered to given status
    date = me.DateTimeField(required=True, db_field="dt", default=dt.datetime.utcnow)
    #: Entry content
    content = me.StringField(db_field="ct")
    #: The entry author
    author = me.StringField(db_field="a")
    #: The host, which added the history entry
    hostname = me.StringField(db_field="h")
    #: Request id of XMLRPC method, that cause some changes with task
    request_id = me.StringField(db_field="ri")
    #: Ip of remote host that requested changes
    remote_ip = me.StringField(db_field="ip")
    #: Source by which changes was made (WEB or XMLRPC API)
    source = me.StringField(db_field="sc", choices=list(RequestSource), default=RequestSource.SYS)
    #: Client host that requested changes
    client = me.StringField(db_field="cl")
    #: Id's of tasks or resources awaited by the task
    wait_targets = me.DictField(default=None)

    @classmethod
    def first_preparing_for_tasks(cls, task_ids):
        """
        Get time of first tasks PREPARING state.

        :param task_ids: tasks identifiers of interest
        :return: dict with mapped task identifiers to first task `PREPAIRING` state.
        """

        pipeline = [
            {"$match": {"ti": {"$in": task_ids}, "st": ctt.Status.PREPARING}},
            {"$project": {"ti": 1, "dt": 1}},
            {"$sort": {"dt": 1}},
            {"$group": {"_id": "$ti", "date": {"$first": "$dt"}}},
        ]
        return {o["_id"]: o["date"] for o in Audit.aggregate(pipeline)}


class ParameterValue(me.base.BaseField):
    def validate(self, value, *_):
        try:
            json.dumps(value)
        except (ValueError, TypeError):
            self.error("Parameter value accepts only JSONified types")


class ParametersMeta(base.ConnectionSwitcherMixin, me.Document):
    """ Parameters views' meta information. """

    meta = {"indexes": ["hash", "accessed"]}

    class ParameterMeta(me.EmbeddedDocument):
        meta = {"allow_inheritance": True}

        #: Name
        name = me.StringField(db_field="n", required=True)
        #: Parameter is required for task execution.
        required = me.BooleanField(db_field="re")
        #: Title, short description.
        title = me.StringField(db_field="ti")
        #: Long description.
        description = me.StringField(db_field="de")
        #: View modifiers for GUI, for example: repetition, multiline, checkboxes, format.
        modifiers = me.DictField(db_field="mo")
        #: Context contains fields that define parameter behaviour and filtration option.
        context = me.DictField(db_field="ctx")
        #: Parameter is output.
        output = me.BooleanField(db_field="out")
        #: Type defines (en|de)code functions and UI view
        type = me.StringField(db_field="tp", choices=list(ctt.ParameterType))
        #: Mapping <parameter value> -> <list of dependent parameters' names>.
        sub_fields = me.DictField(db_field="sf")
        #: Parameter is complex and its value is stored in DB as dumped json.
        complex = me.BooleanField(db_field="co")
        #: Default value
        default = ParameterValue(db_field="dv")
        # Copy parameter value
        do_not_copy = me.BooleanField(db_field="cop")

    #: parameters hash
    hash = me.StringField()
    #: parameters meta
    params = me.ListField(me.EmbeddedDocumentField(ParameterMeta))
    #: update time
    accessed = me.DateTimeField(db_field="up", required=True, default=dt.datetime.utcnow)

    @property
    def calculated_hash(self):
        return hashlib.md5("".join(p.to_json(sort_keys=True) for p in self.params)).hexdigest()

    def save(self, **kws):
        self.hash = self.calculated_hash
        super(ParametersMeta, self).save(**kws)


class Template(me.Document):
    """ Task template document. It doesn't store task fields related to execution. """

    meta = {
        "abstract": True,
        "indexes": [
            {
                "fields": ["#parameters_meta"],
                "sparse": True,
            },
        ],
    }

    #: Task requirements.
    class Requirements(me.EmbeddedDocument):

        class RamDrive(me.EmbeddedDocument):
            """ RAM drive requirements specification. """

            Type = ctm.RamDriveType

            #: Required amount of space to be allocated for the RAM drive in megabytes.
            size = me.IntField(min_value=0, required=True)
            #: Type of RAM drive to be mounted on task execution.
            type = me.StringField(choices=list(Type), required=True)

        class Semaphores(me.EmbeddedDocument):
            """ Required semaphore """
            class Acquire(me.EmbeddedDocument):
                #: Semaphore's name
                name = me.StringField(required=True)
                #: Task's weight
                weight = me.IntField(default=1, min_value=1)
                #: Capacity for automatically created semaphore
                capacity = me.IntField(default=0, min_value=0)
                #: Semaphore allowed for all
                public = me.BooleanField(default=False)

            #: Semaphores to acquire
            acquires = me.ListField(me.EmbeddedDocumentField(Acquire))
            #: List of statuses in which semaphore will be released
            release = me.ListField(me.StringField(choices=list(ctt.Status.statuses_and_groups)))

        class Cache(me.EmbeddedDocument):
            #: cache key
            key = me.StringField(db_field="k", required=True)
            #: cache value
            value = me.StringField(db_field="v")

        class TasksResource(me.EmbeddedDocument):
            """ Tasks resource info. """

            #: Id of resource.
            id = base.ReferenceField(required=True)
            #: Taskbox enabled.
            taskbox_enabled = me.BooleanField(db_field="tb")
            #: Age of tasks resource API.
            age = me.IntField(db_field="age")

        class BucketReserve(me.EmbeddedDocument):
            #: bucket name
            bucket = me.StringField()
            #: reserve size in bytes
            size = me.IntField(min_value=0)

        #: Platform name the task is bound on.
        platform = me.StringField(db_field="plat")
        #: paths of binaries
        platforms = me.ListField(me.StringField(), db_field="plats")
        #: CPU model the client should have.
        cpu_model = me.StringField(db_field="cpu")
        #: minimum number of CPU cores
        cores = me.IntField()
        #: Host the task is bound on.
        host = me.StringField(db_field="host")
        #: Resource ID(s) the task is required to run.
        resources = me.ListField(base.ReferenceField(), db_field="res")
        #: Estimated disk usage in MiB provided by the user.
        disk_space = me.IntField(db_field="disk", min_value=0, required=True)
        #: RAM in MiB.
        ram = me.IntField(db_field="ram", default=0, min_value=0, required=True)
        #: RAM drive requirements if specified.
        ramdrive = me.EmbeddedDocumentField(RamDrive)
        #: Execute task with root privileges
        privileged = me.BooleanField(db_field="pri")
        #: DNS type
        dns = me.StringField(choices=list(ctm.DnsType))
        #: Client tags
        client_tags = me.StringField(db_field="tags")
        #: Semaphores
        semaphores = me.EmbeddedDocumentField(Semaphores, db_field="sem")
        #: Porto layers
        porto_layers = me.ListField(base.ReferenceField(), db_field="pl")
        #: Caches
        caches = me.ListField(me.EmbeddedDocumentField(Cache), db_field="ca", default=None)
        #: Tasks resource for task executing
        tasks_resource = me.EmbeddedDocumentField(TasksResource, db_field="tr")
        #: Resource id of the container
        container_resource = base.ReferenceField(db_field="cr")
        #: Space to reserve in MDS buckets for resources the task going to create
        resources_space_reserve = me.ListField(
            me.EmbeddedDocumentField(BucketReserve), db_field="reserve", default=None
        )

    class BaseNotification(me.EmbeddedDocument):
        """ Base notification for task and scheduler. """
        meta = {"allow_inheritance": True}

        # send notification with specified transport param
        transport = me.StringField(db_field="tr", choices=list(ctn.Transport), required=True)
        # list of addresses
        recipients = me.ListField(me.StringField(), db_field="rs", required=True)
        # juggler check status
        check_status = me.StringField(df_field="cs", choices=list(ctn.JugglerStatus))
        # tags for Juggler
        juggler_tags = me.ListField(me.StringField(min_length=1, max_length=128), db_field="jt")

    class Notification(BaseNotification):
        """ Notification info for task. """

        # task statuses that are waiting by the trigger
        statuses = me.ListField(me.StringField(choices=list(ctt.Status)), db_field="st", required=True)

    class Parameters(me.EmbeddedDocument):
        class Parameter(me.EmbeddedDocument):
            #: key
            key = me.StringField(db_field="k", required=True)
            #: value
            value = ParameterValue(db_field="v")
            #: reset on task restart
            reset_on_restart = me.BooleanField(db_field="r")

        #: input parameters
        input = me.ListField(me.EmbeddedDocumentField(Parameter), db_field="i")
        #: output parameters
        output = me.ListField(me.EmbeddedDocumentField(Parameter), db_field="o")

    class ReportInfo(me.EmbeddedDocument):
        #: Label (id).
        label = me.StringField(db_field="l")
        #: Title.
        title = me.StringField(db_field="t")

    class DefaultHooks(me.EmbeddedDocument):
        #: on_create hook
        on_create = me.BooleanField(db_field="c", default=False)
        #: on_save hook
        on_save = me.BooleanField(db_field="s", default=False)
        #: on_enqueue hook
        on_enqueue = me.BooleanField(db_field="e", default=False)

    #: Task owner.
    owner = me.StringField()
    #: Task author
    author = me.StringField(required=True)
    #: Task type.
    type = me.StringField(required=True)
    #: Flags the task should be hidden.
    hidden = me.BooleanField(db_field="hid")
    #: Task queueing priority.
    priority = me.IntField(db_field="pr", required=True)
    #: Flags the task flagged :), i.e., the importance of the task from the user's view.
    flagged = me.BooleanField(db_field="flg")
    #: User-provided description.
    description = me.StringField(db_field="desc")
    #: Task execution requirements.
    requirements = me.EmbeddedDocumentField(Requirements, db_field="req", required=True)
    #: User-provided task context.
    context = me.BinaryField(db_field="ctx")
    #: notification info
    notifications = me.ListField(me.EmbeddedDocumentField(Notification), db_field="noti")
    #: Simultaneously executing tag
    se_tag = me.StringField(db_field="se")
    #: Max allowed number of restarts from status TEMPORARY
    max_restarts = me.IntField(db_field="mr")
    #: Maximum task execution time interval
    kill_timeout = me.IntField(db_field="kt")
    #: Switch to FAILURE on any error
    fail_on_any_error = me.BooleanField(db_field="fae")
    #: Id of resource with tasks archive
    tasks_archive_resource = base.ReferenceField(db_field="tar")  # TODO: Remove it. SANDBOX-5538
    #: Enable detailed disk usage statistics
    dump_disk_usage = me.BooleanField(db_field="ddu")
    #: Tcpdump arguments that are used for network packets logging
    tcpdump_args = me.StringField(db_field="tda")
    #: Task parameters (SDK2)
    parameters = me.EmbeddedDocumentField(Parameters, db_field="param")
    #: Parameters meta
    parameters_meta = me.ReferenceField(ParametersMeta, db_field="pm")
    #: Reports info.
    reports = me.ListField(me.EmbeddedDocumentField(ReportInfo), db_field="rs", default=None)
    #: Task tags
    tags = me.ListField(me.StringField())
    #: Task hints
    hints = me.ListField(me.StringField(), db_field="h")
    #: Explicit hints (set by user, not from task parameters)
    explicit_hints = me.ListField(me.StringField(), db_field="eh")
    #: Task deduplication unique key
    unique_key = me.StringField(db_field="uk")
    #: Expires time delta, in seconds (expires_at = time of restart + expires_delta)
    expires_delta = me.IntField(db_field="exd")
    #: Enable access to Yandex Vault
    enable_yav = me.BooleanField(db_field="yav", default=False)
    #: Suspend on statuses
    suspend_on_status = me.ListField(
        me.StringField(choices=list(ctt.Status.Group.SUSPEND_ON_STATUS)),
        db_field="suspend_st"
    )
    #: Task score
    score = me.IntField(default=0)
    #: push tasks resource to subtasks
    push_tasks_resource = me.BooleanField(db_field="ptr", default=False)
    #: if hosts_match_score is not overridden in task class
    use_default_hosts_match_score = me.BooleanField(db_field="dhms", default=False)
    #: not overridden hooks
    default_hooks = me.EmbeddedDocumentField(DefaultHooks, db_field="hooks")

    @patterns.singleton_property
    def ctx(self):
        return cPickle.loads(self.context) if self.context else {}

    @property
    def is_new(self):
        """ Returns `True` in case the object has just been created and not saved to the database yet. """
        return self._created or "_id" not in self.to_mongo()

    def cast_to(self, dst_cls, fields_cls=None):
        """ Creates instance of `dst_cls` and initializes it with field values of self as instance of `fields_cls` """
        if fields_cls is None:
            fields_cls = Template

        assert isinstance(self, fields_cls)
        assert issubclass(dst_cls, fields_cls)

        son_obj = self.to_mongo(use_db_field=False, fields=fields_cls._fields)
        son_obj.pop("_id", None)  # `_id` is dumped by default
        son_obj.pop("id", None)  # `id` is dumped if `fields_cls` has it defined

        return dst_cls(**son_obj)

    def clone(self):
        return self.cast_to(type(self), fields_cls=type(self))


class Task(base.ConnectionSwitcherMixin, Template, base.Aggregatable):
    """
    The class represents a mapping between Python object and database storage for "Task" entity.
    """
    meta = {
        "indexes": [
            "type",
            "parent_id",
            "template_alias",
            "scheduler",
            "owner",
            "author",
            "hidden",
            "lock_host",
            "time.created",
            "time.updated",
            "requirements.resources",
            "execution.host",
            "execution.status",
            "release.creation_time",
            "parameters.output.key",
            "acquired_semaphore",
            "tags",
            "hints",
            "unique_key",
            "expires_at"
        ],
    }

    #: Task time marks.
    class Time(me.EmbeddedDocument):
        #: Creation date and time.
        created = me.DateTimeField(db_field="ct", required=True, default=dt.datetime.utcnow)
        #: Task's status last change time.
        updated = me.DateTimeField(db_field="up", required=True, default=dt.datetime.utcnow)

    #: Last task run result.
    class Execution(me.EmbeddedDocument):
        Status = ctt.Status

        #: Execution time marks.
        class Time(me.EmbeddedDocument):
            #: Date and time the task started to execute.
            started = me.DateTimeField(db_field="st")
            #: Date and time the task finished the execution.
            finished = me.DateTimeField(db_field="fin")

        class DiskUsage(me.EmbeddedDocument):
            #: Maximal disk usage by task during execution
            max = me.IntField(db_field="m", min_value=0, default=0)
            #: Used disk space after task execution
            last = me.IntField(db_field="l", min_value=0, default=0)

        class AutoRestart(me.EmbeddedDocument):
            #: Number of remaining restarts from status TEMPORARY
            left = me.IntField(db_field="r")
            #: Time in seconds to wait for next restart
            interval = me.IntField(db_field="i")

        #: Time intervals for various stages of execution
        class Intervals(me.EmbeddedDocument):
            class IntervalData(me.EmbeddedDocument):
                #: Start of the interval
                start = me.DateTimeField(db_field="st")
                #: Duration of the interval
                duration = me.IntField(db_field="d")
                #: Quota consumption in the interval
                consumption = me.IntField(db_field="qp")
                #: Quota pool name
                pool = me.StringField(db_field="p")

                @property
                def finish(self):
                    """ End of the interval """
                    if self.duration is not None:
                        delta = dt.timedelta(seconds=self.duration)
                        return self.start + delta

                @classmethod
                def from_mongo(cls, value):
                    return cls._from_son(value)

            #: Intervals for task staying in the queue
            queue = me.ListField(me.EmbeddedDocumentField(IntervalData), db_field="q")
            #: Intervals for task staying in WAIT_* state
            wait = me.ListField(me.EmbeddedDocumentField(IntervalData), db_field="w")
            #: Intervals for task staying in EXECUTE state
            execute = me.ListField(me.EmbeddedDocumentField(IntervalData), db_field="e")

        #: Client hostname the task execution was performed on.
        host = me.StringField()
        #: Execution status.
        status = me.StringField(choices=list(Status), db_field="st", required=True)
        #: Time marks.
        time = me.EmbeddedDocumentField(Time, required=True)
        #: Time intervals of task execution stages
        intervals = me.EmbeddedDocumentField(Intervals, db_field="in")
        #: Description, which was provided during the execution.
        description = me.StringField(db_field="desc")
        #: Disk usage by task
        disk_usage = me.EmbeddedDocumentField(DiskUsage, db_field="du")
        # Parameters for method `on_release`
        release_params = me.DictField(db_field="rel")
        #: Parameters of auto restart from TEMPORARY
        auto_restart = me.EmbeddedDocumentField(AutoRestart, db_field="ar", required=True)

        @property
        def last_execution_start(self):
            if self.intervals and self.intervals.execute:
                return self.intervals.execute[-1].start
            return self.time.started

        @property
        def last_execution_finish(self):
            if self.intervals and self.intervals.execute:
                return self.intervals.execute[-1].finish
            return self.time.finished

    class Release(me.EmbeddedDocument):
        """
        release info document for task
        """
        Status = ctt.ReleaseStatus

        class Message(me.EmbeddedDocument):
            # release summary - name, tag, branch, etc
            subject = me.StringField(required=True)
            # text about changes and release info
            body = me.StringField()

        # release creation time
        creation_time = me.DateTimeField(db_field="ctime", default=dt.datetime.utcnow, required=True)
        # release author
        author = me.StringField(required=True)
        # release status (one of Status class values)
        status = me.StringField(choices=list(Status), required=True)
        # user message for release (notification)
        message = me.EmbeddedDocumentField(Message, required=True)
        # bulet list with changes
        changelog = me.ListField(me.StringField())

    #: Object ID - atomically incremented positive integer, primary key.
    id = me.SequenceField(primary_key=True)
    #: Parent task ID if the entry is actually a sub-task.
    parent_id = base.ReferenceField(db_field="pid")
    #: Scheduler ID if the task has been created by scheduler.
    scheduler = me.IntField(db_field="sch")
    #: Template alias if the task has been created from template.
    template_alias = me.StringField(db_field="tmpl")
    #: Time marks.
    time = me.EmbeddedDocumentField(Time, required=True)
    #: Last execution results.
    execution = me.EmbeddedDocumentField(Execution, db_field="exc", required=True)
    #: release info
    release = me.EmbeddedDocumentField(Release, db_field="rel")
    #: Server hostname that acquired lock for the task
    lock_host = me.StringField(db_field="lock")
    #: Time when the lock was acquired
    lock_time = me.DateTimeField(db_field="ltime")
    #: Id of task from which this task is copied
    source_id = base.ReferenceField(db_field="sid")
    #: Shows that the task is acquired semaphore(s)
    acquired_semaphore = me.BooleanField(db_field="sem", default=False)
    #: Time when the task expires
    expires_at = me.DateTimeField(db_field="ext")
    #: User requested last action with task
    last_action_user = me.StringField()
    #: Effective client tags for last enqueueing
    effective_client_tags = me.StringField(db_field="ect")
    #: Links to TaskStatusEvent records
    status_events = me.ListField(me.StringField(), db_field="tse")

    @classmethod
    def ensure_indexes(cls):
        """
        The method will provide query indexes usage hints in  addition to indexes creation for the collection.
        """
        super(Task, cls).ensure_indexes()
        db = cls._get_db()
        # Provide here all popular queries hints with sorting by ID, otherwise Mongo will perform full collection scan
        # because its query planner selects index by ID when sorting by that field.
        db.command(
            "planCacheSetFilter", cls._get_collection_name(),
            query={"sch": "X"}, sort={"_id": pymongo.DESCENDING},
            indexes=[{"sch": 1}]
        )
        db.command(
            "planCacheSetFilter", cls._get_collection_name(),
            query={
                "type": {"$in": ["X"]},
                "pid": {"$exists": False},
                "hid": {"$ne": True},
                "exc.st": {"$ne": ctt.Status.DELETED}
            },
            sort={"_id": pymongo.DESCENDING},
            indexes=[{"type_1": 1}, {"_id_": 1}]
        )

    @classmethod
    def last_task_per_scheduler(cls, scheduler_ids):
        """
        Get identifiers of last tasks created by schedulers.

        :param scheduler_ids: schedulers identifiers of interest
        :return: dict with mapped scheduler identifiers to last task created through scheduler.
        """
        """
        :param scheduler_ids:
        :return:
        """
        pipeline = [
            {"$match": {"sch": {"$in": scheduler_ids}}},
            {"$project": {"_id": 1, "sch": 1}},
            {"$sort": {"_id": -1}},
            {"$group": {"_id": "$sch", "task_id": {"$first": "$_id"}}},
        ]
        return {o["_id"]: o["task_id"] for o in cls.aggregate(pipeline)}

    @classmethod
    def tasks_per_status(cls, delta=dt.timedelta(hours=1), exceptional=True):
        """
        Get tasks per status amount. For exceptional statuses (FAILURE, SUCCESS, DELETED, DRAFT)
        returns only amount of tasks, finished in [now - delta, now) period, if requested.
        In case of no delta specified, it will just calculate absolute amount of tasks per status.

        :param delta:       timedelta object to bound minimal finished time.
        :param exceptional: calculate deltas only for exceptional statuses (see below for detailed list).
        :return: Amount of tasks per status
        :rtype: dict
        """
        exceptional_statuses = [
            ctt.Status.FAILURE,
            ctt.Status.SUCCESS,
            ctt.Status.DELETED,
            ctt.Status.RELEASED,
            ctt.Status.DRAFT,
        ]

        pipeline = []
        res = {st.lower(): 0 for st in ctt.Status}
        if delta:
            now = dt.datetime.utcnow()
            match = {
                "exc.time.fin": {
                    "$gte": now - delta - dt.timedelta(seconds=1),
                    "$lt": now - dt.timedelta(seconds=1)
                },
            }
            if exceptional:
                match["exc.st"] = {"$in": exceptional_statuses}
            pipeline = [
                {"$match": match},
                {"$group": {"_id": "$exc.st", "amount": {"$sum": 1}}}
            ]
            res.update({res["_id"].lower(): res["amount"] for res in cls.aggregate(pipeline)})
            if not exceptional:
                return res
        pipeline = [{"$match": {"exc.st": {"$nin": exceptional_statuses}}}]
        pipeline.append({"$group": {"_id": "$exc.st", "amount": {"$sum": 1}}})
        res.update({res["_id"].lower(): res["amount"] for res in cls.aggregate(pipeline)})
        return res

    @classmethod
    def tasks_count(cls, delta=None, by="type", status=None):
        """
        Get tasks per given field amount for the given period of time and given status.

        :param delta:   timedelta object to bound minimal finished time, all tasks if `None`.
        :param by:      field to group collection by.
        :param status:  filter tasks by status if provided.
        :return:        Amount of tasks per given field
        :rtype: dict
        """
        match = {}
        if delta:
            match["time.updated"] = {"$gte": dt.datetime.utcnow() - delta}
        if status:
            match["exc.st"] = {"$in": list(status)} if hasattr(status, "__iter__") else status

        pipeline = [{"$match": match}] if match else []
        pipeline.append({"$group": {"_id": "$" + by, "amount": {"$sum": 1}}})
        return {res["_id"]: res["amount"] for res in cls.aggregate(pipeline)}


class ClientTagsToHostsCache(base.ConnectionSwitcherMixin, me.Document):
    """
    Cache for hosts matching client tag expression
    """

    #: Tag expression
    client_tags = me.StringField(primary_key=True)
    #: Hosts list
    hosts = me.ListField(me.StringField())
    #: Time when last accessed
    accessed = me.DateTimeField(default=dt.datetime.utcnow, required=True)


class TaskTagCache(base.ConnectionSwitcherMixin, me.Document):
    """
    All task tags that had been assigned
    """
    meta = {"indexes": ["login", "hits"]}

    #: Task tag
    tag = me.StringField(primary_key=True)
    #: login of the users who created the tag
    login = me.StringField(required=True)
    #: Time when created
    created = me.DateTimeField(db_field='ct', default=dt.datetime.utcnow, required=True)
    #: Time when last accessed
    accessed = me.DateTimeField(db_field='at', default=dt.datetime.utcnow, required=True)
    #: number of tasks with this tag
    hits = me.IntField()
