import os
import json
import time
import random
import signal
import logging
import textwrap
import datetime as dt
import subprocess as sp
import collections

import psutil
import requests

from . import base
from . import unified_agent

BASE_BALANCER_CONFIG = {
    "location": "/",
    "balancing_options": {
        "connection_timeout": "200ms",
        "backend_timeout": "30s",
        "retries_count": 10,
        "balancer_retries_timeout": "70s",
        "keepalive_count": 5,
        "fail_on_5xx": True,  # means retry_on_5xx: https://wiki.yandex-team.ru/samogon/kontrolka-servanta/balancer/
        "balancing_type": {
            "mode": "rr",
        },
    }
}


class ServerLauncher(base.SandboxLauncher):
    __name__ = "server_launcher"

    __MAX_PROCESSING_REQUESTS = 250
    __MIN_MAIL_PERIOD = 900  # in seconds
    __HTTP_TIMEOUT = 15  # in seconds
    __MAX_PING_FAILS = 3
    __SERVER_URL = "http://localhost:9998"
    __CHECK_FILE = "server_check_stop"
    __MIN_RESTART_PERIOD = 600  # in seconds
    __MEMORY_LIMIT = "64G"

    @classmethod
    def cgroup_settings(cls):
        return {"memory": {"memory.limit_in_bytes": cls.__MEMORY_LIMIT}}

    def postinstall(self):
        super(ServerLauncher, self).postinstall()
        if os.path.islink("/home/zomb-sandbox/ui"):
            os.unlink("/home/zomb-sandbox/ui")

    def start(self):
        self._write_solomon_descr_to_config("server.solomon.push")
        os.chmod(os.path.join(self.get_pack("server.tgz"), "sandbox", "web", "media"), 0o775)
        self._make_symlink(
            self.config_as_file(),
            os.path.join(self.get_pack("server.tgz"), "sandbox", "etc", "settings.yaml"),
            force=True
        )

        self.create(
            [
                os.path.join(self._sandbox_venv_path, "bin", "python"),
                os.path.join(self.get_pack("server.tgz"), "sandbox", "bin", "server.py")
            ],
            env={
                "PYTHONPATH": ":".join((
                    "/skynet", os.path.join(self.data_dir(), "packages", "tasks"),
                ))
            }
        )
        os.chmod(self.config_as_file(), 0o644)

    @property
    def __cfg__(self):
        if self._config is None:
            self._config = self.merge_dicts(
                super(ServerLauncher, self).__cfg__,
                {"client": {"sdk": {"svn": {"use_system_binary": True}}}}
            )
        return self._config

    def packages(self):
        return super(ServerLauncher, self).packages() + ["server.tgz"]

    def user(self):
        return self._service_user._asdict()

    def balancer(self):
        return self.merge_dicts(BASE_BALANCER_CONFIG, {
            "domain": "api",
            "port": 9998,
        })

    @staticmethod
    def logfiles():
        """ List of logs sends to Elasticsearch """
        return [
            "/var/log/nginx/error.log",
            "/var/log/nginx/sandbox_access.log",
            "/var/log/sandbox/server.log",
            "/var/log/sandbox/xmlrpc.log",
            "/var/log/sandbox/sandbox.log",
            "/var/log/sandbox/service.log",
            "/var/log/sandbox/serviceq.log",
            "/var/log/sandbox/rest.log"
        ]

    @staticmethod
    def __touch(fname):
        try:
            os.utime(fname, None)
            return True
        except OSError:
            pass
        try:
            open(fname, "a").close()
            os.chmod(fname, 0o666)
            return True
        except OSError:
            return False

    def ping(self):
        now = dt.datetime.now()
        if (now.hour == 23 and now.minute > 55) or (now.hour == 0 and now.minute < 5):
            return True

        checkf = os.path.join(base.SANDBOX_RUNTIME_DIR, self.__CHECK_FILE)
        try:
            if os.stat(checkf).st_mtime + self.__MIN_RESTART_PERIOD > time.time():
                return True
        except OSError:
            pass

        if self.persist is None:
            self.persist = collections.Counter()

        self.persist["fails"] += 1
        if self.persist["fails"] >= self.__MAX_PING_FAILS:
            self.persist["fails"] = 0
            self.__touch(checkf)
            self.restart()
            return

        try:
            # if a connection hangs up, die!
            signal.alarm(self.__HTTP_TIMEOUT * 2)
            requests.get(self.__SERVER_URL + "/sandbox/http_check", timeout=self.__HTTP_TIMEOUT).raise_for_status()
        except Exception:
            signal.alarm(self.__HTTP_TIMEOUT * 2)
            logging.exception("Server did not respond to HTTP check")
            return False
        else:
            signal.alarm(0)
            self.persist["fails"] = 0
            try:
                r = requests.get(self.__SERVER_URL + "/api/v1.0/service/status/server", timeout=self.__HTTP_TIMEOUT)
                r.raise_for_status()
                status = r.json()
            except Exception:
                self.__touch(checkf)
                self.restart()
                logging.exception("Unable to fetch server status")
                return False

            logging.info(
                "Server currently processing %d requests, queued %d requests",
                status["requests"]["processing"], status["requests"]["queued"]
            )
            if status["requests"]["processing"] > self.__MAX_PROCESSING_REQUESTS:
                self.restart()
                logging.error(
                    "Server is overloaded (see logs for threads dump): %s",
                    json.dumps(status, sort_keys=True),
                )
                return False
        return True


