# -*- coding: utf-8 -*-

from __future__ import absolute_import, division, print_function

import datetime as dt
import logging
from base64 import b64decode

from sandbox import sdk2
from sandbox.projects.jupyter_cloud.common.task_mixin import JupyterCloud3Task, ENVIRONMENTS

CONCURRENCY = 64
HOUR = 60 * 60
STATE_APPLY_TIMEOUT = 3 * HOUR
TASK_TIMEOUT = STATE_APPLY_TIMEOUT + HOUR


class JupyterCloudSalt(JupyterCloud3Task):
    class Requirements(sdk2.Task.Requirements):
        """Multislot task requirements:
        - cores <= 16
        - ram <= 64GiB
        - no root
        - empty Caches
        """

        cores = 1
        disk_space = 2048
        ram = 2048

        class Caches(sdk2.Requirements.Caches):
            pass

    class Parameters(JupyterCloud3Task.Parameters):
        kill_timeout = TASK_TIMEOUT

        with sdk2.parameters.RadioGroup(
            'JupyterCloud environment', required=True, hint=True
        ) as environment:
            environment.values['production'] = environment.Value(default=True)
            environment.values['testing'] = environment.Value()

        salt_masters = sdk2.parameters.List(
            'Salt master servers',
            default=[],
            description='Taken from environment if empty',
        )

        with sdk2.parameters.RadioGroup('Salt action', required=True) as salt_action:
            salt_action.values['state_apply'] = salt_action.Value(default=True)

        with salt_action.value['state_apply']:
            do_restart_minions = sdk2.parameters.Bool('Restart unresponsive minions', default=True)
            state_name = sdk2.parameters.String('State name', default='user_env', required=True)

        state_apply_timeout = sdk2.parameters.Timedelta(
            'State apply timeout (seconds)',
            default=dt.timedelta(seconds=STATE_APPLY_TIMEOUT),
        )

        users = sdk2.parameters.List(
            'Users to apply',
            value_type=sdk2.parameters.Staff,
            default=[],
            description='Everyone if empty',
        )

        with sdk2.parameters.Group('Internal'):
            profile = sdk2.parameters.Bool('Profile this run', default=False)
            concurrency = sdk2.parameters.Integer('Internal concurrency', default=16)
            batchsize = sdk2.parameters.Integer('Batch size', default=20)

    def on_enqueue(self):
        """Runs on master right before execution
        You can still change Requirements and Parameters here
        """
        super(JupyterCloudSalt, self).on_enqueue()

        env = ENVIRONMENTS[self.Parameters.environment]
        if not any(self.Parameters.salt_masters):  # handles None and list of empty strings
            self.Parameters.salt_masters = env['salt_masters']

    def on_prepare(self):
        shared_secrets = self.Parameters.yav_shared_secret.data()

        self.secrets = {}
        self.secrets['yav_oauth'] = shared_secrets['yav-oauth-token']
        self.secrets['qyp_oauth'] = shared_secrets['qyp_oauth_token']
        self.secrets['minion_ssh_key'] = b64decode(
            shared_secrets['id_rsa']
        )  # files are base-64 encoded

        if self.Parameters.environment == 'production':
            env_secrets = self.Parameters.yav_prod_secret.data()
        else:
            env_secrets = self.Parameters.yav_test_secret.data()

        self.secrets['salt_secret'] = env_secrets['salt_secret']

    def on_execute(self):
        from sandbox.projects.jupyter_cloud.salt.lib import SaltJobLib

        if self.Parameters.profile:
            import yappi

            yappi.set_clock_type('wall')
            yappi.start()

        self.logger = logging.getLogger('JCSalt')

        job = SaltJobLib(job=self)
        job.execute()
