# -*- coding: utf-8 -*-

from __future__ import absolute_import, print_function, division

import os
import logging
import contextlib
import time

from sandbox import sdk2
from sandbox.common.errors import TaskFailure
from sandbox.sdk2.helpers import ProcessLog, ProcessRegistry, subprocess

from sandbox.projects.jupyter_cloud import JUPYTER_CLOUD_BACKUP
from sandbox.projects.jupyter_cloud.task_mixin import JupyterCloudTaskMixin

logger = logging.getLogger(__name__)

BACKUP_FILE = 'backup.tar.bz2.gpg'
PASSPHRASE_FILE = 'passphrase.file'
SSH_KEY_FILE = 'ssh.key'
BACKUP_RETRIES = 10
RESTORE_RETRIES = 10
RETRY_SLEEP = 120

DEFAULT_EXCLUDE_PATHS = [
    '.ya/tools',
    '.ya/build',
    '.local/share/jupyter/kernels/arcadia_default_py2',
    '.local/share/jupyter/kernels/arcadia_default_py3',
    '.cache/nile',
    '.cache/pip',
]


class JupyterCloudBackup(sdk2.Task, JupyterCloudTaskMixin):
    class Requirements(sdk2.Task.Requirements):
        cores = 2
        disk_space = 20 * 1024
        ram = 2 * 1024
        tasks_resource = None

        class Caches(sdk2.Requirements.Caches):
            pass

    class Parameters(JupyterCloudTaskMixin.Parameters):
        with sdk2.parameters.Group('Backup/Restore options') as backup_options:
            username = sdk2.parameters.Staff('User to backup/spawn/restore', required=True)

            do_backup = sdk2.parameters.Bool('Do backup', default=False)

            with do_backup.value[True]:
                old_host = sdk2.parameters.Url('Host to backup from', required=True)

                exclude_paths = sdk2.parameters.List(
                    "Paths to exclude from backup (relative to user's home dir",
                    default=DEFAULT_EXCLUDE_PATHS,
                )

                backup_ttl = sdk2.parameters.Integer('ttl for backup resource', default=183)

            with do_backup.value[False]:
                backup_resource = sdk2.parameters.Resource(
                    'Resource with this user backup',
                    resource_type=JUPYTER_CLOUD_BACKUP,
                )

            spawn_new_vm = sdk2.parameters.Bool('Spawn new vm', default=False)

            with spawn_new_vm.value[True]:
                with sdk2.parameters.RadioGroup('JH Instance', required=True) as jh_instance:
                    jh_instance.values['beta.jupyter.yandex-team.ru'] = None
                    jh_instance.values['new.jupyter.yandex-team.ru'] = None
                    jh_instance.values['jupyter.yandex-team.ru'] = None

                with sdk2.parameters.RadioGroup('Location', required=True) as location:
                    location.values.sas = None
                    location.values.man = None
                    location.values.vla = None
                    location.values.iva = None
                    location.values.myt = None

                with sdk2.parameters.RadioGroup('Account', required=True) as account:
                    account.values['abc:service:2142'] = account.Value(default=True)

                with sdk2.parameters.RadioGroup('Segment', required=True) as segment:
                    segment.values.dev = segment.Value()
                    segment.values.default = segment.Value(default=True)

                with sdk2.parameters.RadioGroup('Host type', required=True) as host_type:
                    host_type.values.cpu1_ram4_hdd24 = None
                    host_type.values.cpu2_ram16_hdd48 = None
                    host_type.values.cpu4_ram32_hdd96 = None

                    host_type.values.cpu1_ram4_ssd24 = None
                    host_type.values.cpu2_ram16_ssd48 = None
                    host_type.values.cpu4_ram32_ssd96 = None

            do_restore = sdk2.parameters.Bool('Do restore', default=False)

            with do_restore.value[True]:
                new_host = sdk2.parameters.Url(
                    'Host to restore to', description="leave empty if you spawn new VM"
                )

    def on_save(self):
        JupyterCloudTaskMixin.on_save(self)

        if self.Parameters.spawn_new_vm and self.Parameters.new_host:
            raise TaskFailure(
                'non-empty value for new_host while spawn_new_vm==True'
            )

        if (
            self.Parameters.do_restore and
            not self.Parameters.new_host and
            not self.Parameters.spawn_new_vm
        ):
            raise TaskFailure(
                "to restore backup fill 'new_host' or select 'spawn_new_vm' option"
            )

        if (
            self.Parameters.do_restore and
            not self.Parameters.do_backup and
            not self.Parameters.backup_resource
        ):
            raise TaskFailure(
                "to restore backup fill 'do_backup' or 'backup_resource' option"
            )

    def on_execute(self):
        from jupytercloud.tools.lib.environment import environment
        from jupytercloud.tools.lib.utils import setup_sentry

        with environment('production', self.Parameters.yav_oauth_token.data()):
            setup_sentry(__name__)

        host = None
        if self.Parameters.spawn_new_vm:
            host = self._spawn_new_vm()
        elif self.Parameters.new_host:
            host = self.Parameters.new_host

        backup = None
        if self.Parameters.do_backup:
            backup = self._do_backup()
        elif self.Parameters.backup_resource:
            backup = self.Parameters.backup_resource

        if self.Parameters.do_restore:
            self._do_restore(host, backup)

    def _get_spawn_options(self):
        location_params = [
            self.Parameters.location,
            self.Parameters.account,
            self.Parameters.segment,
            self.Parameters.host_type,
            'False'
        ]

        return {
            'instance': [';'.join(location_params)],
            'force_new': 'True',
        }

    def _spawn_new_vm(self):
        from jupytercloud.tools.lib.jupyterhub import JupyterHubClient

        user = self.Parameters.username
        host = self.Parameters.jh_instance

        jupyterhub_client = JupyterHubClient(
            host=host,
            token=self._get_secret('sandbox_api_token'),
        )

        user_info = jupyterhub_client.get_user(user)

        if user_info is None:
            raise TaskFailure('user {} is not registered in hub {}'.format(user, host))

        server_info = jupyterhub_client.get_server_info(user)

        if server_info is not None:
            raise ValueError(
                'user {} already have server in hub {} with state {}'
                .format(user, host, server_info)
            )

        spawn_options = self._get_spawn_options()

        logger.info('going to spawn vm for user %s with options %s', user, spawn_options)

        for status in jupyterhub_client.spawn_user(user, spawn_options):
            logger.info('spawner returned  message %s', status)

            if status.get('failed'):
                raise TaskFailure(
                    'spawn failed with message {}'.format(status.get('message'))
                )

        server_info = jupyterhub_client.get_server_info(user)
        server_state = server_info['state']
        host = '{}{}.{}.yp-c.yandex.net'.format(
            server_state['name_prefix'],
            user,
            server_state['cluster']
        )

        logger.info('server successfully spawned at %s', host)

        return host

    def _get_ssh_command(self, host):
        return ['ssh', '-i', 'ssh.key', '-v', 'robot-jupyter-cloud@{}'.format(host)]

    def _do_backup(self):
        user = self.Parameters.username
        host = self.Parameters.old_host

        exclude = ' '.join(
            '--exclude=./{}'.format(path)
            for path in self.Parameters.exclude_paths
        )

        backup_path = '/home/{}'.format(user)

        tar_command = self._get_ssh_command(host) + [
            'cd {} && sudo tar cj . {} --ignore-failed-read'.format(backup_path, exclude)
        ]
        gpg_command = [
            'gpg', '--symmetric', '--batch', '--no-use-agent', '--yes',
            '--passphrase-file', PASSPHRASE_FILE,
        ]

        with \
                ProcessRegistry, \
                ProcessLog(self, logger='backup', stderr_level=logging.WARNING) as process_log, \
                self._ssh_key(), \
                self._passphrase():
            logger.info(
                'running backup command: %s | %s > %s',
                tar_command, gpg_command, BACKUP_FILE
            )

            for i in range(BACKUP_RETRIES):
                with open(str(BACKUP_FILE), 'w') as backup_file:
                    tar_proc = subprocess.Popen(
                        tar_command,
                        stdout=subprocess.PIPE,
                        stderr=process_log.stderr,
                    )

                    gpg_proc = subprocess.Popen(
                        gpg_command,
                        stdin=tar_proc.stdout,
                        stderr=process_log.stderr,
                        stdout=backup_file
                    )

                    tar_proc.stdout.close()

                    tar_ret = tar_proc.wait()
                    gpg_ret = gpg_proc.wait()

                if tar_ret or gpg_ret:
                    message = 'unexpexted backup return codes: tar={}, gpg={}'.format(
                        tar_ret, gpg_ret
                    )
                    if i == BACKUP_RETRIES - 1:
                        raise TaskFailure(message)

                    logger.error(message)
                    logger.warning('fail to backup, going to do retry %d', i + 1)
                    time.sleep(RETRY_SLEEP)
                else:
                    break

        resource = JUPYTER_CLOUD_BACKUP(
            task=self,
            description="backup for {} from {}".format(user, host),
            path=BACKUP_FILE,
            user=user,
            host=host
        )
        resource.ttl = self.Parameters.backup_ttl

        resource_data = sdk2.ResourceData(resource)
        resource_data.ready()

        return resource

    def _do_restore(self, host, resource):
        user = self.Parameters.username
        if resource.user != user:
            raise TaskFailure(
                "trying to restore backup from user {} to user {}"
                .format(resource.user, user)
            )

        resource_data = sdk2.ResourceData(resource)
        resource_path = resource_data.path

        backup_path = '/home/{}'.format(user)

        gpg_cmd = [
            'gpg', '--output', '-', '--batch', '--no-use-agent', '--yes',
            '--passphrase-file', PASSPHRASE_FILE,
        ]

        tar_cmd = self._get_ssh_command(host) + [
            '(cd {}; sudo -u {} tar xjf -)'.format(backup_path, user)
        ]

        with \
                ProcessRegistry, \
                ProcessLog(self, logger='restore', stderr_level=logging.WARNING) as process_log, \
                self._ssh_key(), \
                self._passphrase():

            logger.info('running restore command: %s < %s | %s', resource_path, gpg_cmd, tar_cmd)

            for i in range(RESTORE_RETRIES):
                with open(str(resource_path)) as backup_file:
                    gpg_proc = subprocess.Popen(
                        gpg_cmd,
                        stdin=backup_file,
                        stderr=process_log.stderr,
                        stdout=subprocess.PIPE
                    )

                    tar_proc = subprocess.Popen(
                        tar_cmd,
                        stdin=gpg_proc.stdout,
                        stderr=process_log.stderr,
                        stdout=process_log.stdout,
                    )

                    gpg_proc.stdout.close()
                    tar_ret = tar_proc.wait()
                    gpg_ret = gpg_proc.wait()

                if tar_ret or gpg_ret:
                    message = 'unexpexted backup return codes: tar={}, gpg={}'.format(
                        tar_ret, gpg_ret
                    )
                    if i == RESTORE_RETRIES - 1:
                        raise TaskFailure(message)

                    logger.error(message)
                    logger.warning('fail to restore, going to do retry %d', i + 1)
                    time.sleep(RETRY_SLEEP)
                else:
                    break

    @contextlib.contextmanager
    def _ssh_key(self):
        ssh_key = self._get_secret('id_rsa')

        with open(SSH_KEY_FILE, 'w') as f_:
            f_.write(ssh_key)

        os.chmod(SSH_KEY_FILE, 0o400)

        try:
            yield
        finally:
            os.remove(SSH_KEY_FILE)

    @contextlib.contextmanager
    def _passphrase(self):
        passphrase = self._get_secret('backup_passphrase')
        with open(PASSPHRASE_FILE, 'w') as f_:
            f_.write(passphrase)

        os.chmod(PASSPHRASE_FILE, 0o400)

        try:
            yield
        finally:
            os.remove(PASSPHRASE_FILE)
