import os
import abc
import datetime as dt
import logging
import textwrap
import threading
import multiprocessing

import concurrent.futures

import sandbox.common.types.misc as ctm
from sandbox import common
from sandbox.yasandbox.database import mapping
from sandbox.common.types import statistics as ctss

from . import locking


__all__ = ("Service", "SingletonService", "ThreadedService", "MultiprocessedService", "WalleSingletonService")


logger = logging.getLogger(__name__)


class DefaultWorkerState(common.patterns.Abstract):
    __slots__ = ("interval", "persistent_state")
    __defs__ = (None, {})


class Service(object):
    """
    Base class for all services. Provides config, service context, periodic execution, etc.
    """
    __metaclass__ = abc.ABCMeta

    #: Timeout for Juggler monitoring script, minutes
    notification_timeout = 30

    #: Enable sending signals with tick() durations
    push_tick_duration = True

    def __init__(self):
        self.sandbox_config = common.config.Registry()
        self.service_config = self.sandbox_config.server.services.get(self.name, {})

        self._model = None
        self._stop_requested = threading.Event()

        self._running = False
        self._tick_in_progress = False
        self._liveness_deadline = None

        self.signaler = common.statistics.Signaler(
            common.statistics.ServerSignalHandler(), component=ctm.Component.SERVICE
        )

    @common.utils.classproperty
    def description(cls):
        return textwrap.dedent(cls.__doc__).strip() if cls.__doc__ else ""

    @common.utils.classproperty
    def name(cls):
        return common.utils.ident(cls.__name__).lower()

    @property
    def context(self):
        return self._model.context

    def set_alive_deadline(self, deadline=None, seconds=None):
        """
        Set service liveness deadline. Service will be considered alive until this time.
        If no one renews the deadline, service can be killed.

        .. note:: `deadline` and `seconds` are mutually exclusive.

        :param deadline `datetime.datetime`: Moment in time service is considered dead after.
        :param seconds int: Number of seconds relative to the current moment.
        """
        if deadline is not None and seconds is not None:
            raise ValueError("Arguments `deadline` and `seconds` are mutually exclusive.")

        if seconds is not None:
            deadline = dt.datetime.utcnow() + dt.timedelta(seconds=seconds)

        self._liveness_deadline = deadline

    def get_status_report(self):
        return {
            "running": self._running,
            "stopping": self._stop_requested.is_set(),
            "tick_in_progress": self._tick_in_progress,
            "liveness_deadline": self._liveness_deadline.isoformat() + "Z" if self._liveness_deadline else None,
        }

    @abc.abstractproperty
    def tick_interval(self):
        """
        Service tick interval. This is a delay between tick starts, not between tick executions.
        """
        pass

    @abc.abstractmethod
    def tick(self):
        """
        Service core logic, should not take longer than `tick_interval`.
        """
        pass

    def on_stop(self):
        """
        Called before service shutdown. Overload to implement some kind of cleanup.
        """
        pass

    def request_stop(self):
        """
        Ask service to perform graceful shutdown. By default it'll finish current tick and then terminate.
        """
        self._stop_requested.set()

    def run_forever(self):
        """
        Perform service initialization and run execution loop until stop is requested.
        """
        self._running = True
        self.load_service_state()

        try:
            self._run_loop()
        except Exception:
            logger.exception("Unhandled exception during `_run_loop`, stopping service...")

        logger.info("Main loop terminated, run `on_stop` hook")
        try:
            self.on_stop()
        except Exception:
            logger.exception("Unhandled exception during `on_stop`")

        logger.info("All done, goodbye!")

    def load_service_state(self):
        # This method should not be executed concurrently by different service instances,
        # because save conflicts are possible. If service is distributed, this method is called
        # after successful exclusive lock acquisition.
        self._model = mapping.Service.objects.with_id(self.name)
        if self._model is None:
            self._model = mapping.Service()
            self._model.name = self.name
            self._model.time = mapping.Service.Time()

        # Update current owner of this service
        self._model.host = self.sandbox_config.this.id
        self._model.timeout = self.notification_timeout
        self._model.save()

    def _run_loop(self):
        assert self._model, "Service state should be loaded at this point"

        while not self._stop_requested.is_set():
            tick_start = dt.datetime.utcnow()

            if tick_start > self._model.time.next_run:
                # Remain alive for the duration of tick
                self.set_alive_deadline(seconds=self.tick_interval)

                # Schedule next execution before tick, this way we'll be able to detect stuck services
                self._model.time.last_run = tick_start
                self._model.time.next_run = self._model.time.last_run + dt.timedelta(seconds=self.tick_interval)
                self._model.save()

                logger.info("Tick started")
                self._tick_in_progress = True
                self.tick()
                logger.info("Tick finished")
                self._tick_in_progress = False

                utcnow = dt.datetime.utcnow()
                tick_duration = utcnow - tick_start

                if self.push_tick_duration:
                    self.signaler.push(dict(
                        type=ctss.SignalType.SERVICE_STATISTICS,
                        date=utcnow,
                        timestamp=utcnow,
                        service=self.name,
                        duration=int(tick_duration.total_seconds() * 1000)
                    ))

                if tick_duration.total_seconds() > self.tick_interval:
                    logger.warning(
                        "Tick duration was longer than interval (%ss > %ss)!",
                        tick_duration.total_seconds(), self.tick_interval
                    )

                # Save service context, it could've been modified during tick
                self._model.save()
            else:
                # Remain alive while waiting for the next tick (plus a little bit just in case)
                self.set_alive_deadline(deadline=self._model.time.next_run + dt.timedelta(seconds=5))
                # Wait until next iteration
                timeout = (self._model.time.next_run - tick_start).total_seconds()
                self._stop_requested.wait(timeout=timeout)

    @classmethod
    def tick_once_with_context(cls, context):
        """
        Helper method for manual debug execution with some provided context.
        Changes to the context won't be saved anywhere, but can be observed via `context` argument.
        """

        class FakeModel(object):
            def __init__(self, ctx):
                self.context = ctx

            def save(self):
                pass

        instance = cls()
        instance._model = FakeModel(context)
        instance.tick()
        return context


