# coding: utf-8

import os
import sys
import abc
import json
import time
import copy
import Queue
import signal
import httplib
import logging
import threading
import xmlrpclib
import collections
import datetime as dt

import psutil

from sandbox import common
import sandbox.common.types.misc as ctm
import sandbox.common.types.task as ctt

from sandbox.yasandbox import services

from sandbox.web import response


logger = logging.getLogger(__name__)


def make_reply_to_large_request(source, ex):
    if source == ctt.RequestSource.RPC:
        ret = xmlrpclib.dumps(
            xmlrpclib.Fault(
                common.proxy.ReliableServerProxy.ErrorCodes.ERROR,
                '{0}: {1}'.format(ex.__class__.__name__, ex.message)
            ),
            '',
            True
        )
    else:
        ret = response.HttpExceptionResponse(code=httplib.UNPROCESSABLE_ENTITY)
    return ret


class FakeRequest(object):
    """ Fake request object to be used on reply for non-picklable responses. """
    def __init__(self, req):
        self.id = req.id
        self.internal_id = req.internal_id
        self.ctx = req.ctx
        self.handler = req.handler
        self.user = req.user
        self.source = req.source
        self.measures = req.measures
        self.read_preference = req.read_preference
        self.session = req.session
        self.quota_owner = req.quota_owner
        self.headers = req.headers
        self.query_string = req.query_string


class ThreadsPoolBase(object):
    """
    Threads pool abstract base.
    """
    __metaclass__ = abc.ABCMeta

    def __init__(self, size, queue_type=Queue.Queue):
        self.size = size
        self.queue = queue_type(size / 3)
        self.threads = [threading.Thread(target=self.main) for _ in xrange(size)]

    def start(self):
        logger.info("Initializing pool of %d threads.", len(self.threads))
        map(threading.Thread.start, self.threads)

    def stop(self):
        logger.info('Signaling workers to finish.')
        map(self.queue.put, [None] * len(self.threads))
        logger.info('Waiting for thread pool to finish.')
        map(threading.Thread.join, self.threads)

    @property
    def qsize(self):
        return self.queue.qsize()

    @abc.abstractmethod
    def main(self):
        pass


