from sandbox import sdk2
from sandbox.common.types import task as sandbox_task

from sandbox.projects.crypta.common import (
    helpers,
    task,
)
from sandbox.projects.crypta.common.vault import get_vault_item

CONFIG_FILENAME = "config.yaml"


class CryptaMlTrainCustomModelBundle(sdk2.Resource):
    releasable = True
    ttl_on_release = 30


class CryptaMlTrainCustomModelTask(task.CryptaTask):
    class Parameters(task.CryptaTask.Parameters):
        sample_table_path = sdk2.parameters.String(
            "Path to train sample table",
            required=True,
        )
        sample_file_path = sdk2.parameters.String(
            "Path to train sample file",
            required=False,
            default=None,
        )
        audience_id = sdk2.parameters.String(
            "Audience id with train sample file",
            required=False,
            default=None,
        )
        model_name = sdk2.parameters.String(
            "Industry model name, if exists",
            required=False,
            default=None,
        )
        output_dir_path = sdk2.parameters.String(
            "Path to directory to save the training results",
            required=False,
            default=None,
        )
        validate_segments = sdk2.parameters.Bool(
            "If classes distribution in resulting segments should be calculated",
            required=False,
            default=True,
        )
        positive_segment_size = sdk2.parameters.Integer(
            "The number of users that will be selected for calculated positive segment",
            required=False,
            default=100000,
        )
        negative_segment_size = sdk2.parameters.Integer(
            "The number of users that will be selected for calculated negative segment",
            required=False,
            default=100000,
        )
        send_results_to_api = sdk2.parameters.Bool(
            "If metrics and errors need to be sent to lab api",
            required=False,
            default=False,
        )
        partner = sdk2.parameters.String(
            "Partner who provided the data",
            required=False,
            default=None,
        )
        if_make_decision = sdk2.parameters.Bool(
            "If sample should be added/rejected automatically",
            required=False,
            default=False,
        )
        login = sdk2.parameters.String(
            "Login of the user who adds new sample",
            required=False,
            default=None,
        )
        logins_to_share = sdk2.parameters.String(
            "Logins to add access to segments for existing model",
            required=False,
            default=None,
        )
        kill_timeout = 8 * 60 * 60  # 8 hours

    class Requirements(task.CryptaTask.Requirements):
        cores = 32
        disk_space = 20 * 1024
        ram = 30 * 1024

    class CryptaOptions(task.CryptaTask.CryptaOptions):
        bundle_resource_type = CryptaMlTrainCustomModelBundle
        use_semaphore = True

        cmd = [
            helpers.get_abspath("crypta-ml-train-custom-model"),
            "--config", helpers.get_abspath(CONFIG_FILENAME),
        ]

    def on_enqueue(self):
        if self.Parameters.model_name:
            self.Requirements.semaphores = sandbox_task.Semaphores(
                acquires=[
                    sandbox_task.Semaphores.Acquire(
                        '-'.join([self.__class__.__name__, self.Parameters.model_name]),
                        weight=1,
                        capacity=1,
                    )
                ]
            )

    def on_prepare(self):
        super(CryptaMlTrainCustomModelTask, self).on_prepare()

        helpers.render_template(
            CONFIG_FILENAME,
            sample_table_path=self.Parameters.sample_table_path,
            sample_file_path=self.Parameters.sample_file_path,
            audience_id=self.Parameters.audience_id,
            model_name=self.Parameters.model_name,
            output_dir_path=self.Parameters.output_dir_path,
            validate_segments=self.Parameters.validate_segments,
            positive_output_segment_size=self.Parameters.positive_segment_size,
            negative_output_segment_size=self.Parameters.negative_segment_size,
            send_results_to_api=self.Parameters.send_results_to_api,
            partner=self.Parameters.partner,
            login=self.Parameters.login,
            if_make_decision=self.Parameters.if_make_decision,
            logins_to_share=self.Parameters.logins_to_share,
        )

    def get_additional_env(self):
        return {
            "CRYPTA_ENVIRONMENT": self.Parameters.environment,
            "SANDBOX_TOKEN": get_vault_item("sec-01csvzg7vtpvb7bgrx18ajsscj[token]"),
            "PROFILE_TVM_SECRET": get_vault_item("sec-01dq7m7y6xb50x8h5hheh8b4s0[client_secret]"),
            "CUSTOM_ML_TVM_SECRET": get_vault_item("sec-01fttw9304mryykfqsas90zq26[client_secret]"),
            "ROBOT_SECRETARY_OAUTH_TOKEN": get_vault_item("sec-01csvzhx0pk7094n7hkpvd73qt[token]"),
        }