class SingletonService(Service):
    """
    Acquires exclusive lock before run.
    """

    __metaclass__ = abc.ABCMeta

    def __init__(self, *args, **kwargs):
        super(SingletonService, self).__init__(*args, **kwargs)
        self._lock = None
        self._waiting_for_lock = False
        # TODO: generalize this in a way so that all services have a single and easily determined ZK lock path
        self.zk_name = self.name

    def get_status_report(self):
        report = super(SingletonService, self).get_status_report()
        report.update({
            "waiting_for_lock": self._waiting_for_lock,
        })
        return report

    @property
    def lock(self):
        if self._lock is None:
            self._lock = self._create_lock()
        return self._lock

    @lock.deleter
    def lock(self):
        if self._lock is not None:
            self._lock.destroy()
            self._lock = None

    def _create_lock(self):
        if not self.sandbox_config.common.zookeeper.enabled:
            return locking.NoLock()

        path = os.path.join(self.sandbox_config.common.zookeeper.root, "locks", "jobs")
        owner = self.sandbox_config.this.fqdn
        hosts = common.proxy.brace_expansion([self.sandbox_config.common.zookeeper.hosts], join=",")

        def on_disconnect():
            logger.warning("Lost ZooKeeper connection, terminate asap...")
            self.request_stop()

        def on_stepdown():
            logger.warning("Too many locks acquired, drop this one...")
            self.request_stop()

        return locking.FairZookeeperLock(
            path, self.zk_name, owner, hosts, on_disconnect, on_stepdown,
            watcher_delay=3600  # work at least for an hour
        )

    def _try_acquire_lock(self):
        self._waiting_for_lock = True
        while not self._stop_requested.is_set():
            try:
                self.lock.acquire(timeout=10)
            except locking.LockTimeout:
                continue

            self._waiting_for_lock = False
            return True

        return False

    def run_forever(self):
        if not self._try_acquire_lock():
            # Service was stopped during lock acquisition
            return

        try:
            super(SingletonService, self).run_forever()
        finally:
            del self.lock


