from __future__ import absolute_import

import os
import sys
import time
import stat
import shutil
import logging
import tempfile
import itertools
import collections

import six

from .. import config
from .. import format
from .. import errors
from .. import patterns
from ..types import client as ctc

#: Limit download speed for all resources greater than 512MiB if backbone mode was requested
BACKBONE_DL_LIMIT_SIZE_THRESHOLD = 512 << 20

logger = logging.getLogger(__name__)


def _check_cqueue_results(results):
    """
    Summarize errors happened during ``api.cqueue.Client()``'s run and re-raise them in a single exception.

    :param results: return value of ``api.cqueue.Client().run().wait()``
    :return: None
    :raise: Exception with multiple errors and stacktraces, if available
    """

    errors = []
    for host, result, failure in results:
        if failure:
            if hasattr(failure, '_traceback'):
                errors.append('Error at %s: %s' % (host, failure._traceback))
            else:
                errors.append('Error at %s: %s' % (host, failure))

    if errors:
        raise Exception('\n'.join(errors))


@patterns.singleton
def supported_copier_transports():
    import api.copier
    return list(itertools.chain.from_iterable(magics for magics, _ in api.copier.copierClass().transportsInfo()))


def files_torrent(torrent_id, timeout=5 * 60, **kwargs):
    """
    Obtain torrent metadata

    :param torrent_id: torrent identifier, starts with "rbtorrent"
    :param timeout: maximal amount of time to wait
    :rtype: list of dicts with keys

        :name: file path relative to resource's root (e.g. `my file` or `path/to/my file`)
        :size: size in bytes
        :executable: whether the file is executable or not
        :type: file, dir or symlink
        :md5sum: md5 checksum of the whole file
    """

    import api.copier
    return api.copier.Copier().list(torrent_id, timeout=timeout, **kwargs).files()


def get_share_size(skynet_id, **kwargs):
    """
        Calculate the total size in bytes of files in a share
    """
    return sum(d['size'] for d in files_torrent(skynet_id, **kwargs) if d['type'] == 'file')


def calculate_timeout(share_size, mbps=96, fixed_time=600):
    """
    Calculate timeout for sky get based on channel throughtput (1Gbit / 7 concurrent threads).
    XXX: This method doesn't count number of copies of the torrent.
    """
    return max(share_size / mbps / (1 << 17), fixed_time)


class ShareSandboxResource:
    marshaledModules = __name__.split('.', 1)[:1]  # Avoid marshalling anything except 'common' package.

    def __init__(self, user, files, cwd=None):
        if user:
            self.osUser = user
        self.files = files
        self.cwd = cwd

    def run(self):
        import api.copier
        c = api.copier.Copier()
        handler = c.createExEx(self.files, cwd=self.cwd)
        return handler.wait().resid()


class RemoteListFiles:
    marshaledModules = __name__.split('.', 1)[:1]  # Avoid marshalling anything except 'common' package.

    def __init__(self, user, path):
        if user:
            self.osUser = user
        self.path = path

    def run(self):
        result = []
        if not os.path.exists(self.path):
            import socket
            raise errors.TaskError(
                "Remote directory '{}' at '{}' does not exist".format(self.path, socket.gethostname())
            )

        if os.path.isfile(self.path):
            return [os.path.basename(self.path)], 1

        for root, dirs, files in os.walk(self.path, topdown=False):
            result.extend(map(lambda x: os.path.relpath(os.path.join(root, x), self.path), files))
        return result, 0


class ShareChangingFiles:
    marshaledModules = __name__.split('.', 1)[:1]  # Avoid marshalling anything except 'common' package.

    def __init__(self, user, files, cwd):
        if user:
            self.osUser = user
        self.files = files
        self.cwd = cwd

    def run(self):
        import api.copier
        share_dir = tempfile.mkdtemp()

        # copy files
        for file in self.files:
            file_dir = os.path.dirname(file)
            if file_dir:
                os.makedirs(os.path.join(share_dir, file_dir))
            shutil.copy(os.path.join(self.cwd, file), os.path.join(share_dir, file_dir))

        # share
        c = api.copier.Copier()
        handler = c.create(os.listdir(share_dir), cwd=share_dir)
        return share_dir, handler.resid()


