from __future__ import print_function, absolute_import
import os
import sys
import pwd
import errno
import shutil

from .utils import suppress
from .filesystem import py_ptsname
from .sysutils import pipe, prepare_env, quote_shell_arg
from .handlers import SocketOutHandler, InOutHandler, WatchHandler, TelnetHandler
from .portotools import (
    get_portoconn,
    get_container_root,
    get_container_proc_root,
    find_in_container,
    make_utils_volume,
    make_home_volume,
    make_yatool_volume,
)

import six
from porto.exceptions import InvalidValue, PermissionError
from porto.api import rpc as rpc_pb2

from infra.skylib.porto import (
    set_capabilities,
    get_parent_container,
    get_container_user_and_group,
)
from infra.skylib import safe_container_actions
from library.python.capabilities import capabilities
from .slots.slot import YpLiteSlot, YpSlot


DEPLOY_ENVIRON_TO_UI_URL_INDEX = {
    "dev": "https://test.deploy.yandex-team.ru",
    "pre": "https://man-pre.deploy.yandex-team.ru",
    "prod": "https://deploy.yandex-team.ru",
}


def is_motd_enabled(slot):
    if slot is None:
        return False

    if not isinstance(slot, (YpLiteSlot, YpSlot)):
        return False

    if not slot.pod_labels:
        return False

    if slot.pod_labels.get('portoshell_motd_disabled'):
        return False

    return True


def motd_line(header, value=None):
    if value:
        return '{0:<15}  {1}'.format(header + ':', value)
    else:
        return '{0:<14}'.format(header)


def make_yasm_url(slot, node_hostname):
    if isinstance(slot, YpLiteSlot):
        return (
            'https://yasm.yandex-team.ru/template/panel/Porto-container/hosts={slot.mtn_hostname};'
            'itype={slot.yasm_tags.itype};ctype={slot.yasm_tags.ctype};prj={slot.yasm_tags.prj};node={node_hostname}/'
        ).format(
            slot=slot,
            node_hostname=node_hostname,
        )
    return None


def build_motd_lines(slot, node_hostname):
    if isinstance(slot, YpSlot):
        deploy_env = slot.pod_labels.get('environ', 'prod') or 'prod'
        deploy_info = slot.pod_labels.get("deploy", {})
        deploy_unit_id = deploy_info.get("deploy_unit_id")
        stage_id = deploy_info.get("stage_id")
        ui_url = DEPLOY_ENVIRON_TO_UI_URL_INDEX.get(deploy_env) or DEPLOY_ENVIRON_TO_UI_URL_INDEX["prod"]
        stage_url = "{ui_url}/stages/{stage_id}".format(ui_url=ui_url, stage_id=stage_id)
        if slot.box:
            welcome_msg = motd_line("Welcome to box", slot.get_persistent_box_fqdn())
        else:
            welcome_msg = motd_line("Welcome to pod", slot.get_persistent_pod_fqdn())
        return [
            welcome_msg,
            motd_line("Node", node_hostname),
            motd_line("Pod ID", slot.pod),
            motd_line("IP", slot.mtn_interfaces[0] if slot.mtn_interfaces else "---"),
            motd_line("Container", slot.container),
            "",
            motd_line("Stage", stage_url),
            motd_line("Pod set", slot.pod_set_id),
            motd_line("Deploy Unit", deploy_unit_id),
            motd_line("Box", slot.box) if slot.box else None,
            "",
            " " * 10 + '"ya tool strace|perf|gdb" is already available in this container!'
        ]

    elif isinstance(slot, YpLiteSlot):
        nanny_service_url = slot.pod_labels.get('deploy_engine_url')
        if not nanny_service_url and slot.api_url:
            api_url_parsed = six.moves.urllib.parse.urlparse(slot.api_url)
            nanny_service_url = "https://{}/ui/#/services/catalog/{}/".format(api_url_parsed.netloc, slot.service)

        return [
            motd_line("Welcome to pod", slot.mtn_hostname),
            motd_line("Node", node_hostname),
            motd_line("Pod ID", slot.pod),
            motd_line("IP", slot.mtn_interfaces[0] if slot.mtn_interfaces else "---"),
            motd_line("Pod Yasm panel", make_yasm_url(slot, node_hostname)),
            "",
            motd_line("Workdir", slot.instance_dir),
            motd_line("Meta container", slot.container),
            "",
            motd_line("Service",  nanny_service_url),
            motd_line("Conf ID",  slot.configuration_id),
            motd_line("State",  "{slot.state} -> {slot.target_state}".format(slot=slot)),
            "",
            '"ya tool strace|perf|gdb" already available in this container!'
        ]
    else:
        return []