class DistributedService(SingletonService):
    """
    Service with multiple execution entities performing different work.
    By default, it checks its workers' status on every iteration and restarts those that are dead
    """

    __metaclass__ = abc.ABCMeta

    worker_class = abc.abstractproperty()
    event_class = abc.abstractproperty()
    worker_state_class = DefaultWorkerState

    # tick() does not do real work here so we do not measure it. Durations of threaded targets are measured separately.
    push_tick_duration = False

    class Target(object):
        """
        Class that stores worker-related options; return a list of these from targets() propety
        """

        def __init__(self, function, interval=None, name=None, log_execution=True, stateful=False):
            """
            :param function: callable to execute. If `name` is not supplied, it is assumed that
                the callable has __name__ attribute
            :param interval: how frequently to run `function`, in seconds. Unless specified,
                it defaults to the value of `tick_interval`
            :param name: a string which identifies the worker among the others; used mostly for logging
            :param log_execution: log start and end of execution on every loop step
            :param stateful: when True, `function` will be fed with result of its previous execution,
                passed as the first argument. To use sane defaults on the first iteration,
                overload `worker_state_class` property
            """

            self.function = function
            self.interval = interval
            self.name = name or function.__name__
            self.stateful = stateful
            self.log_execution = log_execution

    def __init__(self, *args, **kwargs):
        super(DistributedService, self).__init__(*args, **kwargs)

        self.stop_event = self.event_class()
        self._workers = {}

    @abc.abstractproperty
    def targets(self):
        """ Target workers to execute, they must return result message to log and list of futures to wait """
        pass

    def worker_loop(self, target):
        def unwrap_return_value(ret):
            if isinstance(ret, tuple):
                try:
                    msg, futures_to_wait, next_state = ret
                except ValueError:
                    msg, futures_to_wait, next_state = ret + (self.worker_state_class(),)
            else:
                msg = ret
                futures_to_wait = []
                next_state = self.worker_state_class()

            return msg, futures_to_wait, next_state

        state = self.worker_state_class()
        if hasattr(state, "persistent_state"):
            obj = mapping.Service.objects(name=self.name).first()
            if obj:
                persistent_state = obj.threads.get(target.name, {}).get("persistent_state")
                if persistent_state is not None:
                    state.persistent_state = persistent_state
        while not self.stop_event.is_set():
            worker_tick_interval = state.interval if state.interval is not None else target.interval
            loop_start = dt.datetime.utcnow()
            mapping.Service.objects(name=self.name).update_one(**{
                "set__threads__{}".format(target.name): {
                    "last_run": loop_start,
                    "next_run": loop_start + dt.timedelta(seconds=worker_tick_interval),
                },
            })
            state.interval = worker_tick_interval

            if target.log_execution:
                logger.info("[%s] Start", target.name)

            try:
                result = target.function(state) if target.stateful else target.function()
                message, futures, state = unwrap_return_value(result)
                concurrent.futures.wait(filter(None, futures))
                if hasattr(state, "persistent_state"):
                    mapping.Service.objects(name=self.name).update_one(**{
                        "set__threads__{}__persistent_state".format(target.name): state.persistent_state,
                    })
            except Exception:
                logger.exception("[%s] unhandled exception!", target.name)
                message = "Terminated with exception"

            utcnow = dt.datetime.utcnow()
            duration = (utcnow - loop_start).total_seconds()

            self.signaler.push(dict(
                type=ctss.SignalType.SERVICE_STATISTICS,
                date=utcnow,
                timestamp=utcnow,
                service="{}.{}".format(self.name, target.name),
                duration=int(duration * 1000)
            ))

            wait_for = max(0, worker_tick_interval - duration)

            if target.log_execution:
                logger.info(
                    "[%s] End: %s (duration: %0.2fs, next run in %0.2fs)",
                    target.name, message, duration, wait_for
                )

            self.stop_event.wait(wait_for)

    def _make_worker(self, target):
        if target.interval is None:
            target.interval = self.tick_interval

        def inner():
            try:
                self.worker_loop(target)
            except:
                logger.exception("[%s] Unhandled exception, the thread has just died", target.name)
                raise

        worker = self.worker_class(target=inner)
        worker.start()
        return worker

    def spawn_workers(self):
        self._workers = {
            target.name: self._make_worker(target)
            for target in self.targets
        }

    def tick(self):
        if not self._workers:
            self.spawn_workers()

        for target in self.targets:
            worker = self._workers[target.name]
            if worker.is_alive():
                continue
            logger.warning("[%s] RESPAWNING dead %s", target.name, self.worker_class.__name__.lower())
            self._workers[target.name] = self._make_worker(target)

    def on_stop(self):
        self.stop_event.set()
        for k, th in self._workers.iteritems():
            logger.info("[%s] waiting for the stop", k)
            th.join()

        super(DistributedService, self).on_stop()


class ThreadedService(DistributedService):
    __metaclass__ = abc.ABCMeta

    @property
    def worker_class(self):
        return threading.Thread

    @property
    def event_class(self):
        return threading.Event


class MultiprocessedService(DistributedService):
    __metaclass__ = abc.ABCMeta

    @property
    def worker_class(self):
        return multiprocessing.Process

    @property
    def event_class(self):
        return multiprocessing.Event

    def tick(self):
        if not self._workers:
            mapping.disconnect()  # the connection is guaranteed to be restored before workers' execution
            self.spawn_workers()
            mapping.ensure_connection()

        for target in self.targets:
            worker = self._workers.get(target.name)
            if worker is None or worker.is_alive():
                continue

            if self.stop_event.is_set():
                del self._workers[target.name]
            else:
                logger.warning("[%s] the process has died; respawn it", target.name)
                self._workers[target.name] = self._make_worker(target)

    def worker_loop(self, target):
        mapping.ensure_connection(
            uri=self.sandbox_config.server.mongodb.connection_url,
            max_pool_size=100,
        )
        super(MultiprocessedService, self).worker_loop(target)


class WalleSingletonService(SingletonService):
    MAX_LIMIT = 10000

    def __init__(self, *args, **kwargs):
        super(WalleSingletonService, self).__init__(*args, **kwargs)
        if self.sandbox_config.server.auth.enabled:
            auth = common.utils.read_settings_value_from_file(self.sandbox_config.common.walle.token)
        else:
            auth = None
        self.walle_client = common.rest.Client(self.sandbox_config.common.walle.api, auth=auth)

    def get_host_names(self, **query):
        result = set()
        cursor = 0
        host_num = 0

        query["fields"] = ["name"]
        query["limit"] = self.MAX_LIMIT

        while cursor != -1:
            query["cursor"] = cursor

            response = self.walle_client.hosts.read(**query)
            hosts = response.get("result", [])
            if not hosts:
                break

            result.update([host["name"] for host in hosts])
            host_num += len(hosts)
            cursor = response.get("next_cursor", -1)

        return result
