# -*- coding: utf-8 -*-

import sandbox.common.types.client as ctc
import sandbox.projects.common.constants as consts
import sandbox.projects.release_machine.mixins.build as rm_build_mixin
import sandbox.projects.common.search.gdb as gdb
import sandbox.sandboxsdk.parameters as sdk_parameters

from sandbox.projects.common.build.CommonBuildTask import CommonBuildTask
from sandbox.projects.common.nanny import nanny
from sandbox.projects.release_machine import rm_notify
from sandbox.projects.voicetech.resource_types import VOICETECH_TTS_SERVER_GPU, VOICETECH_TTS_SERVER_RUNNER
from sandbox.projects.voicetech.resource_types import VOICETECH_EVLOGDUMP


class CudaVersion(sdk_parameters.SandboxStringParameter):
    name = 'cuda_version'
    description = 'CUDA_VERSION'
    default_value = '9.1'
    required = True


class WithHfg(sdk_parameters.SandboxBoolParameter):
    name = 'with_hfg'
    description = 'Build Hfg TRT support'
    default_value = True
    required = True


@rm_notify.notify2()
class BuildTtsServerGpu(rm_build_mixin.ComponentReleaseTemplate, CommonBuildTask, nanny.ReleaseToNannyTask):
    """ Build tts-server with CUDA
    """

    type = 'BUILD_TTS_SERVER_GPU'
    execution_space = 80000  # 80 Gb
    input_parameters = CommonBuildTask.input_parameters + [CudaVersion, WithHfg]
    client_tags = ctc.Tag.Group.LINUX
    TARGET_RESOURCE_TYPES = (
        VOICETECH_TTS_SERVER_GPU,
        VOICETECH_TTS_SERVER_RUNNER,
        VOICETECH_EVLOGDUMP,
    )

    def on_enqueue(self):
        if '/trunk/' in self.ctx[consts.ARCADIA_URL_KEY]:
            decreased_space = 50 * 1024
            if decreased_space < self.execution_space:
                self.execution_space = decreased_space  # Decrease trunk build execution space
        if self.ctx.get('with_hfg'):
            gpu_params = '-DCUDA_VERSION=11.0 -DTENSORRT_VERSION=7 -DCUDNN_VERSION=8.0.5 -DOS_SDK=ubuntu-16'
        else:
            gpu_params = '-DCUDA_VERSION={} '.format(self.ctx.get('cuda_version', '9.1'))
        self.ctx['definition_flags'] = gpu_params + self.ctx.get('definition_flags', '')
        CommonBuildTask.on_enqueue(self)

    def do_execute(self):
        # Not add gdb to trunk builds
        if '/trunk/' not in self.ctx.get(consts.ARCADIA_URL_KEY, ''):
            gdb.append_to_release(self)

        CommonBuildTask.do_execute(self)

    def on_release(self, additional_parameters):
        # FIXME(mvel) a bit of copypaste here
        if nanny.STARTREK_TICKET_IDS_KEY in additional_parameters:
            self.ctx[nanny.STARTREK_TICKET_IDS_KEY] = additional_parameters[nanny.STARTREK_TICKET_IDS_KEY]

        nanny.ReleaseToNannyTask.on_release(self, additional_parameters)


__Task__ = BuildTtsServerGpu