def send_motd(channel, motd_lines):
    channel.sendall(
        '\r\n' +
        '\r\n'.join(line for line in motd_lines if line is not None) +
        '\r\n' +
        '\r\n'
    )


def can_join_mount_ns():
    caps = capabilities.Capabilities()
    return caps.is_set(capabilities.cap_sys_admin, capabilities.Effective)


def openpty(root_pid):
    def cb():
        ptm, pts = os.openpty()
        ptsname = py_ptsname(ptm)
        try:
            ttyname = os.ttyname(pts)
        except EnvironmentError:
            ttyname = None

        return (ptm, pts), (ptsname or ttyname or b'')

    (ptm, pts), ttyname = safe_container_actions.make_fds(cb, None, root_pid)
    if not ttyname:
        ttyname = None

    return ptm, pts, ttyname


class StartupException(Exception):
    def __init__(self, msg):
        self.msg = msg


class Context(object):
    def __init__(self, log):
        self.log = log
        self.user = None
        self.unset_env = None
        self.tools_tarball = None
        self.extra_files = None
        self.telnet_timeout = 3600
        self.interactive_cmd = False
        self.api_mode = False
        self.streaming = True
        self.enable_shellwrapper = False
        self.sessionleader_user = None
        self.sessionleader_session_id = None
        self.sessions = {}

        self.send_warning = lambda msg, *args: None  # by default we do nothing with warnings
        self.is_motd_enabled = False
        self.motd_lines = []

    def make_session(self, channel):
        chanid = channel if isinstance(channel, int) else channel.get_id()

        if chanid in self.sessions:
            return self.sessions[chanid]

        session = self.sessions[chanid] = Session(self.log, self)
        return session

    def setup_motd(self, slot, node_hostname):
        self.is_motd_enabled = is_motd_enabled(slot)
        if self.is_motd_enabled:
            self.motd_lines = build_motd_lines(slot, node_hostname)

    def close_session(self, channel):
        chanid = channel if isinstance(channel, int) else channel.get_id()
        session = self.sessions.pop(chanid, None)
        if session is not None:
            if session.agent:
                self.log.info("removing agent")
                session.agent.close()
                session.agent = None

            self.log.info("removing dirs: %s", session.dirs_to_remove)
            for dest in session.dirs_to_remove:
                shutil.rmtree(dest, ignore_errors=True)

    def set_env(self, channel, key, value):
        self.make_session(channel).extra_env.insert(0, (key, value))

    def set_winsize(self, channel, rows, cols, xpixel, ypixel):
        session = self.make_session(channel)
        session.pty_params = {'width': cols, 'height': rows}
        if hasattr(session.handler, 'set_winsize'):
            try:
                session.handler.set_winsize(rows=rows, cols=cols, xpixel=xpixel, ypixel=ypixel)
            except Exception as e:
                self.log.warning("failed to set terminal size: %s", e)
                pass
            else:
                return True
        return False

    def make_agent(self, channel, klass):
        self.make_session(channel).agent = klass(channel.get_transport(), log=self.log.getChild('agent'))

    def finalize(self):
        for session in six.itervalues(self.sessions):
            if session.agent:
                session.agent.close()


