# coding: utf-8
import os
import sys
import glob
import shutil
import contextlib
from collections import defaultdict

import six
import gevent
try:
    from gevent.coros import RLock
except ImportError:
    from gevent.lock import RLock

import api.skycore.errors as exceptions

from ..service_config import ServiceConfig
from ..framework.component import Component
from ..framework.greendeblock import Deblock
from ..framework.utils import Path
from ..kernel_util.sys import TempDir
from ..downloader import download, parse_advertised_release
from .. import tartools as tarfile
from ..dirutils import unpack_archive_fileobj
from .service import Service, ServiceCorrupted, State


def replace_dir(source_root, target_root, log=None):
    def log_error(func, path, exc_info):
        log and log.error("replace_dir: failed to %r path %r:", func.__name__, path, exc_info=exc_info)

    old_temp = target_root + '.old'
    if os.path.exists(target_root):
        os.rename(target_root, old_temp)
        renamed = True
    else:
        renamed = False
    try:
        try:
            shutil.move(source_root, target_root)
            os.chmod(target_root, 0o755)
        except:
            exc_info = sys.exc_info()
            try:
                shutil.rmtree(source_root, onerror=log_error)
            finally:
                six.reraise(*exc_info)
    except:
        if renamed:
            os.rename(old_temp, target_root)
        raise


