# -*- coding: utf-8 -*-

import logging
import os
import random

import requests

import sandbox.common as common
from sandbox import sdk2
import sandbox.common.types.client as ctc
from sandbox.projects.rtmr.clusters import RTMR_CLUSTERS, RtmrClustersInfo
from sandbox.sdk2.helpers import subprocess as sp


class RtmrRestartHost(sdk2.Task):
    """Restart rtmr host process"""

    class Requirements(sdk2.Task.Requirements):
        client_tags = ctc.Tag.GENERIC & ~ctc.Tag.LXC
        disk_space = 2 * 1024  # 2Gb

    class Parameters(sdk2.Task.Parameters):
        description = "Restart rtmr host process"
        kill_timeout = 30 * 60

        with sdk2.parameters.String("Cluster name", multiline=True, required=True) as cluster_name:
            _first = True
            for _name in RTMR_CLUSTERS:
                if _first:
                    cluster_name.values[_name] = cluster_name.Value(default=True)
                    _first = False
                else:
                    cluster_name.values[_name] = None

        with sdk2.parameters.Group("Auth settings") as auth_block:
            secret_name = sdk2.parameters.String(
                "Vault secret name with SSH key",
                required=True,
                default_value="robot-rtmr-mnt-ssh"
            )
            secret_owner = sdk2.parameters.String(
                "Vault secret owner",
                required=True,
                default_value="RTMR-DEV"
            )
            remote_user = sdk2.parameters.String(
                "Remote username",
                required=True,
                default_value="robot-rtmr-mnt"
            )

    class Context(sdk2.Task.Context):
        failed_hosts = None

    def get_cluster_hosts(self):
        from library.sky.hostresolver import Resolver
        return Resolver().resolveHosts(RtmrClustersInfo().clusters[self.Parameters.cluster_name].skynet)

    def get_rtmr_workers_from_host(self, hostname):
        logging.info("Get rtmr workers from " + hostname)
        url = "http://{}:8080/api/v1/workers.json".format(hostname)
        logging.info("Url " + url)
        response = requests.get(url)
        if response.status_code != 200:
            logging.error("Api response has error %r", response.status_code)
            raise common.errors.TaskError("RTMR api http response code {}".format(response.status_code))
        data = response.json()
        return data["Workers"]

    def get_rtmr_workers(self, hosts):
        for try_no in xrange(5):
            try:
                return self.get_rtmr_workers_from_host(random.choice(list(hosts)))
            except Exception as e:
                logging.error("Error get rtmr workers (#%d) %r", try_no, e)
                raise
        logging.error("Retry exceeded")
        raise common.errors.TaskError("RTMR api error. See logs for more informations.")

    def get_alive_hosts(self):
        workers = self.get_rtmr_workers(self.get_cluster_hosts())
        alive_hosts = set()
        for worker in workers:
            server_id = worker.get("ServerId")
            if server_id is None:
                continue
            services = worker.get("Services", [])
            for service in services:
                name = service.get("Name")
                if name == "RTMRTASK":
                    alive_hosts.add(server_id)
        return alive_hosts

    def verify_failed_hosts(self):
        failed_hosts = set()
        try:
            with open(self.Context.failed_hosts, "r") as fd:
                for line in fd.xreadlines():
                    line = line.strip()
                    if len(line) == 0:
                        continue
                    failed_hosts.add(line)
        except IOError as e:
            logging.error("Error open skynet failed-hosts file %r", e)
            raise common.errors.TaskError("Error open skynet failed-hosts file")
        if len(failed_hosts) > 0:
            self.set_info("Fail restart on hosts: " + " ".join(failed_hosts))
        alive_hosts = self.get_alive_hosts()
        failed_alive_hosts = alive_hosts.intersection(failed_hosts)
        if len(failed_alive_hosts) > 0:
            logging.error("Fail restart on hosts: %r", list(failed_alive_hosts))
            raise common.errors.TaskError("Error restart on {} hosts".format(len(failed_alive_hosts)))

    def on_execute(self):
        if self.Context.failed_hosts is None:
            self.Context.failed_hosts = str(sdk2.Path("failed-hosts.txt").absolute())

        cmd = [
            "/usr/local/bin/sky",
            "run",
            "-u", self.Parameters.remote_user,
            "--retry", "3",
            "--stream",
            "--log_failed_hosts=" + self.Context.failed_hosts,
            "if [ -x /bin/systemctl ]; then /usr/bin/sudo /bin/systemctl restart rtmr-host.service; else "
            "/usr/bin/sudo /sbin/restart rtmr-host; fi",
            RtmrClustersInfo().clusters[self.Parameters.cluster_name].skynet
        ]

        self.set_info("Restart rtmr-host on " + self.Parameters.cluster_name)
        logging.info("Restart command %r", cmd)
        with sdk2.helpers.ProcessLog(self, logger=logging.getLogger("sky")) as pl:
            with sdk2.ssh.Key(self, self.Parameters.secret_owner, self.Parameters.secret_name):
                proc = sp.Popen(cmd, stdout=pl.stdout, stderr=sp.STDOUT, env=os.environ.copy())
                proc.wait()
                if proc.returncode != 0:
                    self.set_info("Warning: restart command return error code: " + str(proc.returncode))
                    self.verify_failed_hosts()
        self.set_info("Done")
