from __future__ import absolute_import

import os
import re
import sys
import time
import shutil
import logging
import itertools
import functools

import requests
import concurrent.futures

from . import fs
from . import share
from . import config
from .types import misc as ctm


registry = config.Registry()


__all__ = (
    "directory_version", "package_version",
    "PackageOutdatedError", "PackageUpdater", "RevisionChecker",
)


class ResourceDownloadError(Exception):
    pass


def download(res, downloads_dir, try_rbtorrent_first=False, use_fastbone=None, logger=None):
    if logger is None:
        logger = logging.getLogger(__name__)

    import api.copier
    copier = api.copier.Copier()
    skynet_id = res.get("skynet_id")
    rbtorrents = [skynet_id] if skynet_id else ()
    resid = "any:" + ",".join(itertools.chain(
        rbtorrents if try_rbtorrent_first else (),
        res.get("rsync", ()),
        res.get("http", ()),
        () if try_rbtorrent_first else rbtorrents
    ))
    timeout = share.calculate_timeout(res.get("size", 0) << 10)
    network_selector = {
        True: api.copier.Network.Fastbone,
        False: api.copier.Network.Backbone
    }
    if use_fastbone is None:
        use_fastbone = registry.common.network.fastbone
    copier_get = functools.partial(copier.get, resid, downloads_dir, network=None, user=True)
    try:
        copier_get(network=network_selector[use_fastbone]).wait(timeout)
    except Exception as error:
        logger.warning(
            "Can not download resource #%s:\n%s\nTrying to download it via %s.",
            res.get("id"), error, "backbone" if use_fastbone else "fastbone"
        )
        try:
            copier_get(network=network_selector[not use_fastbone]).wait(timeout)
        except Exception as error:
            raise ResourceDownloadError(str(error))

    dir_name = os.path.dirname(res["file_name"])
    file_name = os.path.basename(res["file_name"])
    if dir_name:
        file_path = os.path.join(downloads_dir, "_".join((dir_name, file_name)))
        os.rename(os.path.join(downloads_dir, file_name), file_path)
    else:
        file_path = os.path.join(downloads_dir, file_name)
    return file_path


def directory_version(root):
    try:
        with open(os.path.join(root, ".revision"), "r") as fh:
            return int(fh.readline().strip())
    except (OSError, IOError, ValueError, TypeError):
        return None


def package_version(pkg):
    """ Returns a version for package of type specified. """

    if pkg == "venv":
        path = os.path.dirname(os.path.dirname(sys.executable))
    elif pkg == "docs":
        path = registry.server.web.static.docs_link
    elif pkg == "tasks":
        path = registry.client.tasks.code_dir

    return directory_version(path)


class PackageInfo(object):
    def __init__(self, name, resource, link):
        self.name = name
        self.resource = resource
        self.link = link

        self.current_revision = package_version(self.name) or 0
        self.target_revision = int(self.resource.get("revision", 0))


def get_package_info(package_name, resource):
    pkg_info = None

    if package_name == "tasks":
        pkg_info = PackageInfo("tasks", resource, registry.client.tasks.code_dir)
    elif package_name == "docs":
        pkg_info = PackageInfo("docs", resource, registry.server.web.static.docs_link)

    if pkg_info is None:
        raise Exception("Unknown package: {}".format(package_name))

    return pkg_info


def get_revisions_from_sandbox_servers(nodes, logger=None):
    if logger is None:
        logger = logging.getLogger(__name__)

    servers = set(nodes)
    servers.discard(registry.this.fqdn)  # remove current server from ping list

    def ping(server):
        try:
            r = requests.get("http://{}:{}/http_check".format(server, registry.server.api.port), timeout=30)
        except Exception as e:
            logger.error("Server %s didn't respond in a timely manner: %s", server, e)
            return None

        revision = r.headers.get("X-Tasks-Revision", None)
        if revision:
            return int(revision)

        logger.warning("Server %s responded for http_check, but with weird headers: %r", server, r.headers)
        return None

    logger.info("Asking other servers for their versions: %s", nodes)
    with concurrent.futures.ThreadPoolExecutor(max_workers=len(servers)) as pool:
        revisions = {rev for rev in pool.map(ping, servers) if rev is not None}
        logger.info("Other servers versions: %s", revisions)

    return revisions


class PackageOutdatedError(Exception):
    pass


