import logging
import os
import subprocess
import threading

from concurrent import futures

from sandbox.common import rest
from sandbox.common.errors import TaskError, TaskFailure

import transport_pack


MAX_CONCURRENT_UNPACKS = 1


def build_bin_base_list(res_ids, task_ids):
    rest_client = rest.Client()
    ids = list(res_ids)
    for task_id in task_ids:
        logging.info("Using bases generated by task %s", task_id)
        task_ctx = rest_client.task[task_id].context.read()
        ids += task_ctx['bin_base_res_ids']
    return ids


class BinBasesProvider(object):

    def __init__(self, agentr, use_packed_bases):
        self.agentr = agentr
        self.providers = dict()
        self.use_packed_bases = use_packed_bases

    def provide(self, res_ids, *servers):
        """
        Ensures that all the bases needed by servers are in data dirs of all servers.
        Does NOT wait that servers say bases are ready - that needs to be done separately.
        """

        for prv in self.providers.itervalues():
            prv.drop_dest_dirs()

        cl = rest.Client()
        for server in servers:
            data_dir = server.get_bin_db_dir()
            db_list = server.get_bin_db_list()
            try:
                server_db_ver = db_list[0].split('.', 1)[0]
            except IndexError:
                raise TaskFailure("Unexpected data in db_list: {}. Expected base_ver.tag.yabs".format(db_list))
            db_list = set(db_list)
            for res_id in res_ids:
                if res_id not in self.providers:
                    self.providers[res_id] = _BaseProvider(res_id, self.agentr, cl, self.use_packed_bases)
                provider = self.providers[res_id]
                if provider.get_name_by_ver(server_db_ver) in db_list:
                    provider.add_dest_dir(data_dir, server_db_ver)

        for prv in self.providers.itervalues():
            prv.drop_unused()

        unpack_semaphore = threading.Semaphore(MAX_CONCURRENT_UNPACKS)  # and hope there will be no OOM

        with futures.ThreadPoolExecutor(max_workers=18) as pool:
            # schedule syncs, unpacks and symlinks
            fts = [pool.submit(prv.provide, unpack_semaphore) for prv in self.providers.itervalues()]
            # wait for them to finish
            for ft in futures.as_completed(fts):
                db_name = ft.result()
                logging.info("%s ready", db_name)

        logging.info("All bases of yabs-servers are ready.")


class _BaseProvider(object):
    def __init__(self, res_id, agentr, rest_api_client=None, use_packed=False):
        rest_api_client = rest_api_client or rest.Client()
        res_info = rest_api_client.resource[res_id].read()

        self._agentr = agentr

        self._res_id = res_id
        attrs = res_info["attributes"]
        self._db_ver = attrs['db_ver']
        self._internal_db_ver = attrs.get('internal_ver')
        try:
            self._name_template = "{db_ver}" + ".{}.yabs".format(attrs.get('type') or attrs['tag'])
            self._name_by_attr = self._name_template.format(db_ver=self._db_ver)
        except KeyError as exc:
            raise TaskError(
                "Binary base resource {} has no attribute {}".format(res_id, str(exc))
            )
        self._filename_template = self._name_template + ('.zstd_7' if use_packed else '')

        pack_type = attrs.get('pack', 'not specified')
        self.unpack_cmd = self._get_unpack_cmd(pack_type, use_packed)
        self._path = None
        self._dest_dirs = set()
        self._links = set()

    def _get_unpack_cmd(self, pack_type, use_packed=False):
        if use_packed:
            if pack_type == 'tr':
                return "ln -sf {src} {dst}"

            raise TaskError("Cannot use non-packed bases with packer {}".format(pack_type))

        if pack_type == 'gz':
            return "gunzip -c {src} 1>{dst}_ && mv {dst}_ {dst}"
        elif pack_type == 'tr':
            return transport_pack.TransportPack().decompress_cmd()
        raise TaskError(
            "Binary base resource {} packed by unknown packer ({})".format(self._res_id, pack_type)
        )

    @property
    def name(self):
        """Full base name (123456.dbe.yabs)"""
        return self._name_by_attr

    def get_name_by_ver(self, server_db_ver):
        """Full base name (123456.dbe.yabs)"""
        return self._name_template.format(db_ver=server_db_ver)

    def drop_dest_dirs(self):
        self._dest_dirs = set()

    def add_dest_dir(self, dest, server_db_ver):
        self._dest_dirs.add((os.path.abspath(dest), server_db_ver))

    def provide(self, unpack_semaphore):
        """
        Ensure that all destionation directories have either the base or a symlink to it.
        Sync & unpack if there is no unpacked base.
        """
        try:
            return self._provide(unpack_semaphore)
        except:
            logging.exception("Failed to provide base %s", self.name)
            raise

    def drop_unused(self):
        for link in self._links:
            os.remove(link)
        if not self._dest_dirs and self._path is not None:
            logging.info("Base %s not needed, dropping", self.name)
            os.remove(self._path)
            self._path = None

    def _provide(self, unpack_semaphore):
        if not self._dest_dirs:
            logging.info("No destination dirs were added for base %s, not syncing and not unpacking", self.name)
            return

        if self._path and not os.path.exists(self._path):
            logging.warning("%s disappeared somehow", self._path)
            self._path = None
        paths = set()
        for dest_dir, server_db_ver in self._dest_dirs:
            paths.add(os.path.join(dest_dir, self._filename_template.format(db_ver=server_db_ver)))
            if int(self._db_ver) != int(server_db_ver):
                logging.warn('There is base with other ver {}. Version was changed to {}. Internal version is {}'.format(self._db_ver, server_db_ver, self._internal_db_ver))
        logging.info("Providing base %s to %s", self.name, paths)

        if self._path:
            logging.debug("Found %s at %s", self.name, self._path)
            paths.discard(self._path)
        else:
            self._path = paths.pop()
            self._synpack(unpack_semaphore)

        self._links = paths

        for link_path in self._links:
            try:
                os.symlink(self._path, link_path)
            except Exception as exc:
                raise RuntimeError(
                    "Failed to symlink base %s (%s->%s): %s",
                    self.name,
                    self._path,
                    link_path,
                    exc
                )
        return self.name

    def _synpack(self, unpack_semaphore):
        logging.info("Syncing %s (resource id %s)", self.name, self._res_id)
        src = self._agentr.resource_sync(self._res_id)

        cmdline = self.unpack_cmd.format(src=src, dst=self._path)
        logging.info('Synced %s, waiting for semaphore to unpack', self.name)
        with unpack_semaphore:
            logging.info('Unpacking %s via %s', self.name, cmdline)
            try:
                subprocess.check_output(cmdline, stderr=subprocess.STDOUT, shell=True)
            except subprocess.CalledProcessError as exc:
                raise TaskFailure("%s failed with exitcode %s:\n%s\n" % (exc.cmd, exc.returncode, exc.output))
        logging.info("Base %s ready", self.name)