class Session(object):
    def __init__(self, log, ctx):
        self.log = log
        self.ctx = ctx
        self.tag = ''
        self.container_name = None
        self.extra_env = []
        self.cwd = None
        self.cmd = None
        self.shell = None
        self.dirs_to_remove = []
        self.pty_params = None

        self.container = None
        self.handler = None
        self.agent = None

    def _get_command(self):
        if self.cmd:
            return '%s -c %s' % (self.shell, quote_shell_arg(self.cmd))
        else:
            return '%s -i' % (self.shell,)

    def _configure_home(self, user):
        home_naive = b'/root' if user == 'root' else os.path.join('/home', user).encode('utf-8')
        if find_in_container(self.container_name, (home_naive,), '', isdir=True):
            return None, home_naive

        root = get_container_root(self.container_name)
        if root == '/':
            return make_home_volume(self.container_name, self.ctx.extra_files)

        path = os.path.join(root, home_naive[1:])  # strip first char to make it relative
        try:
            os.makedirs(path, mode=0o755)
        except EnvironmentError as e:
            if e.errno != errno.EEXIST:
                return make_home_volume(self.container_name, self.ctx.extra_files)
        try:
            uid = pwd.getpwnam(user).pw_uid
            os.chown(path, uid, -1)
        except EnvironmentError as e:
            return make_home_volume(self.container_name, self.ctx.extra_files)

        return None, home_naive

    def _restrict_porto_access(self, user):
        if user != 'nobody':
            try:
                self.container.SetProperty("enable_porto", 'child-only')
                return
            except (InvalidValue, PermissionError):
                pass

        self.container.SetProperty("enable_porto", False)

    def _set_virt_mode(self, user):
        if user != 'nobody':
            try:
                self.container.SetProperty('virt_mode', 'job')
                return
            except InvalidValue:
                pass

        self.container.SetProperty('virt_mode', 'app')

    def _prepare_agent(self, user):
        agent_path = None
        try:
            parent = get_parent_container(get_portoconn(), self.container_name)
            root = get_container_proc_root(parent.name)
            agent_path = start_agent(self.agent, root, user)
        except Exception as e:
            self.log.error("failed to start ssh agent using procfs, will try volume path: %s", e)

            try:
                root = get_container_root(self.container_name)
                agent_path = start_agent(self.agent, root, user)
            except Exception as e:
                self.log.error("failed to start ssh agent using volume path: %s", e)
                self.ctx.send_warning("failed to start ssh agent, agent will be unavailable: %s", e)

        if agent_path:
            self.extra_env.append(('SSH_AUTH_SOCK', agent_path))

    def _prepare_yatool_cache(self):
        self.log.debug("Binding yatool cache volume")
        try:
            if make_yatool_volume(self.container_name):
                self.extra_env.append(('YA_CACHE_DIR_TOOLS', '/ya_tool_cache/tools'))
                self.extra_env.append(('YA_TC', '0'))
                self.extra_env.append(('YA_OAUTH_EXCHANGE_SSH_KEYS', 'false'))
                # insert it into the beginning in case if user wants to override the PATH
                self.extra_env.insert(0, ('PATH', "/ya_tool_cache:/usr/local/sbin:/usr/local/bin:/usr/sbin:/usr/bin:/sbin:/bin"))
        except Exception as e:
            self.log.exception("failed to bind ya_tool_cache volume: %s", e)
            self.ctx.send_warning("failed to bind ya_tool_cache volume, 'ya tool' cache will not be available: %s", e)

    def prepare_container(self):
        portoconn = get_portoconn()
        parent = get_parent_container(portoconn, self.container_name)
        parent_root = get_container_root(parent.name)

        user, group = get_container_user_and_group(portoconn, parent, self.ctx.user)
        self._prepare_agent(user)
        self._prepare_yatool_cache()

        self.container.SetProperty('owner_containers', 'self')

        command = self._get_command()
        if self.ctx.tools_tarball is not None and self.ctx.enable_shellwrapper and parent_root != '/':
            command = '/portoshell_utils/shellwrapper ' + command
        if self.ctx.tools_tarball is not None and self.ctx.sessionleader_user is not None and parent_root != '/':
            command = ('/portoshell_utils/sessionleader "%s" ' % self.ctx.sessionleader_user) + command

        home_tmp, home = self._configure_home(user)
        home_tmp and self.dirs_to_remove.append(home_tmp)
        self.extra_env.insert(0, ('HOME', home))
        # not misspell, see SKYDEV-1399
        self.cwd and self.extra_env.insert(0, ('SWD', self.cwd))
        if self.ctx.sessionleader_session_id is not None:
            self.extra_env.append(('PORTOSHELL_SESSION_ID', self.ctx.sessionleader_session_id))

        # check if we're really in container
        own_user = self.container.GetProperty('owner_user')

        self.log.debug("using shell=%r, user=%r, extra_env=%s, unset_env=%s", self.shell, user, self.extra_env, self.ctx.unset_env)
        for opt, val in {
            'respawn': False,
            'isolate': False,
            'command': command,
            'env': prepare_env(self.shell, user, self.extra_env, self.ctx.unset_env),
            'user': user,
            'group': group,
            'private': ('PORTOSHELL %s' % (self.tag,)) if self.tag else 'PORTOSHELL',
        }.iteritems():
            self.container.SetProperty(opt, val)

        if own_user == 'root':
            for opt, val in {
                'owner_user': parent.GetProperty('owner_user'),
                'owner_group': parent.GetProperty('owner_group'),
            }.iteritems():
                self.container.SetProperty(opt, val)

        self._restrict_porto_access(user)
        self._set_virt_mode(user)

        try:
            self.container.SetProperty('weak', True)
        except Exception:
            pass

        if self.ctx.tools_tarball is not None and parent_root != '/':
            try:
                make_utils_volume(self.container_name, self.ctx.tools_tarball)
            except Exception as e:
                self.log.exception("failed to bind utils volume: %s", e)
                if self.ctx.enable_shellwrapper or self.ctx.sessionleader_user:
                    self.container.SetProperty('command', self._get_command())  # rollback to no shellwrapper
                self.ctx.send_warning("failed to bind utils volume, tools like 'scp' or 'busybox' can be unavailable: %s", e)

        if self.cwd is not None:
            self.container.SetProperty('cwd', self.cwd)

        uid = pwd.getpwnam(user).pw_uid
        for path in ('stdin_path', 'stdout_path', 'stderr_path'):
            val = self.container.GetProperty(path)
            if val.startswith('/dev/fd/'):
                fd = int(val[len('/dev/fd/'):])
                os.fchown(fd, uid, -1)

        set_capabilities(portoconn, self.container, user, own_user)

    def _create_container(self):
        job_spec = rpc_pb2.TContainerSpec()
        job_spec.name = self.container_name
        job_spec.owner_containers = "self"

        self.container = get_portoconn().CreateSpec(job_spec, start=False)

    def start_server(self, sock):
        self.shell = find_shell(self.container_name)

        try:
            self._create_container()
        except InvalidValue:
            error("Cannot join slot container: it doesn't exist (%r)" % (self.container_name,), self.ctx.api_mode)

        pts = None

        if self.cmd and not self.ctx.interactive_cmd:
            stdout_r, stdout_w = pipe()
            stderr_r, stderr_w = pipe()

            self.container.SetProperty('stdout_path', os.path.join('/dev/fd', str(stdout_w)))
            self.container.SetProperty('stderr_path', os.path.join('/dev/fd', str(stderr_w)))

            handler = SocketOutHandler(sock,
                                       self.container,
                                       stdout=stdout_r,
                                       stderr=stderr_r,
                                       api_mode=self.ctx.api_mode,
                                       streaming=self.ctx.streaming,
                                       log=self.log)
        elif not self.cmd and (self.ctx.api_mode or not self.ctx.streaming):
            raise RuntimeError("api_mode and non-streaming mode aren't supported without command specified")
        else:
            root_pid = int(get_parent_container(get_portoconn(), self.container_name).GetProperty('root_pid'))
            self.log.debug('preparing TelnetHandler for root_pid %r', root_pid)
            ptm, pts, ptsname = openpty(root_pid)

            for prop in ('stdin_path', 'stdout_path', 'stderr_path'):
                self.container.SetProperty(prop, os.path.join('/dev/fd', str(pts)))

            handler = TelnetHandler(self.container, sock, ptm, timeout=self.ctx.telnet_timeout)
            if self.pty_params:
                handler.set_winsize(self.pty_params.get('height', 80), self.pty_params.get('width', 24))

        self.log.debug('preparing container')
        self.prepare_container()
        self.container.Start()

        if self.pty_params is not None and self.ctx.is_motd_enabled:
            send_motd(sock, self.ctx.motd_lines)

        if pts is not None:
            os.close(pts)

        return handler

    def start_ssh(self, channel):
        self.shell = find_shell(self.container_name)

        try:
            self._create_container()
        except InvalidValue:
            raise Exception("Cannot join slot container: it doesn't exist (%r)" % (self.container_name,))

        pts = None

        try:
            if self.pty_params is None:
                self.log.debug('preparing InOutHandler')
                stdin_r, stdin_w = pipe()
                stdout_r, stdout_w = pipe()
                stderr_r, stderr_w = pipe()

                self.container.SetProperty('stdin_path', os.path.join('/dev/fd', str(stdin_r)))
                self.container.SetProperty('stdout_path', os.path.join('/dev/fd', str(stdout_w)))
                self.container.SetProperty('stderr_path', os.path.join('/dev/fd', str(stderr_w)))

                handler = InOutHandler(channel=channel,
                                       container=self.container,
                                       stdin=stdin_w,
                                       stdout=stdout_r,
                                       stderr=stderr_r,
                                       log=self.log)
            else:
                root_pid = int(get_parent_container(get_portoconn(), self.container_name).GetProperty('root_pid'))
                self.log.debug('preparing TelnetHandler for root_pid %r', root_pid)
                ptm, pts, ptsname = openpty(root_pid)

                for prop in ('stdin_path', 'stdout_path', 'stderr_path'):
                    self.container.SetProperty(prop, os.path.join('/dev/fd', str(pts)))

                if ptsname is not None:
                    self.extra_env.append(('SSH_TTY', ptsname))
                self.handler = handler = TelnetHandler(self.container, channel, ptm, timeout=self.ctx.telnet_timeout)
                handler.set_winsize(self.pty_params.get('height', 80), self.pty_params.get('width', 24))

            self.log.debug('preparing container')
            self.prepare_container()
            self.container.Start()

            if self.pty_params is not None and self.ctx.is_motd_enabled:
                send_motd(channel, self.ctx.motd_lines)

            if pts is not None:
                os.close(pts)

            return handler
        except BaseException:
            exc_info = sys.exc_info()
            suppress(self.container.Destroy)
            six.reraise(exc_info[0], exc_info[1], exc_info[2])

    def watch_ssh(self, channel):
        try:
            self.container = get_portoconn().Find(self.container_name)
        except InvalidValue:
            raise Exception("Cannot join slot container: it doesn't exist (%r)" % (self.container_name,))

        handler = WatchHandler(channel=channel,
                               container=self.container,
                               log=self.log,
                               )
        return handler