class StatisticsThread(object):
    """
    Statistics collector thread. The main point here is to wakeup at very strict time period to collect
    correct statistics.
    """
    # Statistics collector frame sizes in seconds.
    FRAMES = (5, 15, 30)

    class RequestsCounters(common.patterns.Abstract):
        """ Absolute amount of requests processed during server's lifetime. """
        __slots__ = ('low_priority', 'high_priority')
        __defs__ = (0, 0)

    class RPS(common.patterns.Abstract):
        """ Amount of requests per frame (5, 15 and 30 seconds). """
        class Slot(common.patterns.Abstract):
            """ Single slot representation. """
            __slots__ = ('count', 'ts')
            __defs__ = (0, 0)

        __slots__ = ('s5', 's15', 's30')
        __defs__ = [[None, None]] * 3

        def __init__(self):
            super(StatisticsThread.RPS, self).__init__(*(
                (frame, [self.Slot() for _ in xrange(frame)])
                for frame in StatisticsThread.FRAMES
            ))

        def avg(self):
            """ Average requests per seconds per frame. """
            return {k: round(sum(s.count for s in slots) / float(len(slots)), 2) for k, (frame, slots) in self}

    class Process(common.patterns.Abstract):
        """ Process statistics. """
        class CPU(common.patterns.Abstract):
            """ Amount of CPU times per frame (5, 15 and 30 seconds). """
            CHECK_PERIOD = 5

            class Slot(common.patterns.Abstract):
                """ Single CPU slot representation """
                __slots__ = ('user', 'system', 'idle')
                __defs__ = (.0, .0, .0)

                def __sub__(self, rval):
                    return self.__class__(*[
                        max(0, v) for v in (self.user - rval.user, self.system - rval.system, self.idle - rval.idle)
                    ])

            __slots__ = ('prev', 's5', 's15', 's30')
            __defs__ = [[None, None]] * 4

            def __init__(self):
                super(StatisticsThread.Process.CPU, self).__init__(
                    [None, self.Slot()],
                    *(
                        (frame, [self.Slot() for _ in xrange(frame / self.CHECK_PERIOD)])
                        for frame in StatisticsThread.FRAMES
                    )
                )

            def avg(self):
                """ Average requests per seconds per frame. """
                return {
                    k: self.Slot(*(
                        round(sum(getattr(s, attr) for s in slots) / float(len(slots)) / self.CHECK_PERIOD, 2)  # noqa
                        for attr in self.Slot.__slots__
                    ))
                    for k, (frame, slots) in self if k != 'prev'
                }

        class Memory(common.patterns.Abstract):
            """ Memory consumption (virtual and resident) in kilobibytes. """
            __slots__ = ('rss', 'vms')
            __defs__ = (0, 0)

        __slots__ = ('pid', 'cpu', 'mem')
        __defs__ = (0, None, None)

    class Processes(common.patterns.Abstract):
        """ Per-process statistics aggregate. """
        __slots__ = ('self', 'service', 'workers')
        __defs__ = (None, ) * 3

    def __init__(self, service_pid, worker_pids):
        self.requests = self.RequestsCounters()
        self.rps = self.RPS()
        self.processes = self.Processes(
            self.Process(os.getpid(), self.Process.CPU(), self.Process.Memory()),
            self.Process(service_pid, self.Process.CPU(), self.Process.Memory()),
            [self.Process(pid, self.Process.CPU(), self.Process.Memory()) for pid in worker_pids]
        )
        self.stopping = None
        self.thread = None
        self.started = None

    def start(self):
        self.stopping = threading.Event()
        self.thread = threading.Thread(target=self._run)
        self.thread.start()
        self.started = time.time()

    def stop(self):
        self.stopping.set()

    def join(self):
        self.thread.join()

    def count_request(self, hp):
        if hp:
            self.requests.high_priority += 1
        else:
            self.requests.low_priority += 1

        ts = int(time.time())
        for frame, slots in self.rps.itervalues():
            quotient, reminder = divmod(ts, frame)
            slot = slots[reminder]
            if slot.ts != quotient:
                slot.ts = quotient
                slot.count = 0
            slot.count += 1

    def _run(self):
        period = self.Process.CPU.CHECK_PERIOD
        while not self.stopping.is_set():
            now = time.time()
            ts = int(now)
            processes = [self.processes.self, self.processes.service] + self.processes.workers
            for p in processes:
                try:
                    data = psutil.Process(p.pid)
                    if not data.is_running:
                        continue
                except psutil.NoSuchProcess:
                    break
                cpu = data.get_cpu_times()
                uptime = now - data.create_time
                curr = self.Process.CPU.Slot(cpu.user, cpu.system, uptime - cpu.system - cpu.user)
                prev, p.cpu.prev[1] = p.cpu.prev[1], curr
                diff = curr - prev
                for frame, slots in p.cpu.itervalues():
                    if not frame:
                        continue
                    quotient, reminder = divmod(ts, frame / period)
                    slots[reminder - 1] = diff
                mem = data.get_memory_info()
                p.mem.rss = mem.rss >> 10
                p.mem.vms = mem.vms >> 10

            # Cleanup request rate statistics
            for offset in xrange(period):
                for frame, slots in self.rps.itervalues():
                    quotient, reminder = divmod(ts - offset, frame)
                    slot = slots[reminder]
                    if slot.ts != quotient:
                        slot.ts = quotient
                        slot.count = 0

            now = time.time()
            will = int((now + period) / period) * period
            self.stopping.wait(will - now)


class Watchdog(threading.Thread):
    """ Sub-processes watchdog thread. Monitors all the sub-processes at ones. """

    def __init__(self, workers, stopper):
        super(Watchdog, self).__init__()
        self._stopper = stopper
        self._workers = workers
        self.logger = logger
        self._watch = set()

    def watch(self, pid):
        self._watch.add(pid)

    def run(self):
        self.logger.info('Watchdog thread started.')
        while True:
            pid, rc, _ = os.wait3(0)
            if self._workers.stopping:
                break
            if pid in self._watch and rc:
                self.logger.warn('Child process #%d exited with code %d. Stopping the server.', pid, rc)
                self._stopper()
                break
        self.logger.info('Watchdog thread stopped.')


class Window(object):
    class Sample(common.patterns.Abstract):
        __slots__ = ("value", "appended")
        __defs__ = (None,) * 2

    def __init__(self, period, max_uncheck=1000):
        self.period = period if isinstance(period, dt.timedelta) else dt.timedelta(seconds=period)
        self._samples = collections.deque()
        self._max_uncheck = max_uncheck
        self._unchecked = 0
        self._lock = threading.Lock()

    def _timemark(self):
        now = dt.datetime.utcnow()
        return now - self.period, now

    def __lshift__(self, value):
        s = Window.Sample(value, dt.datetime.utcnow())
        self._samples.append(s)
        self._unchecked += 1
        if self._unchecked > self._max_uncheck:
            self.actualize()

    def actualize(self):
        threshold, now = self._timemark()
        self._samples = collections.deque(s for s in copy.copy(self._samples) if s.appended > threshold)
        self._unchecked = 0

    def __iter__(self):
        self.actualize()
        for s in list(self._samples):
            yield s.value


