import re
import json
import math
import datetime as dt
import textwrap
import functools as ft
import itertools as it
import distutils.util

import six
from six.moves import cPickle

from sandbox.common import data as common_data
from sandbox.common import lazy as common_lazy
from sandbox.common import config as common_config
from sandbox.common import encoding as common_encoding
from sandbox.common import patterns as common_patterns
from sandbox.common import platform as common_platform
from sandbox.common import itertools as common_itertools
import sandbox.common.types.misc as ctm
import sandbox.common.types.task as ctt
import sandbox.common.types.user as ctu
import sandbox.common.types.client as ctc
import sandbox.common.types.resource as ctr

from sandbox.yasandbox.database import mapping
from sandbox.yasandbox import context

import sandbox.sdk2.parameters
import sandbox.sdk2.internal.task


def encode_parameter(value, complex_type):
    if complex_type and value not in ("", None):
        return json.dumps(value)
    return value


def decode_parameter(value, complex_type):
    if complex_type and isinstance(value, six.string_types) and value:
        return json.loads(value)
    if isinstance(value, float) and (math.isnan(value) or math.isinf(value)):
        return str(value)
    return value


def ramdrive_getter(data):
    if not data:
        return None

    rd = mapping.Task.Requirements.RamDrive(**data)
    rd.validate()
    rd.size >>= 20
    return rd