class Namespace(Component):
    def __init__(self,
                 starter,
                 deblock,
                 output_logger_factory,
                 context_changed_event,
                 name=None,
                 workdir=None,
                 rundir=None,
                 linkdir=None,
                 apidir=None,
                 skynetdir=None,
                 supervisordir=None,
                 registry=None,
                 parent=None,
                 context=None,
                 stats=None,
                 reporter=None,
                 restart_all=None,
                 ):
        """
        :param Starter starter: starter of processes
        :param ..framework.greendeblock.Deblock deblock: deblock for i/o operations
        :param output_logger_factory:
        :param gevent.coros.Semaphore context_changed_event: event set when some context data changed
        :param str name: namespace name
        :param str workdir: where service sources stored
        :param str rundir: where service data stored
        :param str linkdir: where links to service parts are aggregated
        :param str apidir: where files with API are stored
        :param str skynetdir: root dir of skynet
        :param str supervisordir: supervisor dir of skynet
        :param ConfigUpdater registry: config registry to use for services
        :param Component parent: component owning this one
        :param dict context: saved state of namespace with services
        :param ..stats.Stats stats: stats controller
        :param .statereporter.StateReporter reporter: object reporting services status changes to heartbeat
        :param str restart_all: if 'ALL', all imported, if 'SKYDEPS', all skydeps dependent services will be restarted
        """
        self.starter = starter
        self.deblock = deblock

        self.name = name
        self.workdir = workdir
        self.rundir = rundir
        self.linkdir = linkdir
        self.apidir = apidir
        self.skynetdir = skynetdir
        self.supervisordir = supervisordir
        self.registry = registry
        self.reporter = reporter
        self.output_logger_factory = output_logger_factory
        self.context_changed_event = context_changed_event
        self.deps = defaultdict(set)
        self.rdeps = defaultdict(set)
        self.services = {}
        self.services_meta = {}
        self.hashes = {}
        self.rhashes = defaultdict(set)
        self.stats = stats

        self.operation_lock = RLock()

        super(Namespace, self).__init__(parent=parent, logname=None, log_msg_prefix=name)

        with self.operation_lock:
            try:
                if context is not None:
                    self.name = context['namespace']
                    self.workdir = (
                        context['workdir']
                        if os.path.isabs(context['workdir'])
                        else os.path.join(self.base, context['workdir'])
                    )
                    self.rundir = (
                        context['rundir']
                        if os.path.isabs(context['rundir'])
                        else os.path.join(self.base, context['rundir'])
                    )
                    for service_ctx in context['services']:
                        try:
                            service = Service(namespace=self,
                                              starter=self.starter,
                                              deblock=self.deblock,
                                              registry=registry,
                                              output_logger_factory=self.output_logger_factory,
                                              context_changed_event=self.context_changed_event,
                                              apifile=os.path.join(self.apidir, service_ctx['service'] + '.api'),
                                              fieldsfile=os.path.join(self.apidir,
                                                                      service_ctx['service'] + '.fields'),
                                              context=service_ctx,
                                              stats=self.stats,
                                              reporter=self.reporter,
                                              force_restart=restart_all,
                                              )
                            unit_dict = service_ctx.get('unit')
                            if unit_dict is not None:
                                unit = InstallationUnit.from_dict(unit_dict, self.deblock)
                            else:
                                meta = service_ctx.get('meta', {'md5': service_ctx.get('md5')})
                                unit = InstallationUnit('<old installation>', meta, self.deblock)
                        except ServiceCorrupted:
                            pass
                        else:
                            self._add_service(service, unit)
                            self.log.debug("added service %s" % (service,))
            except Exception:
                self.log.exception("Namespace creation failed", exc_info=sys.exc_info())
                if parent is not None:
                    parent.childs.remove(self)
                raise

    def _check_services(self, names):
        names = names or list(self.services.keys())
        unknown_services = filter(lambda x: x not in self.services, names)
        if unknown_services:
            raise exceptions.ServiceNotFoundError(
                "Service(s): {} not found in this namespace".format(unknown_services)
            )

        return names

    def subscribe(self, name, dependency):
        self.deps[name].add(dependency)
        self.rdeps[dependency].add(name)

    def _update_deps_state(self, name):
        services = self.rdeps[name].copy()
        for srvc in services:
            ready = all(dep in self.services for dep in self.deps[srvc])
            self.services[srvc].deps_state_changed(ready)

    def start_services(self, names=None, timeout=90.0):
        with self.operation_lock:
            names = self._check_services(names)
            starts = [gevent.spawn(self.services[name].start_service) for name in names]

            try:
                with gevent.Timeout(timeout):
                    gevent.joinall(starts)
            except gevent.Timeout:
                not_running = [name for name in names if self.services[name].state != State.RUNNING]
                raise exceptions.SkycoreError("services (%s) start timed out" % (', '.join(not_running),))

    def stop_services(self, names=None, timeout=180.0):
        with self.operation_lock:
            names = self._check_services(names)
            stops = [gevent.spawn(self.services[name].stop_service, timeout=timeout) for name in names]

            try:
                with gevent.Timeout(timeout):
                    gevent.joinall(stops)
            except gevent.Timeout:
                not_stopped = [name for name in names if self.services[name].state != State.STOPPED]
                raise exceptions.SkycoreError("services (%s) stop timed out" % (', '.join(not_stopped),))

    def _add_service(self, service, unit):
        md5 = unit.md5
        self.log.info("Installing service %s [%s]", service.name, md5)
        self.services[service.name] = service
        self.hashes[service.name] = md5
        self.rhashes[md5].add(service.name)
        self.services_meta[service.name] = unit
        for dep in service.dependencies:
            self.subscribe(service.name, dep)

        self._update_deps_state(service.name)
        service.deps_state_changed(all(dep in self.services for dep in self.deps[service.name]))

        dst = os.path.join(self.workdir, service.name)
        self.log.info("Making symlink: %s -> %s", dst, service.cfg.basepath)
        service.use_symlink(dst)
        service.write_api()
        if self.linkdir is not None:
            service.install_service_symlinks(self.linkdir)

        if self.reporter:
            self.reporter.send_state_report(
                self.name,
                service.name,
                service.state,
                str(unit.release.get('version', None))
            )

    def remove_service(self, name):
        md5 = self.hashes.get(name)
        if md5 is not None:
            self.log.info("Removing service %s [%s]", name, md5)

        self.deps.pop(name, None)

        for rdeps in self.rdeps.itervalues():
            rdeps.discard(name)

        srvc = self.services.pop(name, None)
        self._update_deps_state(name)

        if not self.rdeps[name]:
            self.rdeps.pop(name, None)

        if srvc is not None and srvc in self.childs:
            self.childs.remove(srvc)

        self.hashes.pop(name, None)
        if md5 is not None:
            self.rhashes[md5].discard(name)
            if not self.rhashes[md5]:
                self.rhashes.pop(md5)

        self.services_meta.pop(name, None)

        return srvc

    def get_service_api(self, name, kind):
        return self.services[name].get_api(kind=kind)

    def get_service_field(self, name, field, raw):
        return self.services[name].get_field(field, raw)

    def get_service_metainfo(self, name):
        return self.services_meta.get(name)

    def get_states(self, names=None, new_format=False):
        names = self._check_services(names)
        res = {}

        if names:
            res = {
                name: {
                    'state': self.services[name].state,
                    'version': self.services_meta[name].get('version', None),
                    'state_uptime': self.services[name].current_state_uptime,
                } for name in names
            } if new_format else {
                name: (self.services[name].state, self.services_meta[name].get('version', None)) for name in names
            }

        return res

    @property
    def base(self):
        return self.parent and self.parent().base

    @property
    def context(self):
        return {
            'namespace': self.name,
            'workdir': Path(self.workdir).relto(self.base),
            'rundir': Path(self.rundir).relto(self.base),
            'services': [service.context for service in list(self.services.values())],
        }

    def install_services(self, unit):
        with self.operation_lock:
            md5 = unit.md5
            target_root = os.path.join(self.workdir, md5)

            if md5 in self.rhashes:
                # self.log.debug("[install] skipping install %s from %s: already installed")
                dirty = False
                for name in self.rhashes[md5]:
                    meta = self.get_service_metainfo(name)
                    if meta and meta.dirty:
                        self.log.warning(
                            "old installation md5 mismatch for %r and appears to be broken, replacing files",
                            name
                        )
                        unit.move(target_root, self.log)
                        meta.dirty = False
                        dirty = True
                if not dirty:
                    return
            else:
                try:
                    # self.log.debug("[install] moving %r → %r", srcdir, target_root)
                    unit.move(target_root, self.log)
                except Exception as e:
                    raise Exception("cannot move install srcdir %r to %r" % (unit.datadir, target_root))

            try:
                # FIXME use cfgs already loaded in unit (basepath has to be changed)
                service_cfgs = unit.find_services_in_dir(target_root, self.log.getChild('srvc'))
            except (TypeError, ValueError, EnvironmentError) as e:
                raise Exception("cannot read service spec from %r: %s" % (target_root, e))

            for service_cfg in service_cfgs:
                try:
                    yield (Service(cfg=service_cfg,
                                   namespace=self,
                                   starter=self.starter,
                                   deblock=self.deblock,
                                   registry=self.registry,
                                   apifile=os.path.join(self.apidir, service_cfg.name + '.api'),
                                   fieldsfile=os.path.join(self.apidir, service_cfg.name + '.fields'),
                                   output_logger_factory=self.output_logger_factory,
                                   context_changed_event=self.context_changed_event,
                                   stats=self.stats,
                                   reporter=self.reporter),
                           unit)
                except Exception as e:
                    raise Exception("failed to create service %r infra: %s" % (service_cfg.name, e))

    def uninstall_service(self, name):
        if name not in self.services:
            return
        service = self.services[name]
        to_stop = list(self.rdeps.get(name, [])) + [name]

        with self.operation_lock:
            if service in self.childs:  # component is already stopped otherwise
                try:
                    self.stop_services(to_stop, timeout=Service.STATE_CHANGE_TIMEOUT * 3.5)
                except Exception:
                    self.log.warning("cannot stop the service with its rdeps, timeout: %s", name)
                    raise Exception("Service stop timed out: %s" % (name,))

            try:
                service.on_uninstall(upgrade=False)
            except Exception:
                self.log.exception("uninstall script failed, nevertheless removing it", exc_info=sys.exc_info())

            removed = self.remove_service(name)
            if removed is not None:
                removed.stop()

            if self.stats:
                self.stats.service_remove_inmemory(self.name, name)
            self.context_changed_event.release()
            return removed

    def report_rusage(self):
        for service in list(self.services.values()):
            service.report_rusage()

    @Deblock.wrap_fun
    def cleanup(self):
        self._cleanup_apidir()
        self._cleanup_linkdir()
        self._cleanup_workdir()

    def _cleanup_workdir(self):
        def log_error(func, path, exc_info):
            self.log.error("cleanup: failed to %r path %r:", func.__name__, path, exc_info=exc_info)

        directory = self.workdir
        # do nothing if workdir is missing
        if not os.path.exists(directory):
            return

        try:
            real_directory = os.path.realpath(directory)
            active_hashes = self.hashes.values()
            active_services = self.hashes.keys()
            for item in os.listdir(directory):
                dirname = os.path.join(directory, item)

                if os.path.isdir(dirname) and item in active_hashes:
                    continue
                elif os.path.islink(dirname) and item in active_services:
                    continue

                if not os.path.islink(dirname):  # uninstalled service
                    shutil.rmtree(dirname, onerror=log_error)
                elif not os.path.exists(dirname):  # invalid symlink
                    os.unlink(dirname)
                else:
                    realpath = os.path.realpath(dirname)
                    if not realpath.startswith(real_directory) or os.path.basename(realpath) not in active_hashes:
                        os.unlink(dirname)
            if not os.listdir(directory):
                os.rmdir(directory)
        except Exception as e:
            self.log.error("workdir cleanup failed: %s", self.workdir, e)

    def _cleanup_linkdir(self):
        def log_error(func, path, exc_info):
            self.log.error("cleanup: failed to %r path %r:", func.__name__, path, exc_info=exc_info)

        directory = self.linkdir
        if directory is None or not os.path.exists(directory):
            return

        try:
            active_names = self.services.keys()
            for item in os.listdir(directory):
                if item in active_names:
                    continue

                dirname = os.path.join(directory, item)
                if os.path.islink(dirname) or not os.path.isdir(dirname):
                    os.unlink(dirname)
                else:
                    shutil.rmtree(dirname, onerror=log_error)

            if not os.listdir(directory):
                os.rmdir(directory)
        except Exception as e:
            self.log.error("linkdir cleanup failed: %s", e)

    def _cleanup_apidir(self):
        def log_error(func, path, exc_info):
            self.log.error("cleanup: failed to %r path %r:", func.__name__, path, exc_info=exc_info)

        directory = self.apidir
        if directory is None or not os.path.exists(directory):
            return

        try:
            active_names = self.services.keys()
            for item in os.listdir(directory):
                short = (item[:-4] if item.endswith('.api')
                         else item[:-7] if item.endswith('.fields')
                         else None)
                if short in active_names:
                    continue

                name = os.path.join(directory, item)
                if os.path.islink(name) or not os.path.isdir(name):
                    os.unlink(name)
                else:
                    shutil.rmtree(name, onerror=log_error)

            if not os.listdir(directory):
                os.rmdir(directory)
        except Exception as e:
            self.log.error("apidir cleanup failed: %s", e)

    @contextlib.contextmanager
    def _auto_rollback_service(self, log, message, new_service, old_service, old_uninstalled, new_installed):
        try:
            yield
        except Exception:
            try:
                if old_service is None:
                    log.exception("new service %s %s failed, rolling it back", message, new_service)
                else:
                    log.exception("new service %s %s failed, rolling back to %s", message, new_service, old_service)

                new_service.pause_fsa()
                if new_service in self.childs:
                    self.childs.remove(new_service)

                if new_installed:
                    try:
                        # log.debug("[inject] old service on_uninstall: %s", old_service)
                        new_service.on_uninstall(upgrade=old_service is not None)
                    except Exception:
                        log.exception("new service %s uninstall script failed, nevertheless continuing", new_service)

                if old_service is not None and (old_uninstalled or new_installed):
                    try:
                        # log.debug("[inject] old service on_uninstall: %s", old_service)
                        old_service.on_install(upgrade=True)
                    except Exception:
                        log.exception("old service %s install script failed, nevertheless continuing", old_service)

                new_service.stop()

                if old_service is not None:
                    old_service.attach_procs(new_service)
                    old_service.unpause_fsa()
                    old_service.start_service()
            finally:
                raise Exception("failed to enable service %s, service is removed" % (new_service,))

    def inject_one(self, service, unit, log=None, start_on_install=True):
        log = log or self.log
        old_service = self.services.get(service.name)
        # log.debug("[inject] removed old service %s", old_service)
        needs_restart = service.needs_restart or old_service is None

        # log.debug("[inject] starting new service loops: %s", service)
        with self._auto_rollback_service(log, "enable", service, old_service, False, False):
            service.start()

            if service not in self.childs:
                raise Exception("service loops start failed")

        service.pause_fsa()

        if old_service:
            needs_restart |= old_service.needs_restart
            log.debug(
                'service %r needs_restart=%r (old_service: %r, new_service: %r)',
                service.name, needs_restart, old_service.needs_restart, service.needs_restart
            )
            try:
                # log.debug("[inject] old service on_uninstall: %s", old_service)
                old_service.on_uninstall(upgrade=True)
            except Exception:
                log.exception("preupgrade uninstall script failed, nevertheless continuing", exc_info=sys.exc_info())
        else:
            log.debug('service %r needs_restart=%r', service.name, needs_restart)

        with self._auto_rollback_service(log, "install script", service, old_service, True, False):
            # log.debug("[inject] new service on_install: %s", service)
            service.on_install(upgrade=old_service is not None)

        if old_service and (needs_restart or not start_on_install):
            # log.debug("[inject] stopping old service %s", old_service)
            old_service.stop_service(timeout=Service.STATE_CHANGE_TIMEOUT * 3.5)

        service.attach_procs(old_service)

        if old_service and not needs_restart and start_on_install:
            # log.debug("[inject] migrating old service procs and state to new service %s", service)
            service._state_controller.state = old_service.state
            service.set_required_state(old_service._state_controller.required_state)
            old_service.pause_fsa()

        service.unpause_fsa()
        if not start_on_install:
            log.info("skipping start for service %r", service.name)
        elif needs_restart:
            log.info("service %r will be started", service.name)
            with self._auto_rollback_service(log, "start service", service, old_service, True, True):
                service.start_service()

            if old_service is not None:
                old_service.stop()
        else:
            log.info("service %r doesn't need restart upon upgrade", service.name)

        # log.debug("[inject] attaching new service to namespace: %s", service)
        if old_service is not None:
            log.info("removing old service %s from namespace", old_service)
            self.remove_service(old_service.name)
        log.info("saving new service %s into namespace", service)
        self._add_service(service, unit)
        self.context_changed_event.release()

    def inject_services(self, installed, log=None, start_on_install=True):
        log = log or self.log
        starts = []
        failed = {srvc.name for srvc, _ in installed}
        with self.operation_lock:
            for srvc, unit in installed:
                # log.debug("[inject] injecting %s as %s", unit.md5, srvc)
                start = gevent.spawn(self.inject_one, srvc, unit, log, start_on_install=start_on_install)
                starts.append(start)
                start.link_value((lambda _, srvc=srvc: failed.discard(srvc.name)))

            # log.debug("[inject] waiting for services to start: %s", installed)
            gevent.joinall(starts)
            if failed:
                self.log.error("%r: failed to inject services: %s", self.name, list(failed))
                raise Exception("failed to inject services: %s" % (list(failed)),)

    def start(self):
        for meth, loop in self.loops.iteritems():
            assert loop is None, '%r %r already running' % (meth, loop)
            self.loops[meth] = gevent.spawn(meth, **meth.extra)
            self.log.debug("Started loop %r %r", meth, self.loops[meth])

        for child in list(self.childs):
            if isinstance(child, Service):
                child.start()
                if child not in self.childs:
                    self.remove_service(child.name)
            else:
                child.start()

        return self

    def __eq__(self, other):
        return (self.name == other.name
                and self.workdir == other.workdir
                and self.rundir == other.rundir
                and self.deps == other.deps
                and self.services == other.services
                and self.hashes == other.hashes
                )

    def __str__(self):
        return '%s "%s"' % (self.__class__.__name__, self.name)


