import os
import sys
import signal

import gevent
import gevent.event
try:
    from gevent.coros import RLock
except ImportError:
    from gevent.lock import RLock
import gevent.socket
import psutil
import api.skycore.errors as exceptions
import py
from py import std

try:
    from porto import Connection as PortoConnection
except ImportError:
    PortoConnection = None
else:
    from ..portoutils import PortoAdapter


from ..kernel_util.sys.portoslave import same_container

from .portowatcher import PortoWatcher
from .configupdater import ConfigUpdater
from .serviceupdater import ServiceUpdater
from .namespace import InstallationUnit
from .cgroupscontroller import CgroupsController
from .statereporter import StateReporter
from ..framework.component import Component
from ..framework.utils import Path
from ..framework.greendeblock import Deblock
from ..db import Database
from ..rpc.server import RPC, Server as RPCServer
from ..procs.starter import Starter
from ..procs.mailbox import mailbox
from ..stats import Stats


ACTUAL_FILENAME = 'actual.yaml'


class SkycoreDaemon(Component):
    def __init__(
        self, app_path, workdir, revision, output_logger_factory,
        skynetdir=None, supervisordir=None,
        yappi_mode=None,
        restart_all=None,
        paused=False,
    ):
        self.yappi_mode = yappi_mode
        self.restart_all = restart_all

        super(SkycoreDaemon, self).__init__(logname='')
        self.log.debug('SkycoreDaemon init...')
        self.log.debug('  path  : %s', app_path)
        self.log.debug('  pid   : %d', os.getpid())
        self.log.debug('  nofile: %r', std.resource.getrlimit(std.resource.RLIMIT_NOFILE))
        if os.uname()[0].lower() != 'darwin' and sys.platform != 'cygwin':
            self.log.debug('  uids  : %r', std.os.getresuid())
        else:
            self.log.debug('  uids  : (%d, %d)', os.getuid(), os.geteuid())
        restart_all and self.log.debug('  services restart forced ({})'.format(restart_all))
        yappi_mode and self.log.debug('  yappi enabled')

        self.started_event = gevent.event.Event()
        self.psproc = psutil.Process(os.getpid())
        self.uid = os.getuid()
        gid = os.getgid()

        self.should_stop = False

        self.path = app_path
        self.workdir = workdir

        rundir = Path(self.workdir).join("var")
        rundir.ensure(dir=True)
        rundir.ensure_perms(self.uid, gid, (0o1777, 0o1666), mask=0)

        downloaddir = Path(self.workdir).join("dl")
        downloaddir.ensure(dir=True)
        downloaddir.ensure_perms(self.uid, gid, (0o755, 0o644), mask=0)

        nsdir = Path(self.workdir).join("ns")
        nsdir.ensure(dir=True)
        nsdir.ensure_perms(self.uid, gid, (0o755, 0o644), mask=0)

        linkdir = Path(self.workdir).join('svcs')
        linkdir.ensure(dir=True)
        linkdir.ensure_perms(self.uid, gid, (0o755, 0o644), mask=0)

        confdir = Path(self.workdir).join("conf")
        confdir.ensure(dir=True)
        confdir.ensure_perms(self.uid, gid, (0o755, 0o644), mask=0)

        apidir = Path(self.workdir).join("api")
        apidir.ensure(dir=True)
        apidir.ensure_perms(self.uid, gid, (0o755, 0o644), mask=0)

        self.rpc = RPC(self.log.getChild('rpc'))
        self.rpc_sock = Path(self.workdir).join('rpc.sock')
        self.rpc_server = RPCServer(self.log.getChild('rpc'), backlog=10, max_conns=1000, unix=self.rpc_sock.strpath)

        if PortoConnection is None or os.uname()[0].lower() != 'linux':
            portoconn = self.portoconn = None
            self.portowatcher = None
            self.log.info("porto is not available on host")
        else:
            portoconn = PortoConnection(timeout=20, auto_reconnect=False)
            portoconn = self.portoconn = PortoAdapter(portoconn, log=self.log.get_child('porto'))
            self.portowatcher = PortoWatcher(conn=portoconn, parent=self)

        self.process_lock = RLock()

        self.deblock = Deblock(logger=self.log.getChild('deblock'), name='common')

        self.state_reporter = StateReporter(
            parent=self,
            deblock=self.deblock
        )
        self.config_updater = ConfigUpdater(
            parent=self,
            config_dir=confdir.strpath,
            filename=ACTUAL_FILENAME,
            hostname=os.getenv('SKYNET_HOSTNAME', None),
            deblock=self.deblock,
            paused=paused,
        )

        if sys.platform.startswith('linux'):
            self.cgroups_controller = CgroupsController(
                parent=self,
                workdir=self.workdir,
                registry=self.config_updater,
                lock=self.process_lock
            )
        else:
            self.cgroups_controller = None

        starter = Starter(
            porto=portoconn,
            portowatcher=self.portowatcher,
            workdir=rundir.strpath,
            process_lock=self.process_lock,
            cgroup_controller=self.cgroups_controller
        )

        if not paused:
            # update config if we are not freezed only
            try:
                self.config_updater.update_config()
            except Exception:
                self.log.warning("Failed to update config, will attempt later:", exc_info=sys.exc_info())

        self.config_updater.subscribe(starter.config_changed, [('skynet', 'skycore', 'config')], config_only=True)

        statefile = os.path.join(self.workdir, 'skycore.state')

        self.service_updater = ServiceUpdater(
            deblock=self.deblock,
            base=self.workdir,
            workdir=nsdir.strpath,
            rundir=rundir.strpath,
            linkdir=linkdir.strpath,
            apidir=apidir.strpath,
            downloaddir=downloaddir.strpath,
            statefile=statefile,
            skynetdir=skynetdir,
            supervisordir=supervisordir,
            output_logger_factory=output_logger_factory,
            starter=starter,
            config=self.config_updater,
            namespaces={},
            core_revision=revision,
            parent=self,
            reporter=self.state_reporter,
            paused=paused,
        )

    # API methods

    def ping(self, wait_active=None, extended=False):
        return True

    def query(self, path):
        if not self.started_event.wait(90.):
            raise Exception("skycore couldn't start itself")

        self.log.debug('query: {}'.format('.'.join(path)))
        try:
            # there is no need for deepcopy, RPC will do it
            result = self.config_updater.query(path, deepcopy=False)
        except (KeyError, IndexError):
            raise exceptions.PathError('Unknown path: %s' % path)
        except LookupError:
            raise exceptions.PathError("Config is not ready yet")

        return result

    def pause(self, job):
        self.log.debug('pause')
        self._check_perms(job)

        if self.config_updater.paused:
            # updater is paused already
            return False

        self.service_updater.paused = True
        self.config_updater.paused = True
        return True

    def unpause(self, job):
        self.log.debug('unpause')
        self._check_perms(job)

        if not self.config_updater.paused:
            # updater is not paused yet
            return False

        self.service_updater.paused = False
        self.config_updater.paused = False
        return True

    def state(self):
        if not self.started_event.wait(90.):
            raise Exception("skycore couldn't start itself")

        state = dict()
        state['config_updater'] = self.config_updater.state()
        state['cgroup_controller'] = self.cgroups_controller and self.cgroups_controller.state()
        return state

    def start_services(self, job, namespace, services=None, timeout=90.0):
        self._check_perms(job)

        if not self.started_event.wait(90.):
            raise Exception("skycore couldn't start itself")

        ns = self._get_namespace(namespace)
        ns.start_services(names=services, timeout=timeout)
        return True

    def stop_services(self, job, namespace, services=None, timeout=90.0):
        self._check_perms(job)

        if not self.started_event.wait(90.):
            raise Exception("skycore couldn't start itself")

        ns = self._get_namespace(namespace)
        ns.stop_services(names=services, timeout=timeout)
        return True

    def restart_services(self, job, namespace, services=None, timeout=90.0):
        self._check_perms(job)

        if not self.started_event.wait(90.):
            raise Exception("skycore couldn't start itself")

        ns = self._get_namespace(namespace)
        result = gevent.event.AsyncResult()

        def _inner():
            try:
                ns.stop_services(names=services, timeout=timeout)
                ns.start_services(names=services, timeout=timeout)
            except Exception as e:
                result.set_exception(e)
            else:
                result.set(True)

        gevent.spawn(_inner)
        return result.get()

    def check_services(self, namespace, services=None, new_format=False):
        if not self.started_event.wait(90.):
            raise Exception("skycore couldn't start itself")

        ns = self._get_namespace(namespace)
        return ns.get_states(services, new_format)

    def list_namespaces(self):
        if not self.started_event.wait(90.):
            raise Exception("skycore couldn't start itself")

        return list(self.service_updater.namespaces.keys())

    def list_services(self, namespace):
        if not self.started_event.wait(90.):
            raise Exception("skycore couldn't start itself")

        ns = self._get_namespace(namespace)
        return list(ns.services.keys())

    def get_service_field(self, namespace, service, field, raw=False):
        if not self.started_event.wait(90.):
            raise Exception("skycore couldn't start itself")

        ns = self._get_namespace(namespace)
        result = ns.get_service_field(service, field, raw)
        if isinstance(result, (set, frozenset)):
            result = list(result)

        return result

    def get_service_api(self, namespace, service, kind='python'):
        if not self.started_event.wait(90.):
            raise Exception("skycore couldn't start itself")

        ns = self._get_namespace(namespace)
        return ns.get_service_api(service, kind=kind)

    def install_tgz(self, job, namespace, path):
        self._check_perms(job)

        if not self.started_event.wait(90.):
            raise Exception("skycore couldn't start itself")

        result = gevent.event.AsyncResult()

        def _inner():
            try:
                md5sum = self.deblock.apply(Path(path).computehash, 'md5')
                version = '...' + path[:-30] if len(path) > 30 else path
                version = version + ' [%s]' % (md5sum[:8],)
                release = {
                    'svn_url': None,
                    'version': version,
                    'urls': [],
                    'md5': md5sum,
                    'filename': os.path.basename(path),
                    'size': 0,
                }
                unit = InstallationUnit('<from console>', release, self.deblock)
                unit.archive = path

                # self.log.debug("[install] creating namespace %r", namespace)
                ns = self.service_updater.ensure_namespace(namespace)

                self.log.info("installing tgz %r into namespace %r", path, namespace)

                try:
                    # self.log.debug("[install] extracting %r to %r", path, tempdir)
                    unit.extract(self.service_updater.downloaddir, self.log.getChild('install_tgz'))

                    # self.log.debug("[install] installing %r from %r", md5sum, tempdir)
                    installed = list(self.deblock.apply(ns.install_services, unit))
                    # self.log.debug("[install] installed %r", installed)
                    if not installed:
                        raise exceptions.SkycoreError("No services found in tgz")

                    self.log.info("installed services: %r", installed)
                    ns.inject_services(installed, log=self.log)
                finally:
                    unit.close()
            except Exception as e:
                result.set_exception(e)
            else:
                result.set(True)

        gevent.spawn(_inner)
        return result.get()

    def uninstall_services(self, job, namespace, services):
        self._check_perms(job)
        if not services:
            raise exceptions.SkycoreError("At least one service name should be specified (for security reasons)")

        if not self.started_event.wait(90.):
            raise Exception("skycore couldn't start itself")

        ns = self._get_namespace(namespace)
        self.log.debug("required to remove: %s", services)
        result = gevent.event.AsyncResult()

        def _inner():
            try:
                for srvc in services:
                    ns.uninstall_service(srvc)
            except Exception as e:
                result.set_exception(e)
            else:
                result.set(True)

        gevent.spawn(_inner)
        try:
            return result.get()
        finally:
            ns.cleanup()

    def shutdown_call(self, job):
        self._check_perms(job)

        if not self.started_event.wait(90.):
            raise Exception("skycore couldn't start itself")

        self.should_stop = True
        self.join()

        return True

    def get_stats(self):
        if not self.stats:
            raise Exception("No stats available")

        if not self.started_event.wait(90.):
            raise Exception("skycore couldn't start itself")

        self.service_updater.report_rusage()
        self._report_rusage()
        return self.stats.get_all()

    def force_update(self, job):
        self._check_perms(job)

        if not self.started_event.wait(90.):
            raise Exception("skycore couldn't start itself")

        try:
            self.config_updater.update_config()
            self.service_updater.check_all_namespaces(self.service_updater.log, 'request by API')
        except ServiceUpdater.UpdateException as e:
            raise Exception("update failed: " + str(e))
        return True

    # intentionally non-documented method to start debug console
    def run_debug_console(self, job, port):
        self._check_perms(job)

        from ..reverse import ReversePythonShell
        ReversePythonShell("localhost", port).start()

    # API methods end

    def stop(self):
        # This will say to us "should stop" and quit immediately
        # avoiding blocking on RPC request
        self.should_stop = True
        return True

    def shutdown(self):
        self.should_stop = True  # ensure this set
        signal.alarm(900)

        super(SkycoreDaemon, self).stop()

        self._report_rusage()

        self.db.close()
        self.db.deblock.stop()

        self.service_updater.join()
        self.config_updater.join()
        if self.cgroups_controller:
            self.cgroups_controller.join()
        if self.portowatcher:
            self.portowatcher.join()

        self.deblock.stop()

        if self.yappi_mode is not None:
            import yappi
            import datetime

            yappi.stop()
            stats = yappi.get_func_stats()
            workdir = Path(self.workdir).join('profiling')
            workdir.ensure(dir=True)
            stats.save(os.path.join(workdir.strpath, 'callgrind.%s' % (datetime.datetime.now(),)), 'callgrind')

        if len(std.threading.enumerate()) > 1:
            # Give a last try all threads to stop gracefully after analyzing .should_stop
            # flag. Wait max 5 seconds for them.
            all_threads_stopped = False

            for _ in range(50):
                # wait all threads excepting main thread and pidwaiter
                if len(std.threading.enumerate()) == 2:
                    all_threads_stopped = True
                    break
                gevent.sleep(0.1)

            if not all_threads_stopped:
                self.log.critical(
                    'Cant shutdown properly -- some threads still alive! {}'.format(std.threading.enumerate())
                )
                os.killpg(0, signal.SIGKILL)
                os._exit(1)

        # Last stage -- stop RPC server
        self.log.normal('Shutdown success')

    def start(self):
        self.log.debug('Starting RPC server serving %r on %s' % (self, self.rpc_sock))

        try:
            self.rpc_sock.remove()
        except py.error.ENOENT:
            pass

        self.rpc_server.register_connection_handler(self.rpc.get_connection_handler())

        # register PRC handlers here
        self.rpc.mount(self.ping)
        self.rpc.mount(self.query)
        self.rpc.mount(self.pause, typ='full')
        self.rpc.mount(self.unpause, typ='full')
        self.rpc.mount(self.state)
        self.rpc.mount(self.get_service_field)
        self.rpc.mount(self.get_service_api)
        self.rpc.mount(self.start_services, typ='full')
        self.rpc.mount(self.stop_services, typ='full')
        self.rpc.mount(self.restart_services, typ='full')
        self.rpc.mount(self.check_services)
        self.rpc.mount(self.list_namespaces)
        self.rpc.mount(self.list_services)
        self.rpc.mount(self.install_tgz, typ='full')
        self.rpc.mount(self.uninstall_services, typ='full')
        self.rpc.mount(self.shutdown_call, name='shutdown', typ='full')
        self.rpc.mount(self.force_update, typ='full')
        self.rpc.mount(self.run_debug_console, typ='full')
        self.rpc.mount(self.get_stats)

        self.rpc_server.start()
        self.rpc_sock.ensure_perms(self.uid, os.getgid(), 0o666)

        self.init_db()
        self.init_stats()

        self.service_updater.set_stats(self.stats)
        self.service_updater.read_context(restart_all=self.restart_all)
        self.service_updater.cleanup_liner_sockets()

        # Check greenlets
        assert gevent.spawn(lambda: True).get(), 'Greenlets are not working correctly!'

        self.log.normal('%r started' % self)
        super(SkycoreDaemon, self).start()

        self.started_event.set()
        self.log.debug('Allowed incoming connections')

        try:
            fail_count = 0

            while not self.should_stop:
                try:
                    gevent.sleep(1)
                    fail_count = 0
                except IOError as err:
                    fail_count += 1
                    if fail_count >= 1024:
                        raise
                    log_level = self.log.debug if fail_count < 1000 else self.log.critical
                    log_level('Main daemon hub.switch() error #%d: %s' % (fail_count, err))
        except KeyboardInterrupt:
            self.log.normal('Caught KeyboardInterrupt (SIGINT), shutting down...')
            self.shutdown()
        else:
            self.log.normal('ShouldStop flag set, shutting down...')
            self.shutdown()

    def init_db(self):
        db_path = Path(self.workdir).join('skycore.db')
        db_temp_dir = Path(self.workdir).join('tmp').ensure(dir=1)

        self.log.info('Opening database at %s', db_path)

        dbobj = Database(
            db_path.strpath,
            mmap=None,
            temp=db_temp_dir.strpath,
            parent=self
        )

        self.db_deblock = Deblock(keepalive=None, logger=dbobj.log.getChild('deblock'), name='db')
        self.db = self.db_deblock.make_proxy(dbobj, put_deblock='deblock')

        self.db.open(check=self.db.CHECK_IF_DIRTY, force=True)
        self.db.set_debug(sql=False, transactions=True)

        migrations_path = self.path.join('share', 'db')

        fw_migrations = {}
        bg_migrations = {}

        if migrations_path.check(exists=1, dir=1):
            for migration_file in migrations_path.listdir():
                if not (migration_file.basename.endswith('.sql') or migration_file.basename.endswith('.py')):
                    continue
                if migration_file.basename.startswith('fw_'):
                    target = fw_migrations
                elif migration_file.basename.startswith('bw_'):
                    target = bg_migrations
                else:
                    continue

                version = int(migration_file.basename.split('_', 2)[1])
                target[version] = migration_file

        if fw_migrations or bg_migrations:
            self.db.migrate(fw_migrations, bg_migrations)

        def _maintain(**kwargs):
            try:
                with self.db.deblock.lock('dbmaintain'):
                    self.db.maintain(**kwargs)
            except Exception as ex:
                self.log.critical('Maintainance error: %s', str(ex))

        _maintain(vacuum=True, analyze=True, grow=True)

        def _mnt(t, **kwargs):
            while 1:
                gevent.sleep(t)
                if self.should_stop:
                    break
                _maintain(**kwargs)

        gevent.spawn(_mnt, 3600, vacuum=True, analyze=True)
        gevent.spawn(_mnt, 10, grow=True)

    def init_stats(self):
        self.stats = Stats(self.db)
        self.stats.main_inc_num('skycore_start_count')

    def _report_rusage(self):
        rss = self.psproc.get_memory_info().rss
        cpu = self.psproc.get_cpu_times()
        cpu = cpu.user + cpu.system
        old_cpu = self.stats.main_get_val('skycore_last_cpu_usage', inmemory=True)
        if old_cpu is not None:
            cpu_diff = max(0, cpu - float(old_cpu))
            self.stats.main_inc_num('skycore_cpu_usage', cpu_diff)
        self.stats.main_set_val('skycore_rss', rss)
        self.stats.main_set_val('skycore_last_cpu_usage', cpu, inmemory=True)

    @Component.green_loop(logname='mbox')
    def _iterate_mailbox(self):
        mailbox().get()()
        return 0.1

    @Component.green_loop(logname='rusage')
    def _report_rusage_loop(self):
        self._report_rusage()
        return 600.

    @Component.green_loop(logname='yappi')
    def _report_yappi_stats(self):
        if not self.yappi_mode:
            return 100000000000  # why not?

        import datetime
        import yappi

        yappi.stop()
        stats = yappi.get_func_stats()
        workdir = Path(self.workdir).join('profiling')
        workdir.ensure(dir=True)
        stats.save(os.path.join(workdir.strpath, 'callgrind.%s' % (datetime.datetime.now(),)), 'callgrind')
        yappi.clear_stats()
        yappi.start(True, True)

        return self.yappi_mode

    def _check_perms(self, job):
        peer = job.peer_id[0]
        if peer != 0 and peer != self.uid:
            raise exceptions.AuthorizationError('Don`t have root privileges')

        if not sys.platform.startswith('linux'):
            # containers exist only in linux
            return

        pid = job.peer_pid

        if not same_container(pid):
            if self.portoconn is not None:
                try:
                    job_container = self.portoconn.LocateProcess(pid).name
                except:
                    job_container = None

                try:
                    our_container = self.portoconn.LocateProcess(os.getpid()).name
                except:
                    our_container = None

                if job_container != our_container:
                    raise exceptions.AuthorizationError(
                        'You are not allowed to use this RPC method from other container (%s)' % (job_container, )
                    )
            else:
                raise exceptions.AuthorizationError(
                    'You are not allowed to use this RPC method from other container'
                )

    def _get_namespace(self, namespace):
        try:
            return self.service_updater.namespaces[namespace]
        except KeyError:
            raise exceptions.NamespaceError('Unknown namespace: %s' % namespace)

    def __str__(self):
        return self.__class__.__name__
