import time
import math
import random
import logging

from sandbox import common

from sandbox import sdk2
from sandbox.sdk2.helpers import misc as sdk2_misc


class ServiceCommand(sdk2.Task):
    """
    The task sends service command for specified cluster servers.
    Also the task distributes the command sending on a time range.
    """

    # Maximum amount of hosts to be scheduled for reload in one request
    RELOAD_CHUNK_SIZE = 500

    class Requirements(sdk2.Requirements):
        cores = 1
        ram = 1 << 10
        disk_space = 1 << 10

        class Caches(sdk2.Requirements.Caches):
            pass

    class Parameters(sdk2.Parameters):
        kill_timeout = 300

        with sdk2.parameters.String("Command to send", multiline=True) as command:
            command.values.cleanup = command.Value("CLEANUP", default=True)
            command.values.reload = command.Value("RELOAD")
            command.values.reboot = command.Value("REBOOT")

        with sdk2.parameters.Group("Clients to send service command immediately in one chunk") as group_immediate:
            immediate_by_groups = sdk2.parameters.Bool("Specify hosts by client tags", default=True)
            with immediate_by_groups.value[True]:
                immediate_groups = sdk2.parameters.String(
                    "Clients tag expression for immediate command", default="GENERIC | CUSTOM_IMAGE",
                    description="Clients tag expression to send service command immediately in one chunk"
                )
            with immediate_by_groups.value[False]:
                immediate_hosts = sdk2.parameters.String(
                    "Hosts for immediate command",
                    description="Space-separated list of client ids to send service command "
                    "immediately in one chunk (braces allowed)"
                )

        with sdk2.parameters.Group("Clients to send service command distributed on a time range") as group_distributed:
            distributed_by_groups = sdk2.parameters.Bool("Specify hosts by client tags", default=True)
            with distributed_by_groups.value[True]:
                distributed_groups = sdk2.parameters.String(
                    "Clients tag expression for distributed command", default="MULTISLOT",
                    description="Clients tag expression to send service command distributed on a time range"
                )
            with distributed_by_groups.value[False]:
                distributed_hosts = sdk2.parameters.String(
                    "Hosts for distributed command",
                    description="Space-separated list of client ids to send service command "
                    "distributed on a time range (braces allowed)"
                )
            distribute_range = sdk2.parameters.Integer(
                "Time range in minutes for dispersion", default=180,
                description="Time range in minutes for service command dispersion"
            )
            distribute_period = sdk2.parameters.Integer(
                "Time period in minutes for dispersion", default=5,
                description="Time period in minutes for service command dispersion on specified time range"
            )
            merge_first_n_periods = sdk2.parameters.Integer(
                "Do not distribute first N periods", default=1,
                description="Immediately send service command for first N periods (chunks)"
            )

    def __groups2hosts(self, tags):
        return self.server.client.read(tags=tags, limit=8000)["items"]

    def __hosts2hosts(self, hosts):
        hosts = common.proxy.brace_expansion(filter(None, map(str.strip, str(hosts).split())))
        return self.server.client.read(id=",".join(hosts), limit=len(hosts))["items"]

    def on_execute(self):
        self.server.DEFAULT_TIMEOUT = 180
        with self.memoize_stage.immediate:
            if self.Parameters.immediate_by_groups:
                immediate = self.__groups2hosts(self.Parameters.immediate_groups)
            else:
                immediate = self.__hosts2hosts(self.Parameters.immediate_hosts)

            if self.Parameters.distributed_by_groups:
                distribute = self.__groups2hosts(self.Parameters.distributed_groups)
            else:
                distribute = self.__hosts2hosts(self.Parameters.distributed_hosts)

            self.Context.distribute = [_["id"] for _ in sorted(distribute, key=lambda x: len(x.get("tasks", [])))]
            random.shuffle(self.Context.distribute)

            dset = set(self.Context.distribute)
            immediate = [cl["id"] for cl in immediate if cl["id"] not in dset]
            logging.info("Sending immediate command %r to %d hosts", self.Parameters.command, len(immediate))
            for chunk in common.utils.chunker(immediate, self.RELOAD_CHUNK_SIZE):
                logging.info("Sending immediate command %r to %d", self.Parameters.command, chunk)
                self.server.batch.clients[self.Parameters.command] = dict(
                    id=chunk,
                    comment="Service command from task #{}".format(self.id),
                )

        distribute = self.Context.distribute
        chunks = self.Parameters.distribute_range / self.Parameters.distribute_period
        if not (chunks and distribute):
            logging.info("Chunks: %d, distribute: %r. NO need to switch to WAIT", chunks, distribute)
            return
        chunk_size = int(math.ceil(len(distribute) / float(chunks)))
        wait_time = self.Parameters.distribute_range * 60 / (chunks - self.Parameters.merge_first_n_periods)
        logging.debug("Chunks: %d, chunk size: %d, wait_time: %ds", chunks, chunk_size, wait_time)

        now = int(time.time())
        last_run, self.Context.last_run = self.Context.last_run, now
        correction = min(0, wait_time - now + last_run if last_run else 0)
        with sdk2_misc.MemoizeStage(self, "distribute")(max_runs=chunks) as stage:
            chunk_no = stage.runs - 1
            chunk = distribute[chunk_no * chunk_size:(chunk_no + 1) * chunk_size]
            logging.info("Processing distribution chunk #%d/%d of %d hosts: %r", chunk_no, chunks, len(chunk), chunk)
            self.server.batch.clients[self.Parameters.command] = dict(
                id=chunk,
                comment="Service command from task #{}".format(self.id),
            )
            if chunk_no < self.Parameters.merge_first_n_periods:
                # Just enter the same code again immediately - some kind of `continue`
                raise common.errors.NothingToWait

            wait = wait_time + correction
            logging.debug("Waiting %ds (correction: %ds)", wait, correction)
            if wait < 1:  # TODO: It should be checked in `sdk2.WaitTime`
                raise common.errors.NothingToWait
            raise sdk2.WaitTime(wait)