class Workers(common.process.Master):
    """
    Web server's workers pool. Singleton.
    """

    class ThreadsPool(ThreadsPoolBase):
        def main(self):
            while True:
                data = self.queue.get()
                if not data:
                    break
                try:
                    hp, process_req = data
                    process_req(hp)
                except Exception:
                    logger.exception("Unhandled exception caught while processing the request")

    class ProcessWorker(common.process.Slave):
        class ThreadsPool(ThreadsPoolBase):
            def __init__(self, size, reply_queue):
                self.reply_queue = reply_queue
                self.__lock = threading.Lock()
                super(Workers.ProcessWorker.ThreadsPool, self).__init__(size)

            def main(self):
                while True:
                    ret, data = None, self.queue.get()
                    if not data:
                        break
                    _, (rid, method, req) = data
                    st, pr, req_profiler = None, None, None
                    if req.profid:
                        st = time.time()
                        pr = common.profiler.Profiler()
                        pr.enable()
                    if req.headers.get(ctm.HTTPHeader.PROFILER) is not None:
                        req_profiler = common.profiler.Profiler(
                            sort_field=int(req.headers.get(ctm.HTTPHeader.PROFILER))
                        )
                        req_profiler.enable()
                    try:
                        ret = method(req)
                    except response.HttpResponseBase as reply:
                        ret = reply
                    except Exception:
                        ret = response.HttpExceptionResponse()
                    finally:
                        if req_profiler is not None:
                            req_profiler.disable()
                            if isinstance(ret, response.HttpResponse):
                                msg = ret.content
                            elif isinstance(ret, response.HTTPError):
                                msg = ret.msg
                            else:
                                msg = str(ret)
                            ret = response.HttpResponse(
                                "application/json",
                                json.dumps(
                                    {"result": msg, "profile": req_profiler.dump_to_str()},
                                    ensure_ascii=False, encoding="utf-8", cls=common.rest.Client.CustomEncoder
                                ),
                                httplib.OK,
                            )
                        if pr:
                            settings = common.config.Registry().server.profiler.performance
                            pr.disable()
                            th = settings.threshold
                            ts = int((time.time() - st) * 1000)
                            if not th or ts >= th:
                                path = os.path.join(settings.data_dir, '{}_{}'.format(ts, req.profid))
                                pr.dump_to_file(path)
                                logger.info("Request performance profile dump saved to '%s'", path)
                    with self.__lock:
                        try:
                            self.reply_queue.put((ret, req))
                        except common.errors.DataSizeError as ex:
                            logger.exception("Error processing response for request %r for method %r", req, method)
                            ret = make_reply_to_large_request(req.source, ex)
                            self.reply_queue.put((ret, FakeRequest(req)))
                        except:
                            logger.exception("Error processing response for request %r for method %r", req, method)
                            ret = response.HttpExceptionResponse()
                            self.reply_queue.put((ret, FakeRequest(req)))

        def __init__(self, no, pool_size, queue):
            self.no = no
            self.pool = None
            self.pool_size = pool_size
            pidfile = os.path.join(common.config.Registry().client.dirs.run, 'server.{}.pid'.format(no))
            super(Workers.ProcessWorker, self).__init__(logger, pidfile, queue)

        def on_start(self):
            from kernel.util import console
            console.setProcTitle('[sandbox] Web Server Worker #{}'.format(self.no))
            logger.info('Process #%d (web server worker): initializing managers and controllers.', self.pid)
            from sandbox.yasandbox import manager
            from sandbox.yasandbox import controller
            from sandbox.yasandbox.database import mapping

            logger.info("Establishing database connection.")
            mapping.ensure_connection()
            logger.info("Initializing manager objects.")
            manager.initialize_locally()
            logger.info("Initializing controllers")
            controller.initialize()
            logger.info("Setting up statistics processing")
            common.statistics.Signaler(
                common.statistics.ServerSignalHandler(),
                logger=logger,
                component=ctm.Component.SERVER,
                update_interval=common.config.Registry().server.statistics.update_interval,
            )

            from sandbox.yasandbox import context
            from sandbox.serviceapi import trackers

            # Patch pymongo to enable event collection
            def mongodb_event_handler(event):
                if context.current:
                    context.current.add_span("mongodb_op", event.pop("time"), event)

            trackers.mongodb.install_tracker()
            trackers.mongodb.register_event_handler(mongodb_event_handler)

            # Patch serviceq.client to enable event collection
            def serviceq_event_handler(event):
                if context.current:
                    context.current.add_span("serviceq_call", event.pop("time"), event)

            trackers.serviceq.install_tracker()
            trackers.serviceq.register_event_handler(serviceq_event_handler)

            logger.info('Process #%d worker #%d started with pool of %d threads.', self.pid, self.no, self.pool_size)
            self.pool = self.ThreadsPool(self.pool_size, self.queue)
            self.pool.start()

        def on_stop(self):
            logger.info('Process #%d worker: waiting for threads to complete.', self.pid)
            self.pool.stop()
            common.statistics.Signaler().wait()

        def process(self, data):
            self.pool.queue.put(data)

    class ServiceQ(common.process.Slave):
        def __init__(self):
            logger.info('Starting service Q.')
            pidfile = os.path.join(common.config.Registry().client.dirs.run, 'serviceq.pid')
            super(Workers.ServiceQ, self).__init__(logger, pidfile, None)

        def process(self, data):
            # Service Q cannot process any IPC messages.
            pass

        def main(self):
            import serviceq.bin
            cmd = [sys.executable, os.path.join(os.path.dirname(serviceq.bin.__file__), 'server.py')]
            self.logger.info('Service Q started with PID #%s. execv(%r)', self.mypid, cmd)
            os.execv(sys.executable, cmd)
            assert False, "This point should not be reached"

    class ServiceApi(common.process.Slave):
        def __init__(self):
            logger.info("Starting ServiceApi")
            pidfile = os.path.join(common.config.Registry().client.dirs.run, "serviceapi.pid")
            super(Workers.ServiceApi, self).__init__(logger, pidfile, None)

        def process(self, data):
            # ServiceApi cannot process any IPC messages.
            pass

        def main(self):
            serviceapi_binary = os.path.join(common.config.Registry().common.dirs.data, "serviceapi/serviceapi")
            self.logger.info("ServiceApi started with PID #%s. execv(%r)", self.mypid, serviceapi_binary)
            os.execv(serviceapi_binary, [serviceapi_binary])
            assert False, "This point should not be reached"

    class Taskbox(common.process.Slave):
        def __init__(self):
            logger.info("Starting Taskbox")
            pidfile = os.path.join(common.config.Registry().client.dirs.run, "taskbox.pid")
            super(Workers.Taskbox, self).__init__(logger, pidfile, None)

        def process(self, data):
            # Taskbox cannot process any IPC messages.
            pass

        def main(self):
            taskbox_binary = os.path.join(common.config.Registry().common.dirs.data, "taskbox/taskbox")
            self.logger.info("Taskbox started with PID #%s. execv(%r)", self.mypid, taskbox_binary)
            os.execv(taskbox_binary, [taskbox_binary])
            assert False, "This point should not be reached"

    class TVMTool(common.process.Slave):
        def __init__(self):
            logger.info("Starting TVMTool")
            pidfile = os.path.join(common.config.Registry().client.dirs.run, "tvmtool.pid")
            super(Workers.TVMTool, self).__init__(logger, pidfile, None)

        def process(self, data):
            # TVMTool cannot process any IPC messages.
            pass

        def main(self):
            settings = common.config.Registry()
            tvmtool_binary = os.path.join(settings.common.dirs.data, "tvmtool/tvmtool")
            tvmtool_conf = os.path.join(settings.common.dirs.data, "configs/tvm.conf")
            port = str(settings.common.tvm.port)
            access_token = common.utils.read_settings_value_from_file(settings.common.tvm.access_token)
            self.logger.info("TVMTool started with PID #%s. execv(%r)", self.mypid, tvmtool_binary)
            os.execve(
                tvmtool_binary, [tvmtool_binary, "-c", tvmtool_conf, "--port", port, "-u"],
                {"QLOUD_TVM_TOKEN": access_token}
            )

    def __init__(self, threads_per_process=None, processes_pool_size=None, stopper=None, last_requests_size=None):
        assert threads_per_process and processes_pool_size
        assert threads_per_process > 0
        last_requests_size = last_requests_size or 15
        self.pending = {}
        self.in_progress = {}
        self.last_requests = Window(last_requests_size)
        self._stopper = stopper
        self.watchdog = None
        self._stopping = threading.Event()
        self.processes_pool_size = processes_pool_size
        self.threads_pool_size = threads_per_process * processes_pool_size
        self.threads_per_process = threads_per_process
        self.threads = self.ThreadsPool(self.threads_per_process, Queue.PriorityQueue)

        self.listener = None
        self.statistics = None
        super(Workers, self).__init__(logger)

    def _slaves_builder(self):
        if common.config.Registry().server.services.serviceq.enabled:
            yield self.ServiceQ()

        if common.config.Registry().server.services.serviceapi.enabled:
            yield self.ServiceApi()

        if common.config.Registry().server.services.taskbox.enabled:
            yield self.Taskbox()

        if common.config.Registry().server.services.tvmtool.enabled:
            yield self.TVMTool()

        for i in xrange(self.processes_pool_size):
            yield self.ProcessWorker(i, self.threads_per_process, self.queue)

    def _listener_main(self):
        logger.info('Response listener thread started.')
        try:
            while True:
                try:
                    data = self.queue.get()
                    if not data:
                        break
                    reply, req = data
                except EOFError:
                    break
                h = self.pending.pop(req.internal_id, None)
                if not h:
                    logger.warning('No handler for response #%s (internal_id=%s', req.id, req.internal_id)
                    continue
                h.reply(reply, req)
        finally:
            logger.info('Response listener thread stopped.')

    def start(self):
        self.init()
        super(Workers, self).start()
        self.threads.start()
        self.listener = threading.Thread(target=self._listener_main)
        self.listener.start()
        brigadier = services.Brigadier()
        self.statistics = StatisticsThread(brigadier.processes[0].pid, [p.pid for p in self.processes])
        self.statistics.start()
        self.watchdog = Watchdog(self, self._stopper)
        map(self.watchdog.watch, (p.pid for p in self.processes + brigadier.processes))
        self.watchdog.start()

    def put(self, opt, high_priority=False):
        self.statistics.count_request(high_priority)
        self.threads.queue.put((int(high_priority), opt))

    @property
    def qsize(self):
        """ Returns approximate amount of requests currently in queue. """
        return self.queue.qsize() + self.threads.qsize

    @property
    def processing(self):
        """ Returns approximate amount of requests currently processing. """
        return len(self.in_progress)

    @property
    def stopping(self):
        return self._stopping.is_set()

    @stopping.setter
    def stopping(self, _):
        self._stopping.set()

    def stop(self):
        if not self.watchdog:
            return

        self._stopping.set()
        self.threads.stop()
        self.statistics.stop()
        super(Workers, self).stop()

        for process in self.processes:
            if process.__class__ in (self.ServiceQ, self.ServiceApi):
                logger.info("Asking %s (pid %s) to stop.", process.__class__.__name__, process.pid)
                try:
                    os.kill(process.pid, signal.SIGINT)
                except OSError as ex:
                    logger.error("Error stopping %s: %s", process.__class__.__name__, str(ex))
        self.join()
        logger.info('Waiting for statistics thread.')
        self.statistics.join()
        logger.info('Waiting for listener thread.')
        self.queue.put_to_output(None)
        self.listener.join()
        self.watchdog.join()
        self.watchdog = None

    def async_reply(self, obj, hp, method, request):
        self.pending[request.internal_id] = obj
        self.process((int(not hp), (request.id, method, request)))