class ShareAndCopyHeadFiles:
    marshaledModules = __name__.split('.', 1)[:1]  # Avoid marshalling anything except 'common' package.

    def __init__(self, user, files, cwd, head):
        if user:
            self.osUser = user
        self.files = files
        self.cwd = cwd
        self.head = head

    @staticmethod
    def chmod_recursive(path, mode):
        for root, dirs, files in os.walk(path):
            os.chmod(root, mode)
            for f in files:
                os.chmod(os.path.join(root, f), mode)
            for d in dirs:
                ShareAndCopyHeadFiles.chmod_recursive(os.path.join(root, d), mode)

    def run(self):
        import api.copier
        share_dir = tempfile.mkdtemp()
        # copy head for files
        for file in self.files:
            file_dir = os.path.dirname(file)
            if file_dir:
                os.makedirs(os.path.join(share_dir, file_dir))
            copied_lines = 1
            with open(os.path.join(share_dir, file), "wb") as fOut:
                for line in open(os.path.join(self.cwd, file), "rb"):
                    fOut.write(line)
                    copied_lines += 1
                    if copied_lines > self.head:
                        break
        mode_all_rwx = (
            stat.S_IRUSR | stat.S_IWUSR | stat.S_IXUSR |
            stat.S_IRGRP | stat.S_IWGRP | stat.S_IXGRP |
            stat.S_IROTH | stat.S_IWOTH | stat.S_IXOTH
        )
        ShareAndCopyHeadFiles.chmod_recursive(share_dir, mode_all_rwx)

        # share
        c = api.copier.Copier()
        handler = c.create(os.listdir(share_dir), cwd=share_dir)
        return share_dir, handler.resid()


class DeleteSharedFiles:
    marshaledModules = __name__.split('.', 1)[:1]  # Avoid marshalling anything except 'common' package.

    def __init__(self, user, path):
        if user:
            self.osUser = user
        self.path = path

    def run(self):
        shutil.rmtree(self.path)


def skynet_share(cwd, sub_path):
    # some functions do not accept unicode strings, so convert unicode string to regular string
    cwd = str(cwd)
    sub_path = str(sub_path)
    logger.info("Share with copier '%s' in folder '%s'", sub_path, cwd)
    return ShareSandboxResource(None, [sub_path], cwd=cwd).run()


def skynet_get(skynet_id, data_dir, timeout=10800, fallback_to_bb=False, size=None, log_progress=True, **kwargs):
    """
    Download share. Equivalent to sky get -u -t ...
    if timeout is None, it will be calculated dynamically, based of share size
    """
    import api.copier

    copier = api.copier.Copier()
    _logger = kwargs.pop("logger", None) or logger

    settings = config.Registry()
    _backbone_dl_limits = {
        "network": api.copier.Network.Backbone,
        "max_dl_speed": settings.common.network.backbone.limits.dl,
        "max_ul_speed": settings.common.network.backbone.limits.ul,
    }

    kwargs.update({"dest": data_dir, "user": True, "subproc": True})

    if sys.platform.startswith("darwin"):
        _logger.info("Using Backbone on OSX host")
        default_network = api.copier.Network.Backbone
    else:
        default_network = api.copier.Network.Fastbone
    kwargs.setdefault("network", default_network)

    if ctc.Tag.STORAGE in settings.client.tags:
        kwargs["deduplicate"] = api.copier.Deduplicate.Hardlink

    if not timeout and timeout is not None:
        size = size or get_share_size(skynet_id, **kwargs)
        timeout = calculate_timeout(size)
    _logger.debug(
        "Resource %s size is %s, timeout is: %r.",
        skynet_id, format.size2str(size or 0), format.td2str(int(timeout)) if timeout else timeout
    )
    if kwargs.get("network") == api.copier.Network.Backbone and size > BACKBONE_DL_LIMIT_SIZE_THRESHOLD:
        kwargs.update(_backbone_dl_limits)

    def _reportive_get(skynet_id, kwargs):
        prev_t = collections.namedtuple("Prev", ("stage", "ts"))
        prev = prev_t(None, 0)
        h = copier.get(skynet_id, **kwargs)
        try:
            try:
                # A kind of hack - get skybone's job ID
                skyjid = ":".join(map(str, (h._IGet__slave.job.sid, h._IGet__slave.job.jid)))
            except Exception:
                skyjid = None
            _logger.info("Fetching resource %s with skybone {%s} with arguments %r", skynet_id, skyjid, kwargs)

            for pr in h.iter(timeout=timeout, state_version=1):
                now = time.time()
                if prev.stage != pr.stage or now - prev.ts > 15:
                    _logger.debug("%s progress: %s.", skynet_id, pr)
                    prev = prev_t(pr.stage, now)
            h.wait(timeout=timeout)
        except BaseException:
            ei = sys.exc_info()
            try:
                h.stop()
            except BaseException:
                pass
            six.reraise(ei[0], ei[1], ei[2])

    try:
        _reportive_get(skynet_id, kwargs)
    except Exception as ex:
        if kwargs.get("network") == api.copier.Network.Backbone:
            fallback_to_bb = False
        if isinstance(ex, api.copier.errors.ResourceNotAllowedByNetwork) and fallback_to_bb:
            _logger.warning("Resource '%s' is unavailable via fastbone (%s). Trying backbone.", ex, skynet_id)
        elif isinstance(ex, api.copier.errors.ResourceDownloadError) and fallback_to_bb:
            _logger.warning("Resource '%s' cannot be downloaded via fastbone (%s). Trying backbone.", ex, skynet_id)
        elif isinstance(ex, api.copier.errors.Timeout) and fallback_to_bb:
            _logger.warning("Resource '%s' timed out via fastbone (%s). Trying backbone.", ex, skynet_id)
        elif "database or disk is full" in str(ex):
            raise errors.TemporaryError(ex)
        else:
            raise
        kwargs.update(_backbone_dl_limits)
        _reportive_get(skynet_id, kwargs)

    _logger.debug("Done resource %s fetching.", skynet_id)


