#!/usr/bin/env python
# -*- coding: UTF-8 -*-
from __future__ import print_function
import os
import logging
import urllib2

import sandbox.common.errors as ce
from sandbox import sdk2
from sandbox.sdk2.helpers import subprocess as sp
from sandbox.sdk2.helpers import ProcessLog
from sandbox.projects.kikimr.resources import KikimrRollingRestartBinary

logger = logging.getLogger('sandbox-task')
logger.setLevel(logging.DEBUG)


logger_sp = logging.getLogger('sandbox-subprocess')
logger_sp.setLevel(logging.DEBUG)


class KikimrRollingRestart(sdk2.Task):
    class Requirements(sdk2.Task.Requirements):
        disk_space = 2 * 1024

    class Parameters(sdk2.Task.Parameters):
        description = 'Rolling restart kikimr cluster.'

        addr = sdk2.parameters.List('Comma separated list of CMS addresses', sdk2.parameters.String, required=True)

        with sdk2.parameters.String('Restart action type') as service:
            service.values['cloud-dynamic'] = service.Value('cloud-dynamic')
            service.values['cloud-storage'] = service.Value('cloud-storage')
            service.values['kikimr'] = service.Value('kikimr', default=True)
            service.values['kikimr-storage'] = service.Value('kikimr-storage')
            service.values['kikimr-dynamic'] = service.Value('kikimr-dynamic')
            service.values['rtmr'] = service.Value('rtmr')
            service.values['sqs'] = service.Value('sqs')

        with sdk2.parameters.String('Cluster availability mode for CMS requests (default: max)') as availability_mode:
            availability_mode.values['max'] = availability_mode.Value('max')
            availability_mode.values['keep'] = availability_mode.Value('keep')
            availability_mode.values['force'] = availability_mode.Value('force')

        filter_hosts = sdk2.parameters.List('Host filter', sdk2.parameters.String)
        exclude_hosts = sdk2.parameters.List('Exclude hosts', sdk2.parameters.String)
        filter_version = sdk2.parameters.String('Service version filter (format: [!=<>]version)')
        tenant = sdk2.parameters.String('Tenant name')

        continue_ = sdk2.parameters.Bool('Continue previous restart', default=False)
        verbose = sdk2.parameters.Bool('Verbose', default=False)

        token_vault_owner = sdk2.parameters.String('Name of sandbox vault owner for oauth token')
        token_vault_key = sdk2.parameters.String('Name of sandbox vault entry for oauth token')

        cms_user = sdk2.parameters.String('CMS username (default: rolling-restart)')
        cms_max_wait_time = sdk2.parameters.String('Limit CMS retry time in seconds (default: 60)')
        cms_api_timeout = sdk2.parameters.String('CMS API response timeout in seconds (default: 60)')

        max_static_nodeid = sdk2.parameters.String('Max Static NodeId (default: 1000)')

        drain = sdk2.parameters.Bool('Drain tablets before restart', default=True)
        drain_time_limit = sdk2.parameters.String('Drain tablet time limit in seconds (default: 300)')
        drain_via_api = sdk2.parameters.Bool('Use API for drain tablets', default=False)

        grpc = sdk2.parameters.Bool('Use grpc', default=False)

        ssh_user = sdk2.parameters.String('SSH username', required=True)
        ssh_key_vault_owner = sdk2.parameters.String('Name of sandbox vault owner for robot ssh key', required=True)
        ssh_key_vault_entry = sdk2.parameters.String('Name of sandbox vault entry for robot ssh key', required=True)
        ssh_pool_size = sdk2.parameters.String('SSH processes pool size (default: 10)')
        ssh_logging = sdk2.parameters.Bool('Enable ssh command output logging', default=False)
        ssh_cssh = sdk2.parameters.Bool('Use cssh for connection through Bastion', default=False)
        ssh_cssh_download_url = sdk2.parameters.String('Url to download cssh (default: https://s3.mds.yandex.net/bastion/cssh)', default='https://s3.mds.yandex.net/bastion/cssh')

        restart_duration = sdk2.parameters.String('Expected max restart time in seconds (default: 60)')
        restart_retry_number = sdk2.parameters.String('Retry number of restart (default: 3)')

    def on_execute(self):
        logger.debug('calling on_execute')

        resource = KikimrRollingRestartBinary.find(
            attrs=dict(released='stable')
        ).first()

        if resource is None:
            msg = 'Cannot find %s resource' % KikimrRollingRestartBinary.name
            logger.debug(msg)
            raise ce.TaskError(msg)

        msg = 'Using resource id %s.' % resource.id
        logger.debug(msg)
        self.set_info(msg)

        script_path = str(sdk2.ResourceData(resource).path)

        cmd = [
            script_path,
            '--addr', ','.join(self.Parameters.addr),
            '--service', self.Parameters.service,
        ]

        if self.Parameters.availability_mode != '':
            cmd += ['--availability-mode', self.Parameters.availability_mode]

        if self.Parameters.filter_hosts:
            cmd += ['--hosts', ','.join(self.Parameters.filter_hosts)]

        if self.Parameters.exclude_hosts:
            cmd += ['--exclude-hosts', ','.join(self.Parameters.exclude_hosts)]

        if self.Parameters.filter_version != '':
            cmd += ['--version', self.Parameters.filter_version]

        if self.Parameters.tenant != '':
            cmd += ['--tenant', self.Parameters.tenant]

        if self.Parameters.continue_:
            cmd += ['--continue']

        if self.Parameters.verbose:
            cmd += ['--verbose']

        if self.Parameters.token_vault_key and self.Parameters.token_vault_owner:
            try:
                token = sdk2.Vault.data(self.Parameters.token_vault_owner, self.Parameters.token_vault_key)
            except ce.VaultError:
                raise ce.TaskFailure('Cannot get token using token_vault_key: %s and token_vault_owner: %s' % (
                    self.Parameters.token_vault_owner, self.Parameters.token_vault_key
                ))
            with open('token', 'w') as file:
                file.write(token)

            cmd += ['--auth-file', 'token']

        if self.Parameters.cms_user != '':
            cmd += ['--cms-user', self.Parameters.cms_user]

        if self.Parameters.cms_max_wait_time != '':
            cmd += ['--max-cms-wait-time', self.Parameters.cms_max_wait_time]

        if self.Parameters.cms_api_timeout != '':
            cmd += ['--api-timeout', self.Parameters.cms_api_timeout]

        if self.Parameters.max_static_nodeid != '':
            cmd += ['--max-static-nodeid', self.Parameters.max_static_nodeid]

        if self.Parameters.drain:
            cmd += ['--drain']

            if self.Parameters.drain_time_limit != '':
                cmd += ['--drain-time-limit', self.Parameters.drain_time_limit]

            if self.Parameters.drain_via_api:
                cmd += ['--drain-via-api']

        if self.Parameters.grpc:
            cmd += ['--grpc']

        cmd += ['--ssh-user', self.Parameters.ssh_user]

        if self.Parameters.ssh_pool_size != '':
            cmd += ['--ssh-pool-size', self.Parameters.ssh_pool_size]

        if self.Parameters.ssh_logging:
            cmd += ['--ssh-logging']

        if self.Parameters.ssh_cssh:
            cssh_path = os.path.abspath('./cssh')
            with open(cssh_path, 'w') as file:
                file.write(urllib2.urlopen(self.Parameters.ssh_cssh_download_url).read())
            os.chmod(cssh_path, 0775)
            cmd += ['--ssh-cmd', cssh_path]

        if self.Parameters.restart_duration != '':
            cmd += ['--restart-duration', self.Parameters.restart_duration]

        if self.Parameters.restart_retry_number != '':
            cmd += ['--restart-retry-number', self.Parameters.restart_retry_number]

        msg = 'Run rolling restart: %s.' % cmd
        logger.debug(msg)
        self.set_info(msg)

        with sdk2.ssh.Key(self, self.Parameters.ssh_key_vault_owner, self.Parameters.ssh_key_vault_entry):
            with ProcessLog(self, logger=logger_sp) as pl:
                proc = sp.Popen(cmd, shell=False, stdout=pl.stdout, stderr=pl.stdout)
                proc.communicate()
                if proc.returncode != 0:
                    msg = 'Rolling restart failed with returncode %s.' % proc.returncode
                    logger.debug(msg)
                    raise ce.TaskError(msg)

        self.set_info('Done')