class InstallationUnit(object):
    def __init__(self, section, release, deblock):
        self.deblock = deblock
        self.section = section
        self.urls = release.pop('urls', {})
        self.release = release
        self.content_md5 = None
        self.content_paths = None
        self.archive = None
        self.tempdir = None
        self.service_cfgs = {}
        self.conflicting = False
        self.dirty = False

    @classmethod
    def from_cfg(cls, section, section_cfg, deblock):
        if section_cfg is None or section_cfg['config'] is None:
            return

        try:
            release = parse_advertised_release(section_cfg['config'])
        except Exception as e:
            raise Exception("failed to parse release for section %r: %s" % (section, e))

        return cls(section, release, deblock)

    @property
    def download_size(self):
        # sandbox lies about exact size, so we have to add 1KB for each resource being downloaded
        if 'size' not in self.release:
            return 0
        return self.release['size'] + 1024

    @property
    def md5(self):
        return self.release['md5']

    @property
    def datadir(self):
        return self.tempdir.dir() if self.tempdir is not None else None

    def download(self, downloaddir, log, skybone_available=True):
        filename = self.release['filename']

        log.info("will download %s (section %r)", self.md5, self.section)
        try:
            binary = download(downloaddir, filename, self.urls, self.md5, self.deblock, skybone_available=skybone_available)
        except Exception as e:
            raise Exception("failed to download %s (section %r): %s" % (self.md5, self.section, e))

        if binary is None:
            raise Exception("failed to download %s (section %r)" % (self.md5, self.section))

        self.archive = os.path.join(downloaddir, filename)

    @Deblock.wrap_fun
    def extract(self, downloaddir, log, fsyncqueue=None):
        self.tempdir = TempDir(dir=downloaddir)
        self.tempdir.open()

        try:
            self.content_md5, self.content_paths = extract(self.archive,
                                                           self.tempdir.dir(),
                                                           log=log.getChild('extract'),
                                                           fsyncqueue=fsyncqueue,
                                                           )
        except Exception as e:
            raise Exception("failed to extract %s (section %r): %s" % (self.md5, self.section, e))

    @Deblock.wrap_fun
    def move(self, target_root, log):
        replace_dir(self.datadir, target_root, log)

    @Deblock.wrap_fun
    def find_services_in_dir(self, datadir, log):
        service_cfgs = []
        try:
            for path in glob.iglob(os.path.join(datadir, '*.scsd')):
                # log.debug("[install] loading service config from %r", path)
                cfg = ServiceConfig.from_path(path, log)
                service_cfgs.append(cfg)
        except (TypeError, ValueError, EnvironmentError):
            log.warning("Cannot read service spec from %r" % (path,), exc_info=sys.exc_info())
            raise
        return service_cfgs

    def collect_services(self, log):
        for cfg in self.find_services_in_dir(self.datadir, log):
            # we're just checking description validity before install
            self.service_cfgs[cfg.name] = cfg

        if not self.service_cfgs:
            log.warning("No services found in %r", self.datadir)
            raise ValueError("No services found in %r" % (self.datadir,))

    @Deblock.wrap_fun
    def close(self):
        if self.tempdir is not None:
            self.tempdir.close()
            self.tempdir = None

    def as_dict(self):
        return {
            'section': self.section,
            'urls': None,
            'release': self.release,
            'content_md5': self.content_md5,
            'content_paths': self.content_paths,
            'archive': self.archive,
            'tempdir': None,
            'service_cfgs': {},
            'conflicting': False,
            'dirty': self.dirty,
        }

    @classmethod
    def from_dict(cls, data, deblock):
        unit = cls(data['section'], data['release'], deblock)
        for attr in ('urls',
                     'content_md5',
                     'content_paths',
                     'archive',
                     'tempdir',
                     'service_cfgs',
                     'conflicting',
                     'dirty'):
            setattr(unit, attr, data.get(attr))

        return unit


def extract(archive, dest, log=None, fsyncqueue=None):
    data = unpack_archive_fileobj(archive)
    with tarfile.open(name=archive, fileobj=data, log=log, fsyncqueue=fsyncqueue) as tar:
        tar.extractall(dest)
        if os.getuid() == 0:
            Path(dest).chown(0, 0, rec=True)
        return (
            tar.md5(),
            filter(lambda name: name != '.',
                   (os.path.relpath(item.name, '.')
                    for item in sorted(tar, key=lambda path: path.name)))
        )