class LocalUpdate(object):
    """ Class to synchronize DB model with sdk2-task object. Methods of class DON'T require direct access to DB. """

    NOT_NAME_RE = re.compile(r'[^a-z0-9]', re.IGNORECASE)

    @staticmethod
    def _validate_tags(tags):
        value = [tag.upper() for tag in common_itertools.chain(tags)]
        for tag in value:
            ctt.TaskTag.test(tag)
        return value

    @classmethod
    def update_tags(cls, tags, target):
        tags = cls._validate_tags(tags)
        target.tags = tags
        return tags

    @classmethod
    def _update_caches(cls, caches, model):
        if caches is None:
            model.requirements.caches = None
        elif isinstance(caches, dict):
            model.requirements.caches = [
                mapping.Task.Requirements.Cache(key=key, value=value)
                for key, value in six.iteritems(caches)
            ]

    @classmethod
    def set_expires(cls, expires, model):
        model.expires_delta = expires
        if expires is None:
            model.expires_at = None
        else:
            model.expires_at = dt.datetime.utcnow() + dt.timedelta(seconds=expires)

    @classmethod
    def set_explicit_hints(cls, hints, model):
        """ Set explicit hints in model to be saved later """
        hints = filter(None, [str(h) for h in hints])
        model.explicit_hints = hints

    @classmethod
    def update_hints(cls, parameters, model):
        """
        :type parameters: sandbox.sdk2.Parameters
        :param model:
        :return:
        """
        hints = set(model.explicit_hints or [])
        for param in common_itertools.chain(model.parameters.input, model.parameters.output):
            pclass = getattr(parameters, param.key, None)
            if pclass and pclass.hint and param.value:
                hints.add(str(param.value))
        model.hints = list(hints)

    @classmethod
    def _update_tasks_resource(cls, resource_id, model):
        # explicitly reset value to mark mongoengine field as "changed"
        model.tasks_archive_resource = None
        if not resource_id or (
            model.requirements.tasks_resource and model.requirements.tasks_resource.id
        ) != resource_id:
            model.parameters_meta = None
            model.reports = None
        model.requirements.tasks_resource = None

        if resource_id:
            if not model.requirements.tasks_resource:
                model.requirements.tasks_resource = model.Requirements.TasksResource()
            model.tasks_archive_resource = model.requirements.tasks_resource.id = int(resource_id)

    @classmethod
    @context.timer_decorator()
    def update_common_fields(cls, data, model):
        requirements = data.get("requirements", {})
        settings = common_config.Registry()
        for checker, getter, setter in it.chain(
            (
                (
                    lambda: n in data,  # noqa
                    lambda: (f if f else lambda _: _)(data.get(n)),  # noqa
                    ft.partial(setattr, model, n)  # noqa
                )
                for n, f in (
                    ("enable_yav", bool),
                    ("kill_timeout", common_data.force_int),
                    ("fail_on_any_error", bool),
                    ("dump_disk_usage", bool),
                    ("tcpdump_args", None),
                    ("description", None),
                    ("hidden", bool),
                    ("owner", None),
                    ("max_restarts", common_data.force_int),
                    ("unique_key", None),
                    ("suspend_on_status", None),
                    ("push_tasks_resource", None)
                )
            ),
            (
                (
                    lambda: n in requirements,  # noqa
                    lambda: f(requirements.get(n)),  # noqa
                    ft.partial(setattr, model.requirements, n)  # noqa
                )
                for n, f in (
                    ("disk_space", lambda _: common_data.force_int(_) >> 20),
                    ("ram", lambda _: common_data.force_int(_) >> 20),
                    ("platform", lambda _: (_ or "").lower()),
                    ("host", lambda _: (
                        settings.this.id
                        if settings.common.installation == ctm.Installation.LOCAL else
                        (_ or "").lower()
                    )),
                    ("cores", common_data.force_int),
                    ("dns", lambda _: ctm.DnsType.val2str(_) and _),
                    ("cpu_model", lambda _: (_ or "").lower()),
                    ("privileged", bool),
                    ("client_tags", lambda _: str(ctc.Tag.Query.cast(_)) if _ else model.requirements.client_tags),
                    ("porto_layers", lambda _: list(common_itertools.chain(_))),
                    (
                        "semaphores",
                        lambda _: (
                            mapping.Task.Requirements.Semaphores(**ctt.Semaphores(**_).to_dict())
                            if _ else
                            None
                        )
                    ),
                    ("ramdrive", ramdrive_getter),
                    ("container_resource", ft.partial(common_data.force_int, default=None)),
                    (
                        "resources_space_reserve",
                        lambda _: (
                            [
                                mapping.Task.Requirements.BucketReserve(bucket=bucket, size=size)
                                for bucket, size in sandbox.sdk2.parameters.ResourcesSpaceReserveValue(_).items()
                            ]
                            if _ else
                            None
                        )
                    ),
            )
            ),
            (
                (
                    lambda: "tags" in data,
                    lambda: list(common_itertools.chain(data.get("tags"))),
                    lambda tags: cls.update_tags(tags, model),
                ),
                (
                    lambda: "hints" in data,
                    lambda: list(common_itertools.chain(data.get("hints"))),
                    lambda hints: cls.set_explicit_hints(hints, model),
                ),
                (
                    lambda: "expires" in data,
                    lambda: common_data.force_int(data.get("expires"), None),
                    lambda expires: cls.set_expires(expires, model),
                ),
                (
                    lambda: "tasks_archive_resource" in data and "tasks_resource" not in requirements,
                    lambda: data.get("tasks_archive_resource"),
                    lambda resource_id: cls._update_tasks_resource(resource_id, model),
                ),
                (
                    lambda: data.get("score") is not None,
                    lambda: data.get("score"),
                    lambda score: setattr(model, "score", score)
                )
            )
        ):
            if checker():
                setter(getter())

        cls._update_caches(requirements.get("caches", ctm.NotExists), model)

        priority = data.get("priority")
        if priority:
            model.priority = int(ctt.Priority.make(priority))

    @classmethod
    @context.timer_decorator()
    def update_requirements(cls, task, model):
        cls._update_tasks_resource(task.Requirements.tasks_resource and int(task.Requirements.tasks_resource), model)
        model.requirements.platforms = None
        if task.Requirements.tasks_resource:
            if isinstance(task.Requirements.tasks_resource, int):
                raise ValueError("Tasks resource {} not found".format(task.Requirements.tasks_resource))
            binary_arch = common_platform.get_arch_from_platform(task.Requirements.tasks_resource.arch)
            req_platform = model.requirements.platform
            if binary_arch != ctm.OSFamily.ANY and (not req_platform or req_platform == ctm.OSFamily.ANY):
                model.requirements.platform = binary_arch
            for platform_name in ("linux_platform", "osx_platform", "osx_arm_platform", "win_nt_platform"):
                if getattr(task.Requirements.tasks_resource.system_attributes, platform_name, None):
                    if model.requirements.platforms is None:
                        model.requirements.platforms = []
                    model.requirements.platforms.append(platform_name.rsplit("_", 1)[0])
        for p in task.__class__.Requirements:
            if p.dummy or p.__static__:
                continue

            value = p.__encode__(getattr(task.Requirements, p.name))
            field = getattr(mapping.Task.Requirements, p.name)
            if value and isinstance(field, mapping.me.EmbeddedDocumentField):
                value = (
                    field.document_type(**value)
                    if isinstance(value, dict) else
                    field.document_type(value)
                )
            elif (
                value and
                isinstance(field, mapping.me.ListField) and
                isinstance(field.field, mapping.me.EmbeddedDocumentField)
            ):
                value = [field.field.document_type(**v) for v in value]
            setattr(model.requirements, p.name, value)

        cls._update_caches(task.Requirements.Caches.__getstate__(), model)

    @classmethod
    @context.timer_decorator()
    def update_parameters(cls, task, model):
        input_params = []
        output_params = []
        for p in task.__class__.Parameters:
            if p.dummy or (p.__output__ and p.name not in task.Parameters.__values__):
                continue

            # TODO: remove after all SDK2 tasks will be fixed [SANDBOX-6188]
            # update container in requirements [SANDBOX-7051]
            if issubclass(p, sandbox.sdk2.parameters.Container):
                container = getattr(task.Parameters, p.name)
                model.requirements.container_resource = container and int(container)

            obj = mapping.Task.Parameters.Parameter(
                key=p.name,
                value=encode_parameter(
                    sandbox.sdk2.internal.task.safe_encode(p, p.name, getattr(task.Parameters, p.name)),
                    p.__complex_type__
                )
            )
            if p.__output__:
                output_params.append(obj)
            else:
                input_params.append(obj)

        # This is only needed for schedulers. In tasks `enable_yav` is set during `on_save`.
        if hasattr(task.Parameters, "enable_yav"):
            model.enable_yav = task.Parameters.enable_yav

        if model.parameters is None:  # ensure parameters existence for SDK1->SDK2 migrated tasks
            model.parameters = mapping.Task.Parameters()

        model.parameters.input = input_params
        model.parameters.output = output_params

    @classmethod
    def task_reports(cls, task_cls):
        return [
            mapping.Template.ReportInfo(label=r.label, title=r.title)
            for r in six.viewvalues(task_cls.__reports__)
        ]

    @classmethod
    @context.timer_decorator()
    def update_reports(cls, task, model):
        model.reports = cls.task_reports(task.__class__)

    @classmethod
    @context.timer_decorator()
    def update_notifications(cls, task, model):
        model.notifications = [
            mapping.Task.Notification(
                transport=notification.transport,
                statuses=notification.statuses,
                recipients=notification.recipients,
                check_status=getattr(notification, "check_status", None),
                juggler_tags=getattr(notification, "juggler_tags", [])
            )
            for notification in (task.Parameters.notifications if task.Parameters.notifications else [])
        ]

    @classmethod
    @context.timer_decorator()
    def update_context(cls, task, model):
        ctx = cPickle.loads(model.context) if model.context else {}
        ctx.update(task.Context.__getstate__())
        model.context = six.ensure_binary(cPickle.dumps(ctx, protocol=2))

    @classmethod
    def parameter_meta(cls, parameter_meta_class, pc, short=False):
        """
        Make parameter meta model from class
        :param parameter_meta_class: class of parameter_meta
        :param pc: parameter class
        :return:
        """
        pm = parameter_meta_class()
        if pc.ui and not short:
            pm.type = pc.ui.type
            pm.modifiers = pc.ui.modifiers
            # pm.context = pc.ui.context  # ui.context could be removed since pm.context will be redefined.

        pm.name = pc.name
        pm.output = bool(getattr(pc, "__output__", False))
        pm.complex = getattr(pc, "__complex_type__", None)
        # don't call class property's function
        default_value = None if common_patterns.is_classproperty(pc, "default_value") else pc.default_value
        if not isinstance(default_value, common_lazy.Deferred):
            pm.default = encode_parameter(
                (
                    sandbox.sdk2.internal.task.safe_encode(pc, pc.name, default_value)
                    if hasattr(pc, "__encode__") else
                    default_value
                ),
                pm.complex
            )
        if short:
            return pm
        pm.required = pc.required
        pm.title = common_encoding.force_unicode_safe(pc.description)
        pm.description = textwrap.dedent(common_encoding.force_unicode_safe(pc.__doc__ or ""))
        pm.do_not_copy = getattr(pc, "do_not_copy", False)

        ctx = {}
        if hasattr(pc, "get_custom_parameters"):
            ctx = pc.get_custom_parameters() or {}
        ctx.pop("sub_fields", None)
        pm.context = ctx

        if hasattr(pc, "multiline") and pc.multiline is not None:
            pm.modifiers = pm.modifiers or {}
            pm.modifiers["multiline"] = pc.multiline

        if hasattr(pc, "sub_fields") and pc.sub_fields:
            pm.sub_fields = pc.sub_fields

        return pm

    @classmethod
    def parameters_meta_list(cls, parameters_classes, parameter_meta_class, short=False):
        """
       Construct list of parameter_meta_class meta objects
       :param parameters_classes: parameter classe
       :param parameter_meta_class: class of parameter meta
       :return: list of parameter_meta_class instances
       """

        def block(title):
            return parameter_meta_class(
                name="_grp_" + cls.NOT_NAME_RE.sub("_", title.lower()) if title is not None else "",
                title=title,
                type=ctt.ParameterType.BLOCK,
                required=False,
            )

        parameters_meta = []
        current_group = None

        for pc in parameters_classes:
            pm = cls.parameter_meta(parameter_meta_class, pc, short=short)
            if not short:
                if pc.group and current_group != pc.group:
                    current_group = pc.group
                    if pm.type != ctt.ParameterType.BLOCK:  # sdk1 doesn't generate group first block views, sdk2 does
                        parameters_meta.append(block(current_group))
                elif not pc.group and current_group:
                    current_group = None
                    parameters_meta.append(block(current_group))

            parameters_meta.append(pm)

        return parameters_meta

    @classmethod
    def parameters_meta(cls, parameters_classes, short=False):
        """
        :type parameters_classes: list[sandbox.sdk2.legacy.SandboxParameter]
        :rtype: mapping.ParametersMeta
        """
        parameters_meta = cls.parameters_meta_list(
            parameters_classes, mapping.ParametersMeta.ParameterMeta, short=short
        )
        return mapping.ParametersMeta(params=parameters_meta)

    @classmethod
    def template_parameters_meta(cls, parameters_classes):
        """
        :type parameters_classes: list[sandbox.sdk2.legacy.SandboxParameter]
        :rtype: list(TaskTemplate.Task.ParameterMeta)
        """
        return [
            pm for pm in cls.parameters_meta_list(parameters_classes, mapping.TaskTemplate.Task.ParameterMeta)
            if not pm.output and pm.name and pm.title is not None
        ]