class ServiceQ(base.Base):
    __name__ = "serviceq"

    _CONFIG = textwrap.dedent("""
        serviceq:
          zookeeper:
            enabled: true
            hosts: "sandbox-server{{03,11,25}}.search.yandex.net:2181"
          server:
            mongodb:
              connection_url: "mongodb://localhost:22222/sandbox"
              write_concern:
                %(write_concern)s
            statistics:
              enabled: true
            quotas:
              enabled: true
              use_pools: true
        common:
          installation: "PRODUCTION"
          statistics:
            enabled: true
            database: "signal"
          unified_agent:
            %(unified_agent)s
    """ % dict(
        write_concern=json.dumps(base.WRITE_CONCERN),
        unified_agent=unified_agent.ua_sockets_config_for_server,
    ))

    _CONFIG_PATCHES = base.unfold_list([
        {"sandbox1_server": textwrap.dedent("""
            serviceq:
              zookeeper:
                root: "/serviceq_pre_production/consensus"
              server:
                mongodb:
                  connection_url: "mongodb://localhost:22222/sandbox_restored"
            common:
              installation: "PRE_PRODUCTION"
            client:
              rest_url: "https://www-sandbox1.n.yandex-team.ru/api/v1.0"
        """)},
    ])

    __PING_TIMEOUT = 10  # in seconds
    __MAX_FAIL_INTERVAL = 180  # max interval for transient statuses (STARTING, STOPPING, TRANSIENT) in seconds
    __MAX_RESTORING_INTERVAL = 360  # max interval for status RESTORING in seconds

    _LOGFILE = "/var/log/sandbox/serviceq.log"
    _ELECTOR_LOGFILE = "/var/log/sandbox/serviceq_zk.log"
    _ROTATE_LOGS = [_LOGFILE, _ELECTOR_LOGFILE]

    def packages(self):
        return super(ServiceQ, self).packages() + ["serviceq.tgz"]

    def user(self):
        return self._service_user._asdict()

    @property
    def binary_path(self):
        return os.path.join(self.get_pack("serviceq.tgz"), "serviceq")

    def start(self):
        self._write_solomon_descr_to_config("serviceq.server.statistics.semaphores")
        self.change_permissions(self.binary_path, chmod=0o755)
        self.create(
            [self.binary_path],
            env={"SANDBOX_CONFIG": self.config_as_file()}
        )

    def _adjust_cpu(self):
        cpuset_cgroup_path = "/sys/fs/cgroup/cpuset/serviceq/tasks"
        with open(cpuset_cgroup_path) as f:
            cgroup_pids = map(lambda _: int(_.strip()), f.readlines())
        launcher = next(iter(self.procs()), None)
        if launcher:
            launcher = psutil.Process(pid=launcher.stat()["pid"])
            serviceq = next(iter(filter(lambda _: _.name == "serviceq", launcher.get_children())), None)
            if serviceq and serviceq.pid not in cgroup_pids:
                logging.debug("Move process %s to dedicated CPU set", serviceq.pid)
                with open(cpuset_cgroup_path, "w") as f:
                    f.write(str(serviceq.pid))

    def ping(self):
        self._rotate_logs(move=True, notify=True)
        if self.persist is None:
            self.persist = {}
        self.persist.setdefault("last_not_fail", dt.datetime.now())
        self.persist.setdefault("last_not_restoring", dt.datetime.now())
        self.persist.setdefault("fail_count", 0)
        status = None
        try:
            self._adjust_cpu()
            output = sp.check_output([
                os.path.join(self._sandbox_venv_path, "bin", "python"),
                "-Buc",
                "import os, sys;"
                "os.environ['SANDBOX_CONFIG'] = '{sandbox_config}';"
                "sys.path[:0] = ['{serviceq_path}', '/skynet'];"
                "from sandbox.serviceq import config, client, types;"
                "qconfig = config.Registry();"
                "qconfig.serviceq.client.timeout = {ping_timeout};"
                "qclient = client.Client(qconfig);"
                "status = qclient._rpc_client().call('secondary_status').wait(timeout={ping_timeout});"
                "output = 'READY' if status in types.Status.Group.READY else "
                "'RESTORING' if status == types.Status.RESTORING else 'NOT_READY';"
                "sys.stdout.write(output);"
                "sys.stdout.flush()".format(
                    sandbox_config=self.config_as_file(),
                    serviceq_path=self.get_pack("serviceq.tgz"),
                    ping_timeout=self.__PING_TIMEOUT
                )
            ], stderr=sp.STDOUT)
            status = output.split("\n")[-1]
            alive = status == "READY"
            if not alive:
                logging.error("Service Q reported: %s", output)
        except sp.CalledProcessError as ex:
            logging.error("Error while pinging server: %s", ex.output)
            alive = False
        now = dt.datetime.now()
        if alive:
            self.persist["last_not_fail"] = now
            self.persist["last_not_restoring"] = now
            self.persist["fail_count"] = 0
        else:
            self.persist["last_not_fail" if status == "RESTORING" else "last_not_restoring"] = now
            fail_interval = (now - self.persist["last_not_fail"]).total_seconds()
            restoring_interval = (now - self.persist["last_not_restoring"]).total_seconds()
            if self.persist["fail_count"] == 0 and fail_interval >= self.__MAX_FAIL_INTERVAL:
                self.persist["last_not_fail"] = now
                fail_interval = 0
            self.persist["fail_count"] += 1
            if fail_interval >= self.__MAX_FAIL_INTERVAL or restoring_interval >= self.__MAX_RESTORING_INTERVAL:
                logging.warning(
                    "Service Q ping times: now=%s, last_not_fail=%s, last_not_restoring=%s, fail_count=%s",
                    now, self.persist["last_not_fail"], self.persist["last_not_restoring"], self.persist["fail_count"]
                )
                self.persist["last_not_fail"] = now
                self.persist["last_not_restoring"] = now
                self.persist["fail_count"] = 0
                self.restart()
        return alive