def skynet_copy(host, srcdir, dstdir, unstable=None, exclude=None, user=None):
    if unstable is None:
        unstable = []
    if exclude is None:
        exclude = []

    skynet_id = None
    # some functions do not accept unicode strings, so convert unicode string to regular string
    host = str(host)
    srcdir = str(srcdir)
    dstdir = str(dstdir)

    logger.info(
        "Copy using skynet: host %s, srcdir %s, dstdir %s, unstable files [%s], excluded files [%s]",
        host, srcdir, dstdir, ','.join(unstable), ','.join(exclude)
    )

    from api.cqueue import Client

    client = Client(implementation='cqudp')

    # get file list
    results = list(client.run({host}, RemoteListFiles(user, srcdir)).wait())
    _check_cqueue_results(results)
    files, not_dir = results[0][1]

    files = list(six.moves.filter(lambda x: x not in unstable and x not in exclude, files))
    if len(files):
        if not_dir:
            results = client.run({host}, ShareSandboxResource(user, [(srcdir, os.path.basename(dstdir))])).wait()
        else:
            results = client.run({host}, ShareSandboxResource(user, files, cwd=srcdir)).wait()
        results = list(results)
        _check_cqueue_results(results)
        skynet_id = results[0][1]

    # download resource
    if not_dir:
        dstdir = os.path.dirname(dstdir)

    if not os.path.exists(dstdir):
        os.mkdir(dstdir)
    mode_all_rwx = (
        stat.S_IRUSR | stat.S_IWUSR | stat.S_IXUSR |
        stat.S_IRGRP | stat.S_IWGRP | stat.S_IXGRP |
        stat.S_IROTH | stat.S_IWOTH | stat.S_IXOTH
    )
    os.chmod(dstdir, mode_all_rwx)

    if skynet_id:
        skynet_get(skynet_id, dstdir)

    # download remaining files
    if len(unstable):
        # copy changing files on remote side
        results = list(client.run({host}, ShareChangingFiles(user, unstable, srcdir)).wait())
        _check_cqueue_results(results)
        share_dir, skynet_id = results[0][1]

        # download in the same dir
        skynet_get(skynet_id, dstdir)

        results = list(client.run({host}, DeleteSharedFiles(user, share_dir)).wait())
        _check_cqueue_results(results)

    os.chmod(dstdir, mode_all_rwx)


def skynet_run_and_copy(host, srcdir, dstdir, files=None, method=None, user=None):
    if files is None:
        files = []

    # some functions do not accept unicode strings, so convert unicode string to regular string
    host = str(host)
    srcdir = str(srcdir)
    dstdir = str(dstdir)

    from api.cqueue import Client

    client = Client(implementation='cqudp')

    if len(files) == 0:
        # get file list
        results = list(client.run({host}, RemoteListFiles(user, srcdir)).wait())
        _check_cqueue_results(results)
        files, not_dir = results[0][1]

    logger.info(
        "Copy using skynet: host %s, srcdir %s, dstdir %s, files [%s]",
        host, srcdir, dstdir, ','.join(files)
    )

    if not os.path.exists(dstdir):
        os.mkdir(dstdir)

    mode_all_rwx = (
        stat.S_IRUSR | stat.S_IWUSR | stat.S_IXUSR |
        stat.S_IRGRP | stat.S_IWGRP | stat.S_IXGRP |
        stat.S_IROTH | stat.S_IWOTH | stat.S_IXOTH
    )
    os.chmod(dstdir, mode_all_rwx)
    if method is None:
        method = ShareChangingFiles(user, files, srcdir)
    results = list(client.run({host}, method).wait())
    _check_cqueue_results(results)
    share_dir, skynet_id = results[0][1]

    # download in the same dir
    skynet_get(skynet_id, dstdir)
    if share_dir is not None:
        results = list(client.run({host}, DeleteSharedFiles(user, share_dir)).wait())
        _check_cqueue_results(results)
    os.chmod(dstdir, mode_all_rwx)
