import gevent.monkey
gevent.monkey.patch_all()  # noqa

import gc
import os
import time
import math
import uuid
import psutil
import random
import socket
import logging
import datetime as dt
import threading
import collections

import six

from sandbox import common
import sandbox.common.types.misc as ctm
import sandbox.common.types.resource as ctr
import sandbox.common.types.statistics as ctss

import sandbox.common.joint.errors as jerrors
import sandbox.common.joint.server as jserver

from sandbox import sdk2

import sandbox.agentr.client as ar_client
import sandbox.agentr.errors as ar_errors
import sandbox.taskbox.utils as tb_utils
import sandbox.taskbox.dispatcher.worker
import sandbox.taskbox.statistics as tb_statistics


class WorkerRegistry(object):
    def __init__(
        self, run_dir, tasks_bin=None, config=None, kill_workers_on_start=False, system_collector_interval=1,
    ):
        self._run_dir = run_dir
        self._logger = logging.getLogger("workers")

        if tasks_bin:
            # Custom tasks binary is provided, use for all workers
            self._tasks_binary = tasks_bin
        else:
            self._tasks_binary = None

        self._agentr = ar_client.Session.taskbox(self._logger)

        if not config:
            config = common.config.Registry()
        self.config = config
        self._binary_ttl = self.config.taskbox.dispatcher.tasks_binary_ttl
        self._worker_ttl = self.config.taskbox.dispatcher.worker_ttl

        if config.taskbox.dispatcher.use_subproc:
            self._worker_factory = sandbox.taskbox.dispatcher.worker.SubprocessFactory(system_collector_interval)
        else:
            self._worker_factory = sandbox.taskbox.dispatcher.worker.ProcmanFactory(system_collector_interval)

        self._workers = self._worker_factory.find_workers()
        self._lock_dict = {worker_binary_id: threading.Lock() for worker_binary_id in self._workers}
        now = dt.datetime.utcnow()
        self._last_tasks_binary_call = {worker_binary_id: now for worker_binary_id in self._workers}
        self._lock = threading.Lock()

        if kill_workers_on_start:
            self.kill_workers()
            self._lock_dict = {}
            self._last_tasks_binary_call = {}

        self._worker_watcher_thread = threading.Thread(target=self._worker_watcher)
        self._worker_watcher_thread.daemon = True
        self._worker_watcher_thread.start()

        if config.common.installation in ctm.Installation.Group.NONLOCAL:
            self._worker_presync_binaries = threading.Thread(target=self.presync_binaries)
            self._worker_presync_binaries.daemon = True
            self._worker_presync_binaries.start()

    def __len__(self):
        result = 0
        with self._lock:
            for worker_arr in six.itervalues(self._workers):
                result += len(worker_arr)
        return result

    def __iter__(self):
        workers = []

        with self._lock:
            for worker_arr in six.itervalues(self._workers):
                workers.extend(worker_arr)

        for worker in workers:
            yield worker

    def _fill_empty_binary(self, resource, now):
        self._workers[resource] = []
        self._lock_dict[resource] = threading.Lock()
        self._last_tasks_binary_call[resource] = now

    def presync_binaries(self):
        client = common.rest.Client(
            auth=common.fs.read_settings_value_from_file(self.config.client.auth.oauth_token),
            version=2
        )
        client = client << client.HEADERS({ctm.HTTPHeader.NO_LINKS: "true"})
        while True:
            now = dt.datetime.utcnow()
            try:
                resources = client.resource.read(
                    type=sdk2.service_resources.SandboxTasksBinary,
                    state=ctr.State.READY,
                    attrs={"taskbox_enabled": True},
                    created="{}..{}".format(
                        common.format.utcdt2iso(now - dt.timedelta(minutes=5)),
                        common.format.utcdt2iso(now)
                    ),
                    limit=50
                )

                resources_set = set(r["id"] for r in resources["items"])
                if resources_set:
                    with self._lock:
                        new_resources = resources_set - set(self._workers)
                        for resource in new_resources:
                            self._fill_empty_binary(resource, now)

                    if new_resources:
                        self._logger.info("Sync new resources: %s", ",".join(map(str, new_resources)))
                    for resource in new_resources:
                        self._agentr.resource_sync(resource)
                time.sleep(60)

            except Exception as ex:
                self._logger.error("Error on resource presync thread.", exc_info=ex)

    def _worker_watcher(self):
        pid = os.getpid()
        process = psutil.Process(pid)
        objects_printed = False
        while True:
            if not objects_printed and process.memory_info()[0] / 2.**30 > 7:
                objects_printed = True
                self._logger.warning("Dispatcher consume too much memory, try to print top objects by it's count")
                stats = collections.Counter()
                type_stats = collections.Counter()
                for item in gc.get_objects():
                    try:
                        stats[str(item)] += 1
                    except:
                        pass
                    try:
                        type_stats[type(item)] += 1
                    except:
                        pass
                self._logger.warning(
                    "Top objects by it's count are %s", sorted(stats.items(), key=lambda x: x[1], reverse=True)[:1000]
                )
                self._logger.warning(
                    "Top objects types by it's count are %s",
                    sorted(type_stats.items(), key=lambda x: x[1], reverse=True)[:1000]
                )
            workers = {}
            try:
                tasks_binaries = set(map(lambda _: _[0], self._agentr.dependant_resources()))
            except ar_errors.NoTaskSession as ex:
                self._logger.error("Task session removed. Make new session.", exc_info=ex)
                self._agentr = ar_client.Session.taskbox(self._logger)
                time.sleep(math.pi + random.random())
                continue
            except Exception as ex:
                self._logger.error("Error on getting dependant resources.", exc_info=ex)
                time.sleep(math.pi + random.random())
                continue

            now = dt.datetime.utcnow()
            with self._lock:
                for worker_binary_id, worker_arr in six.iteritems(self._workers):
                    workers[worker_binary_id] = worker_arr

            extra_tasks_binaries = tasks_binaries - set(workers)

            if extra_tasks_binaries:
                self._logger.info("Add extra resource to registry: %s", ", ".join(map(str, extra_tasks_binaries)))
                with self._lock:
                    for resource in extra_tasks_binaries:
                        self._fill_empty_binary(resource, now)

            experid_binaries = []
            for worker_binary_id, worker_arr in six.iteritems(workers):
                with self._lock_dict[worker_binary_id]:
                    for worker in list(worker_arr):
                        if worker.check_if_dead():
                            self._logger.info("Worker %r terminated:\n", worker)
                            worker_arr.remove(worker)
                        if worker.last_call_time + dt.timedelta(minutes=self._worker_ttl) < now:
                            self._logger.info("Worker %r expired. Terminate it.", worker)
                            worker.terminate()
                            worker_arr.remove(worker)
                    last_call = self._last_tasks_binary_call[worker_binary_id]
                    if last_call + dt.timedelta(minutes=self._binary_ttl) < now and not worker_arr:
                        if worker_binary_id in tasks_binaries:
                            self._logger.info("Resource %s expired. Remove it from task_deps.", worker_binary_id)
                            try:
                                self._agentr.remove_dependant_resource(worker_binary_id)
                            except Exception as ex:
                                self._logger.error("Error on remove dependant resources.", exc_info=ex)
                                continue

                        experid_binaries.append(worker_binary_id)

            if experid_binaries:
                with self._lock:
                    for worker_binary_id in experid_binaries:
                        self._workers.pop(worker_binary_id, None)
                        self._lock_dict.pop(worker_binary_id, None)
                        self._last_tasks_binary_call.pop(worker_binary_id, None)

            time.sleep(math.pi + random.random())

    def kill_workers(self):
        with self._lock:
            for worker_binary_id, worker_list in six.iteritems(self._workers):
                with self._lock_dict[worker_binary_id]:
                    for worker in worker_list:
                        self._logger.info("Killing %s", worker)
                        worker.terminate()
            self._workers = {}

    def get_or_create_worker(self, tasks_binary_id, logger, lock):
        now = dt.datetime.utcnow()
        with self._lock:
            if tasks_binary_id not in self._lock_dict:
                self._lock_dict[tasks_binary_id] = threading.Lock()
            if tasks_binary_id not in self._workers:
                self._workers[tasks_binary_id] = []
            self._last_tasks_binary_call[tasks_binary_id] = now

        with self._lock_dict[tasks_binary_id]:
            workers = self._workers[tasks_binary_id]
            for worker in sorted(workers, key=lambda x: x.requests_in_progress):
                if (
                    worker.requests_in_progress >= self.config.taskbox.dispatcher.max_requests_in_progress or
                    worker.check_if_dead()
                ):
                    continue

                # Found suitable free worker
                logger.info("Use existing worker: %r", worker)
                break

            # No suitable workers, create new one.
            else:
                if self._tasks_binary:
                    tasks_bin = self._tasks_binary
                else:
                    agentr = ar_client.Session.taskbox(logger)
                    tasks_bin = agentr.resource_sync(tasks_binary_id)
                    resource_meta = agentr.resource_meta(tasks_binary_id)
                    if resource_meta.get("multifile"):
                        linux_bin = resource_meta.get("system_attributes", {}).get("linux_platform", None)
                        if not linux_bin:
                            raise ValueError("Taskboxed directory-resource doesn't provide linux binary")
                        tasks_bin = os.path.join(tasks_bin, linux_bin)

                worker_id = uuid.uuid4().hex[:8]
                socket_path = os.path.join(self._run_dir, "worker.{}.sock".format(worker_id))

                logger.info("Starting worker for #%s using %s", tasks_binary_id, tasks_bin)
                worker = self._worker_factory.spawn_worker(worker_id, tasks_bin, tasks_binary_id, socket_path)
                workers.append(worker)
                logger.info("Worker started: %s", worker)

            if lock:
                worker.requests_in_progress += 1
            worker.last_call_time = now

            return worker