def error(message, api_mode):
    if api_mode:
        raise StartupException(message)
    else:
        print(message, file=sys.stderr)
        raise SystemExit(1)


def start_agent(agent, root, user):
    if agent is None:
        return

    agent.start(
        basedir=root if root != '/' else None,
        user=user,
    )
    agent_path = agent.get_filename()
    agent_path = agent_path if root == '/' else os.path.join('/', os.path.relpath(agent_path, root))
    return agent_path


def check_for_scp(command, container_name):
    if not command.startswith(b'scp '):
        return command

    parent = get_parent_container(get_portoconn(), container_name)
    if get_container_root(parent.name) == b'/':
        return command

    base = find_in_container(parent.name,
                             candidates=(b'/usr/local/bin/scp',
                                         b'/usr/bin/scp',
                                         b'/bin/scp'),
                             fallback=b'/portoshell_utils/scp')
    command = base + command[3:]
    return command


def find_shell(container_name):
    candidates = (
        b'/usr/local/bin/bash',
        b'/usr/bin/bash',
        b'/bin/bash',
        b'/bin/sh'
    )
    fallback = b'/portoshell_utils/busybox sh'
    parent = get_parent_container(get_portoconn(), container_name)
    if get_container_root(parent.name) == b'/':
        return filter(lambda p: os.path.exists(p) and os.path.isfile(p), candidates)[0]
    return find_in_container(parent.name, candidates, fallback)