class PackageUpdater(object):
    def __init__(self, logger=None):
        self._logger = logger or logging.getLogger(__name__)
        self._packages_dir = os.path.dirname(registry.client.tasks.code_dir)
        self._download_dir = os.path.join(os.path.dirname(self._packages_dir), "downloads")

    def _package_is_up_to_date(self, pkg_info):
        if pkg_info.target_revision == 0 or pkg_info.target_revision <= pkg_info.current_revision:
            self._logger.debug("Package %s:%s is up to date", pkg_info.name, pkg_info.current_revision)
            return True

        self._logger.debug(
            "Update for package detected: %s:%s -> %s",
            pkg_info.name, pkg_info.current_revision, pkg_info.target_revision
        )
        return False

    def _download_resource(self, pkg_info, pkg_path=None):
        # remove garbage from previous downloads
        for node in os.listdir(self._download_dir):
            if pkg_info.name in node:
                os.unlink(os.path.join(self._download_dir, node))

        # download and unpack resource
        pkg_path = pkg_path or download(
            pkg_info.resource, self._download_dir,
            try_rbtorrent_first=True, use_fastbone=False, logger=self._logger
        )
        pkg_dir = fs.make_folder(
            os.path.join(self._packages_dir, "{}.{}".format(pkg_info.name, pkg_info.target_revision)),
            log=self._logger
        )
        fs.untar_archive(pkg_path, pkg_dir, log=self._logger)

        return pkg_dir

    def _replace_current_package(self, pkg_info, pkg_dir):
        try:
            prev_dir = os.readlink(pkg_info.link)
        except (OSError, IOError):
            prev_dir = None

        fs.make_symlink(pkg_dir, pkg_info.link, force=True, log=self._logger)

        if prev_dir and prev_dir != pkg_dir and os.path.dirname(prev_dir) == os.path.dirname(pkg_info.link):
            self._logger.info("Dropping previous package version directory %r", prev_dir)
            shutil.rmtree(prev_dir)

    def update_package(self, package_name, resource, force=False, pkg_path=None):
        """Check for updates for `package_name` and perform the update if there are any.

        :param package_name:    Package name to check and update.
        :param resource:        Resource to use.
        :param force:           Force update.
        :param pkg_path:        Path to package archive if it downloaded already

        :return: True if update was successful, False otherwise
        :rtype: bool
        """
        pkg_info = get_package_info(package_name, resource)

        if self._package_is_up_to_date(pkg_info) and not force:
            return False

        try:
            pkg_dir = self._download_resource(pkg_info, pkg_path)
        except ResourceDownloadError:
            self._logger.exception("Failed to download resource with new package")
            return False

        self._replace_current_package(pkg_info, pkg_dir)
        return True

    def update_tasks_with_node_id_check(self, resource, nodes):
        """Check for updates for `tasks` package and tries to perform the update if there are any.

        If server has odd number it will update immediately.
        Servers with even numbers will wait until there is at least one server with new version.

        :param resource: Resource to use.
        :param nodes: List of all available nodes.
        :return: True if update was successful, False otherwise
        :rtype: bool
        """
        pkg_info = get_package_info("tasks", resource)

        if self._package_is_up_to_date(pkg_info):
            return False

        pkg_dir = None

        try:
            pkg_dir = self._download_resource(pkg_info)
        except ResourceDownloadError:
            self._logger.exception("Failed to download resource with new package")

        if pkg_dir is None:
            # Check if someone already updated. If yes -- we should also update immediately.
            revisions = get_revisions_from_sandbox_servers(nodes, self._logger)
            if pkg_info.target_revision in revisions:
                self._logger.warning("Another server has new package, can't risk working with outdated")
                raise PackageOutdatedError

            # Failed to download resource, but there are no updated servers yet,
            # so we can try again later without forced udpate.
            return False

        def pkg_update_allowed():
            # Get current server number and check if it's odd.
            # Examples:
            # sandbox-server01 -> 1 -> True
            # sandbox-preprod06 -> 6 -> False
            # myt3-531 -> 531 -> True
            # charmander -> None -> False
            m = re.match(r".*?(\d+)$", registry.this.id)
            if m:
                node_id = int(m.group(1))
                if node_id % 2 == 1:
                    return True

            # Ask other servers about their revision
            revisions = get_revisions_from_sandbox_servers(nodes, self._logger)
            return pkg_info.target_revision in revisions

        if pkg_update_allowed():
            self._logger.info("There are either servers with new package or this server is a brave one")
            self._replace_current_package(pkg_info, pkg_dir)
            return True

        self._logger.warning(
            "There are no alive servers with revision %d. Don't update to new package, try again later",
            pkg_info.target_revision
        )
        shutil.rmtree(pkg_dir)
        return False


class RevisionChecker(object):
    CHECK_INTERVAL = 600

    def __init__(self, logger=None):
        self._check_time = 0
        self._logger = logger or logging.getLogger(__name__)

    def get_revisions(self):
        revisions = {}
        local = registry.common.installation == ctm.Installation.LOCAL
        if self._check_time + self.CHECK_INTERVAL > time.time() and not local:
            root_dir = registry.common.dirs.service
            for pkg in ("server", "client", "sdk", "tasks", "venv"):
                pkg_path = os.path.join(root_dir, pkg)
                if os.path.exists(pkg_path):
                    rev_path = os.path.join(pkg_path, ".revision")
                    if os.path.exists(rev_path):
                        with open(rev_path) as f:
                            revisions[pkg] = f.read()
                    else:
                        self._logger.error("File with revision not found: '%s'", rev_path)
        self._check_time = time.time()
        return revisions
