import logging
import collections

import six

import sandbox.common.types.task as ctt
import sandbox.common.types.resource as ctr
from sandbox.common import errors as common_errors
from sandbox import sdk2


class MoveResourcesToMds(sdk2.Task):
    RESOURCES_LIMIT = 500

    class Requirements(sdk2.Requirements):
        cores = 4
        ram = 8192
        disk_space = 3522

        class Caches(sdk2.Requirements.Caches):
            pass  # means that task do not use any shared caches

    class Parameters(sdk2.Parameters):
        with sdk2.parameters.Group("Resource filter") as resource_filter:
            resource_id = sdk2.parameters.List("Resource ids for mds uploading")
            resource_owner = sdk2.parameters.String("Resource owner")
            resource_type = sdk2.parameters.String("Resource type")
            resource_task = sdk2.parameters.List("Task ids of resource for mds uploading")
        with sdk2.parameters.Group("Upload options") as upload_options:
            sync = sdk2.parameters.Bool("Sync upload to mds", default=False)
            force = sdk2.parameters.Bool("Force upload to mds", default=False)

    def resources_list(self):
        resources_query = {"state": ctr.State.READY}
        for parameter in list(type(self).Parameters.resource_filter):
            name = parameter.name
            value = getattr(self.Parameters, name)
            if not value:
                continue
            name = name.split("_")[1]
            resources_query[name] = value

        logging.info("Find resources with query: %s", resources_query)
        resource_ids = [
            _["id"] for _ in self.server.resource.read(limit=self.RESOURCES_LIMIT, **resources_query)["items"]
        ]

        return [self.server.resource[res_id].read() for res_id in resource_ids]

    def hosts_by_resources(self):
        resources = self.resources_list()
        resources_dict = {res["id"]: res for res in resources}

        hosts_counter = collections.defaultdict(list)

        for resource in six.itervalues(resources_dict):
            logging.info("Check resource: %s", resource)
            for host in resource["sources"]:
                hosts_counter[host].append(resource["id"])

        target_hosts = {}
        hosts_counter = sorted(six.iteritems(hosts_counter), key=lambda _: len(_[1]), reverse=True)

        for host, resource_ids in hosts_counter:
            target_resources = []
            for res_id in resource_ids:
                if res_id in resources_dict:
                    target_resources.append(res_id)
            if target_resources:
                target_hosts[host] = target_resources
                for res_id in target_resources:
                    resources_dict.pop(res_id)

        logging.info("Hosts to resources map: %s", target_hosts)
        return target_hosts

    def on_execute(self):
        if self.Requirements.host and self.parent:
            upload_options = {}

            for parameter in type(self).Parameters.upload_options:
                upload_options[parameter.name] = getattr(self.Parameters, parameter.name)

            logging.info(
                "Upload resources with upload options %s to mds: %s", upload_options, self.Parameters.resource_id
            )

            uploaded_queue = self.agentr.upload_to_mds(self.Parameters.resource_id, **upload_options)
            logging.info("Uploaded queue: %s", uploaded_queue)
        else:
            with self.memoize_stage["create_subtasks"]:
                hosts_by_resources = self.hosts_by_resources()
                if not hosts_by_resources:
                    logging.info("There are no resources for mds upload")
                    return
                self.Context.subtasks = []

                for host, resource_ids in six.iteritems(hosts_by_resources):
                    subtask = self.type(
                        self, resource_id=resource_ids, sync=self.Parameters.sync, force=self.Parameters.force,
                        __requirements__={"host": host}
                    )
                    self.Context.subtasks.append(subtask.enqueue().id)

                raise sdk2.WaitTask(
                    self.Context.subtasks, ctt.Status.Group.FINISH + ctt.Status.Group.BREAK, wait_all=True
                )

            failed_tasks = []

            for task_id in self.Context.subtasks:
                task = sdk2.Task[task_id]
                if task.status in list(ctt.Status.Group.BREAK) + [ctt.Status.FAILURE]:
                    failed_tasks.append(task.id)

            if failed_tasks:
                raise common_errors.TaskError("Subtasks {} are failed.".format(", ".join(failed_tasks)))