class DispatcherError(jerrors.SilentException):
    pass


class Server(tb_utils.JointServer):
    """ Taskbox's dispatcher """

    # Time for worker to load task code and start listening
    WORKER_STARTUP_TIMEOUT = 60

    @property
    def service_name(self):
        return "dispatcher"

    def __init__(self, config, cleanup=False):
        super(Server, self).__init__(config)

        interval = (
            config.taskbox.statistics.system_collector_interval
            if config.taskbox.statistics.enabled else
            None
        )
        self.worker_registry = WorkerRegistry(
            config.taskbox.dirs.run,
            config.taskbox.worker.tasks_binary,
            config=config,
            kill_workers_on_start=cleanup,
            system_collector_interval=interval,
        )

        self.__setup_statistics_collection(config)

    def statistics_sender(self, config):
        interval = config.taskbox.statistics.dispatcher_sender_interval
        last_running = time.time()
        while not self.stopping:
            gevent.sleep(max(0, last_running + interval - time.time()))
            last_running = int(time.time())

            signaler = common.statistics.Signaler()
            for component, is_dispatcher in common.itertools.chain(
                ((self, True),),
                ((w, False) for w in self.worker_registry)
            ):
                component_role = (
                    tb_statistics.TaskboxComponent.DISPATCHER
                    if is_dispatcher else
                    tb_statistics.TaskboxComponent.WORKER
                )
                component_uuid = None if is_dispatcher else component.uuid
                resource_id = None if is_dispatcher else component.tasks_binary_id

                signaler.push([
                    dict(
                        type=ctss.SignalType.TASKBOX_STATISTICS,
                        date=point.time,
                        timestamp=point.time,
                        user_cpu=point.user_cpu,
                        system_cpu=point.system_cpu,
                        rss=point.rss >> 20,
                        vms=point.vms >> 20,
                        role=component_role,
                        uuid=component_uuid,
                        server=config.this.fqdn,
                        resource_id=resource_id,
                    )
                    for point in component.system_statistics
                ])

    def __setup_statistics_collection(self, config):
        self.system_statistics = None
        self.system_statistics_greenlet = None
        statistics = config.taskbox.statistics

        if statistics.enabled:
            common.statistics.Signaler().register(
                tb_statistics.TaskboxStatsSignalHandler(token=config.client.auth.oauth_token)
            )

            self.system_statistics = common.os.SystemStatistics()
            self.system_statistics_greenlet = gevent.spawn(
                tb_statistics.system_statistics_collector,
                statistics=self.system_statistics,
                interval=statistics.system_collector_interval,
                logger=self._logger,
                stop_checker=lambda: self.stopping,
            )
            self.statistics_sender_greenlet = gevent.spawn(
                self.statistics_sender,
                config=config,
            )
        else:
            self._logger.info("Taskbox statistics collection greenlet is disabled")

    @jserver.RPC.simple()
    def shutdown(self):
        """ Shutdown server. """
        self.stopping = True

    @jserver.RPC.full()
    def ping(self, ret=True, job=None):
        """ Just returns same value. """
        job.log.info("Pinged with %r value", ret)
        return ret

    @jserver.RPC.full()
    def ensure_worker(self, tasks_binary_id, job=None):
        worker = self.worker_registry.get_or_create_worker(tasks_binary_id, job.log, lock=False)
        return worker.worker_id

    @jserver.RPC.dupgenerator()
    def call(self, tasks_binary_id, request_data, request_id=None, job=None):
        logger = self._logger
        if request_id is not None and job is not None:
            job.log = common.log.tracking_logger(job.log, request_id)
            logger = job.log

        worker = self.worker_registry.get_or_create_worker(tasks_binary_id, logger, lock=True)

        with worker.unlock_after_use():
            # Retry connection errors for some time (i.e. wait until worker is ready).
            # No need to sleep here, because rpc client will sleep between attempts.
            started = time.time()
            while time.time() - started < self.WORKER_STARTUP_TIMEOUT:
                try:
                    call = worker.call(request_data, request_id)
                # TODO: need a special error when worker is not listening
                except socket.error as e:
                    if worker.check_if_dead():
                        logger.warning("%r is dead, can't process request", worker)
                        raise DispatcherError("Worker {} for sbr:{} has crashed. Reason: {}".format(
                            worker, tasks_binary_id, worker._stderr
                        ))
                    else:
                        logger.info("%r is not ready yet: %s", worker, e)
                        continue

                gen = call.generator
                try:
                    rtype, data = gen.next()
                    while True:
                        response = yield rtype, data
                        rtype, data = gen.send(response)
                except StopIteration:
                    raise StopIteration(call.wait())
            else:
                raise DispatcherError("Worker for sbr:{} timeouted during startup".format(tasks_binary_id))