class ServiceApi(base.Base):
    __name__ = "serviceapi"

    # Use the same settings as Sandbox Server
    _CONFIG = base.SandboxLauncher._CONFIG
    _CONFIG_PATCHES = base.unfold_list([{
        "sandbox1_server": textwrap.dedent("""
            common:
              unified_agent:
                %(unified_agent)s
              abcd:
                d_tvm_service_id: "d-testing"
                d_api_url: "https://d.test.yandex-team.ru/api/v1"
            server:
              api:
                port: 8080
                workers: 64
        """ % dict(unified_agent=unified_agent.ua_sockets_config_for_server)),
        "sandbox_server": textwrap.dedent("""
            common:
              unified_agent:
                %(unified_agent)s
              abcd:
                d_tvm_service_id: "d-production"
                d_api_url: "https://d-api.yandex-team.ru/api/v1"
            server:
              api:
                port: 8080
                workers: 200
        """ % dict(unified_agent=unified_agent.ua_sockets_config_for_server)),
    }], base=base.SandboxLauncher._CONFIG_PATCHES)

    __MAX_PING_FAILS = 5  # maximum number of health check fails before ServiceApi restarting
    __HTTP_TIMEOUT = 15  # timeout for health check requests
    __SERVER_URL = "http://localhost:8080"  # ServiceApi instance url
    __RESTART_TIMEOUT = 20  # time to ServcieApi restart in seconds

    def packages(self):
        return ["serviceapi.tgz"]

    def user(self):
        return self._service_user._asdict()

    @property
    def binary_path(self):
        return os.path.join(self.get_pack("serviceapi.tgz"), "serviceapi")

    def start(self):
        self.change_permissions(self.binary_path, chmod=0o755)
        self.create(
            [self.binary_path],
            env={"SANDBOX_CONFIG": self.config_as_file()}
        )

    def ping(self):
        # uwsgi can't reopen logs, therefore perform "copytruncate"
        self._rotate_logs(move=False, notify=False)

        now = dt.datetime.utcnow()
        if self.persist is None:
            self.persist = {"fails": 0, "last_restart": now}

        if self.persist.setdefault("fails", 0) >= self.__MAX_PING_FAILS:
            self.persist["fails"] = 0
            self.persist["last_restart"] = now
            self.restart()
            logging.error("ServiceApi did not respond to HTTP check")
            return

        if now - self.persist.setdefault("last_restart", now) > dt.timedelta(seconds=self.__RESTART_TIMEOUT):
            try:
                # if a connection hangs up, die!
                signal.alarm(self.__HTTP_TIMEOUT * 2)
                requests.get(self.__SERVER_URL + "/health_check", timeout=self.__HTTP_TIMEOUT)
                self.persist["fails"] = 0
            except Exception as ex:
                logging.error("Health check failed.", exc_info=ex)
                self.persist["fails"] += 1
                signal.alarm(self.__HTTP_TIMEOUT * 2)
                return False

        # Don't forget default `ping` logic
        return super(ServiceApi, self).ping()

    def balancer(self):
        return self.merge_dicts(BASE_BALANCER_CONFIG, {
            "domain": "www",
            "port": 8080,
        })

    def logfiles(self):
        return [
            "/var/log/sandbox/api/access.log",
            "/var/log/sandbox/api/server.log",
        ]

    @property
    def _ROTATE_LOGS(self):
        return self.logfiles()


