from __future__ import absolute_import

import os
import time
import uuid
import atexit
import select
import logging
import threading

import six
from six.moves.urllib import parse as urlparse

from sandbox.common import config
from sandbox.common import random as common_random
from sandbox.common import patterns
from sandbox.sdk2.helpers import subprocess as sp

logger = logging.getLogger("vcs.tunnel")

CONTROL_SOCKET_CHECK_TIMEOUT = 0.5


class SshConnectionFailed(Exception):
    pass


class SshTunnel(object):
    CONTROL_PATH_PREFIX = "ssh_mux"

    def __init__(self, username, hostname, port, identity_file=None, on_terminate=None):
        self.username = username
        self.hostname = hostname
        self.port = port
        self.identity_file = identity_file
        self.on_terminate = on_terminate

        # create random path for a control socket, so that unpriviledged tasks won't mess with each other
        # unix sockets don't work well with overlayfs
        # https://bugs.launchpad.net/ubuntu/+source/linux/+bug/1262287
        self.control_path = os.path.join(
            self.control_path_dir,
            "{}_{}_{}_{}_{}".format(
                self.CONTROL_PATH_PREFIX, common_random.random_string(),
                self.username, self.hostname, self.port
            )
        )

        self.remote_fqdn = None
        self.control_socket_is_broken = False

        self._id = uuid.uuid4().hex[:8]
        self._process = None
        self._process_wait_lock = threading.Lock()
        self._process_wait_thread = None

    @patterns.singleton_classproperty
    def control_path_dir(cls):
        return config.Registry().common.dirs.runtime

    def _terminate_tunnel_process(self):
        self._process.terminate()
        try:
            self._process.wait(timeout=5)
        except sp.TimeoutExpired:
            self._process.kill()
            self._process.wait()

    def _wait_tunnel_process(self):
        for line in iter(self._process.stderr.readline, ""):
            logger.error("[%s] tunnel said: %s", self._id, line.strip())

        with self._process_wait_lock:
            self._process.wait()

        logger.error("[%s] tunnel exited with code: %d", self._id, self._process.returncode)

        if self.on_terminate:
            self.on_terminate()

    def _start_watchdoge(self):
        self._process_wait_thread = threading.Thread(target=self._wait_tunnel_process)
        # non-daemon threads are joined before atexit hooks
        self._process_wait_thread.daemon = True
        self._process_wait_thread.start()

    @property
    def _ssh_command(self):
        cmd = [
            "ssh",
            "-o", "StrictHostKeyChecking=no",
            "-o", "UserKnownHostsFile=/dev/null",
            "-o", "ConnectTimeout=60",
            "-o", "ControlMaster=yes",
            "-o", "ControlPersist=no",
            "-o", "ControlPath={}".format(self.control_path),
            "-6",
            "-p", self.port,
            "{}@{}".format(self.username, self.hostname)
        ]
        if self.identity_file and os.path.exists(self.identity_file):
            cmd.extend(["-i", self.identity_file])
        cmd.append("tunnel")

        return cmd

    def _try_establish_ssh_tunnel(self):
        logger.debug("Establishing ssh multiplexing tunnel: %s", " ".join(self._ssh_command))
        self._process = sp.Popen(self._ssh_command, stdin=sp.PIPE, stdout=open(os.devnull, "w"), stderr=sp.PIPE)

        def read_lines_with_timeout(fileobj, timeout):
            read_fd = fileobj.fileno()
            tail = ""
            while timeout > 0:
                started = time.time()
                if select.select([read_fd], [], [], timeout)[0]:
                    data = os.read(read_fd, 4096)
                    if not data:
                        break
                    chunk = tail + six.ensure_str(data)
                    lines = chunk.split(os.linesep)
                    for line in lines[:-1]:
                        yield line
                    tail = lines[-1]
                timeout -= time.time() - started
            if tail:
                yield tail

        # wait until tunnel is established (or fails miserably)
        for line in read_lines_with_timeout(self._process.stderr, timeout=30):
            logger.error("[%s] tunnel said: %s", self._id, line.strip())

            if "too long for Unix domain socket" in line:
                # can't use control socket, mark tunnel as broken
                self.control_socket_is_broken = True

            if line.startswith("node:"):
                # everything is great, tunnel has been established successfully
                self.remote_fqdn = line[len("node:"):].strip()
                self._start_watchdoge()
                atexit.register(self.close)
                break
        else:
            # tunnel has failed or timeouted
            self._process.kill()
            self._process.wait()
            logger.error("[%s] tunnel exited with code: %d", self._id, self._process.returncode)

        return self.remote_fqdn

    def establish(self, retries=2):
        for _ in range(retries + 1):
            if self._try_establish_ssh_tunnel():
                return
            if self.control_socket_is_broken:
                logger.error("Failed ot setup control socket, mark tunnel as broken")
                return
            logger.error("Failed to establish ssh tunnel, attempt #%d", _ + 1)

        logger.error("No attempts left, there will be no tunnel :(")
        raise SshConnectionFailed(
            "Failed to setup master connection to {}@{}:{}".format(self.username, self.hostname, self.port)
        )

    def close(self):
        # Popen.send_signal is not threadsafe in respect to Popen.wait
        with self._process_wait_lock:
            self._terminate_tunnel_process()
        self._process_wait_thread.join()


class Tunnels(object):
    _tunnels = {}  # (user, host, port) -> tunnel
    _lock = threading.Lock()

    @classmethod
    def get_tunnel(cls, user, host, port, identity_file=None):
        with cls._lock:
            tunnel = cls._tunnels.get((user, host, port), None)
            if tunnel is None:
                # cleanup garbage sockets left from dead tunnels
                cleanup_dead_control_sockets()

                try:
                    tunnel = cls._create_tunnel(user, host, port, identity_file)
                except SshConnectionFailed:
                    pass
                else:
                    cls._tunnels[(user, host, port)] = tunnel

        if tunnel and tunnel.control_socket_is_broken:
            return None

        return tunnel

    @classmethod
    def _create_tunnel(cls, user, host, port, identity_file=None):
        def on_terminate():
            with cls._lock:
                cls._tunnels.pop((user, host, port), None)

        tunnel = SshTunnel(user, host, port, identity_file, on_terminate=on_terminate)
        tunnel.establish()
        return tunnel


def cleanup_dead_control_sockets():
    for name in os.listdir(SshTunnel.control_path_dir):
        if not name.startswith(SshTunnel.CONTROL_PATH_PREFIX):
            continue

        path = os.path.join(SshTunnel.control_path_dir, name)

        with open(os.devnull, "w") as devnull:
            p = sp.Popen(
                ["ssh", "-S", path, "-O", "check", "dummy"],
                stdout=devnull, stderr=devnull
            )
            try:
                ret = p.wait(timeout=CONTROL_SOCKET_CHECK_TIMEOUT)
            except sp.TimeoutExpired:
                logger.info(
                    "Didn't clean up control socket %s in %ss -- it may be blocked by a frozen process",
                    path, CONTROL_SOCKET_CHECK_TIMEOUT
                )
                ret = None

        if ret:
            logger.info("Control socket %s is dead, remove it", path)
            try:
                os.unlink(path)
            except OSError:
                # probably someone else has removed this socket
                pass


def ensure_ssh_multiplexing_tunnel(url, identity_file):
    parsed_url = urlparse.urlparse(url)
    username = parsed_url.username
    hostname = parsed_url.hostname
    port = str(parsed_url.port or 22)

    # can't establish tunnel without username or hostname
    if not (username and hostname):
        return None

    return Tunnels.get_tunnel(username, hostname, port, identity_file)