class ServerSideUpdate(object):
    """
    Class to synchronize DB model with sdk2-task object.
    Methods of class REQUIRE server side execution to access DB directly.
    """

    @staticmethod
    def _update_tags_cache(old_tags, tags):
        new_tags = set(tags) - set(old_tags)
        request = getattr(mapping.base.tls, "request", None)
        for tag in new_tags:
            mapping.TaskTagCache.objects(tag=tag).update_one(
                upsert=True,
                set_on_insert__login=request.user.login if request else ctu.ANONYMOUS_LOGIN,
                inc__hits=1,
            )
        mapping.TaskTagCache.objects(tag__in=tags).update(set__accessed=dt.datetime.now())

    @classmethod
    def update_tags(cls, tags, target):
        old_tags = common_itertools.chain(target.tags)
        tags = LocalUpdate.update_tags(tags, target)
        cls._update_tags_cache(old_tags, tags)

    @classmethod
    def validate_tasks_resource(cls, resource_id):
        # TODO: move this method somewhere else (away from `update`) -- SANDBOX-5895
        resource = mapping.Resource.objects.with_id(resource_id)
        if resource is None:
            raise ValueError("Tasks resource error: there is no tasks resource with id={!r}".format(resource_id))

        from sandbox.sdk2 import service_resources as sr
        types = (sr.SandboxTasksArchive, sr.SandboxTasksImage, sr.SandboxTasksBinary)
        if resource.type not in types:
            raise ValueError(
                "Tasks resource error: incompatible type {!r}, must be one of {}".format(resource.type, types)
            )

        return resource

    @classmethod
    def update_tasks_resource(cls, resource_id, model):
        if not model.requirements:
            model.requirements = model.Requirements()
        if not resource_id or (
            model.requirements.tasks_resource and model.requirements.tasks_resource.id
        ) != resource_id:
            model.parameters_meta = None
            model.reports = None
        if resource_id is None:
            model.requirements.tasks_resource = None
            return

        resource = cls.validate_tasks_resource(resource_id)

        if not model.requirements.tasks_resource:
            model.requirements.tasks_resource = model.Requirements.TasksResource()

        cls.updated_tasks_resource_model(model.requirements.tasks_resource, resource)

    @classmethod
    def updated_tasks_resource_model(cls, tasks_resource, resource):
        tasks_resource.id = resource.id
        tasks_resource.taskbox_enabled = False
        tasks_resource.age = 0
        for attr in resource.attributes:
            try:
                if attr.key == ctr.BinaryAttributes.TASKBOX_ENABLED:
                    tasks_resource.taskbox_enabled = bool(distutils.util.strtobool(attr.value))
                elif attr.key == ctr.BinaryAttributes.BINARY_AGE:
                    tasks_resource.age = int(attr.value)
            except ValueError as er:
                raise ValueError("Tasks resource error: invalid '{}' attribute: {}".format(attr.key, er))

    @classmethod
    def postprocess_common_fields(cls, model, old_tags):
        """
        Perform server-side postprocessing of updated common fields

        :param model: task model
        :param old_tags: list of task tags before the update
        """
        changed_fields = model._get_changed_fields()

        if "tags" in changed_fields:
            cls._update_tags_cache(old_tags, model.tags)

        if "req.tr" in changed_fields or "req.tr.id" in changed_fields:
            cls.update_tasks_resource(
                model.requirements.tasks_resource and model.requirements.tasks_resource.id,
                model
            )