class Taskbox(base.Base):
    __name__ = "taskbox"

    __MEMORY_LIMIT = "32G"
    __READ_IO_LIMIT = "104857600"  # 100mb/sec
    __CPUSHARES = "512"
    _ROTATE_LOGS = ["/var/log/sandbox/taskbox/dispatcher.log", "/var/log/sandbox/taskbox/workers/worker.log"]

    # Use the same settings as Sandbox Server
    _CONFIG = base.SandboxLauncher._CONFIG

    _CONFIG_PATCHES = base.unfold_list([{
        ("sandbox_server", "sandbox1_server"): textwrap.dedent("""
            taskbox:
              statistics:
                enabled: true
              dispatcher:
                server:
                  port: 8090
        """),
    }], base=base.SandboxLauncher._CONFIG_PATCHES)

    @staticmethod
    def _get_blockdev_nums():
        block_devs = []
        dev_nums = []
        for d in next(os.walk("/dev/"))[-1]:
            if (d.startswith("sd") and not d[-1].isdigit()) or (d.startswith("md") and d[-1].isdigit()):
                block_devs.append(d)
        for block_dev in block_devs:
            dev = os.stat("/dev/" + block_dev).st_rdev
            dev_nums.append("{}:{}".format(os.major(dev), os.minor(dev)))
        return dev_nums

    @classmethod
    def _gen_io_limits(cls):
        block_devices = cls._get_blockdev_nums()
        read_limits = []
        for block_device in block_devices:
            read_limits.append(block_device + " " + cls.__READ_IO_LIMIT)
        return read_limits

    @classmethod
    def cgroup_settings(cls):
        limits = {}
        read_limits = cls._gen_io_limits()

        limits["memory"] = {"memory.limit_in_bytes": cls.__MEMORY_LIMIT}
        limits["cpu,cpuacct"] = {"cpu.shares": cls.__CPUSHARES}
        # Samogon allow only one write to cgroup name from dict key. Use ./ to modify key for multiple writes.
        if read_limits:
            cg_name = "blkio.throttle.read_bps_device"
            limits["blkio"] = {}
            for i in range(len(read_limits)):
                limits["blkio"]["./" * i + cg_name] = read_limits[i]
        return limits

    def packages(self):
        return super(Taskbox, self).packages() + ["taskbox.tgz"]

    def user(self):
        return self._service_user._asdict()

    @property
    def binary_path(self):
        return os.path.join(self.get_pack("taskbox.tgz"), "taskbox")

    def start(self):
        self.change_permissions(self.binary_path, chmod=0o755)
        self.create(
            [self.binary_path],
            env={"SANDBOX_CONFIG": self.config_as_file()},
        )

    def ping(self):
        self._rotate_logs(move=True, notify=False)

        ping_value = str(random.randint(1, 1000000))
        srv_path = self.get_pack("taskbox.tgz")
        try:
            output = sp.check_output([
                os.path.join(self._sandbox_venv_path, "bin", "python"),
                "-Buc",
                "import os, sys;"
                "sys.path = ['/skynet', {service_path!r}, {code_path!r}] + sys.path;"
                "import sandbox.common.joint.client;"
                "import sandbox.taskbox.client.service as tb_service;"
                "tb = tb_service.Dispatcher();"
                "sys.stdout.write(str(tb.ping({ping_value})));"
                "sys.stdout.flush()".format(
                    service_path=srv_path,
                    code_path=os.path.join(srv_path, "sandbox"),
                    ping_value=ping_value,
                )
            ], stderr=sp.STDOUT)
            alive = output == ping_value
            if not alive:
                logging.error("Mismatch of return value of ping: got %r, expected %r", output, ping_value)
        except sp.CalledProcessError as ex:
            logging.error("Error while pinging server: %s", ex.output)
            alive = False

        return alive