class KamikadzeThread(object):
    """
    神風糸, literally: "God wind thread"; common translation: "Divine wind yarn".
    A special thread, which will kill the process after a given amount of time
    if `success()` method will not be called before. Singleton.
    """
    __metaclass__ = common.utils.SingletonMeta

    def __init__(self, wait, on_finish=None, daemon=False):
        self.on_finish = on_finish
        self.event = threading.Event()
        self.thread = threading.Thread(target=self._run, args=(wait,))
        self.thread.daemon = daemon
        self.thread.start()
        Workers().stopping = True

    @classmethod
    def finish(cls):
        if not cls.instance:
            return
        # noinspection PyUnresolvedReferences
        cls.instance.event.set()

    @classmethod
    def join(cls):
        # noinspection PyUnresolvedReferences
        cls.instance.thread.join()

    def _run(self, wait):
        brigadier = services.Brigadier()
        logger.info('Will shutdown in %s seconds maximum ...', wait)
        brigadier.stop()
        if not self.event.wait(wait):
            Workers().destroy()
            brigadier.destroy()
            logger.error(
                'Shutdown server forcedly. %s/%s worker processes/threads are still running, service process is %s.',
                len(filter(common.process.Process.is_alive, Workers().processes)),
                len(filter(threading.Thread.is_alive, Workers().threads.threads)),
                'running' if brigadier.worker and brigadier.worker.is_alive() else 'stopped'
            )
        else:
            logger.info('Server stopped normally.')
        self.on_finish()
