from . import resources  # noqa


# -*- coding: utf-8 -*-
import json
import logging
from sandbox import sdk2
import sandbox.common.types.resource as ctr
from sandbox.sdk2.helpers import subprocess as sp

from sandbox.projects.adv_machine.common import process_wrapper
from sandbox.projects.ads.tsar.lib import make_difacto_dump, get_last_resource
from sandbox.projects.ads.tsar.lib.resources import *   # noqa


logger = logging.getLogger(__name__)


NETWORK_DSSM_MODEL = 1063609024
IDENTITY_50_PROJECTOR = 1063622526

SEARCH_DSSM_MODEL = 1243166002
TORCH_PRODUCTION_MODEL_ID = 1


class TsarTransport(sdk2.Task):
    """Transport tsar to bs server"""
    PRODUCTION_TASK_ID = "robdrynkin_bsdev_75076_prod_v3"

    class Requirements(sdk2.Task.Requirements):
        cores = 1

        class Caches(sdk2.Requirements.Caches):
            pass

    class Parameters(sdk2.Task.Parameters):

        with sdk2.parameters.Group('YT parameters') as yt_block:
            yt_proxy = sdk2.parameters.String('YT proxy', required=True, default='hahn')
            yt_token_vault = sdk2.parameters.String('YT_TOKEN vault name', required=True, default='robot_tsar_token')
            yt_pool = sdk2.parameters.String('YT_POOL', required=True, default='robot-tsar')

        with sdk2.parameters.Group('Resources') as resources_block:
            tsar_banner_transport_binary = sdk2.parameters.LastReleasedResource(
                'tsar_binary',
                resource_type=resources.TsarBannerTransportBinary,
                state=(ctr.State.READY,),
                required=True,
            )

        with sdk2.parameters.Group('TransportParameters') as transport_parameters_block:
            cluster = sdk2.parameters.String('Cluster')
            input_table = sdk2.parameters.String('AdvMachineExportTable')
            genocide_table = sdk2.parameters.String('Genocide results table')
            output_table = sdk2.parameters.String('Destination table path')
            scoring_data_size_per_job_gb = sdk2.parameters.Integer('ScoringDataSizePerJobGB')
            scoring_mapper_memory_limit_gb = sdk2.parameters.Integer('ScoringMapperMemoryLimitGB')
            validation_table = sdk2.parameters.String('ValidationTable')
            metrics_tolerance = sdk2.parameters.Float('MetricsTolerance')
            genocide_result_table_name = sdk2.parameters.String('GenocideResultTableName')
            yt_pool = sdk2.parameters.String('YtPool')
            versions_prefix = sdk2.parameters.String(
                "VersionsPrefix",
                default="//home/yabs-cs/ads_quality/versions",
                required=False
            )
            target_domain_table = sdk2.parameters.String(
                "TargetDomainTable",
                default="//home/yabs/dict/TargetDomain",
                required=False
            )
            acceptable_data_weight_deviation = sdk2.parameters.Float(
                "AcceptableDataWeightDeviation",
                default=0.3,
                required=False
            )
            apply_genocide = sdk2.parameters.Bool(
                "ApplyGenocideOnTensorTransport",
                default=True,
                required=False
            )
            filter_non_active_banners = sdk2.parameters.Bool(
                "FilterNonActiveBannerOnTransport",
                default=True,
                required=False
            )

            apply_pytorch_model = sdk2.parameters.Bool(
                "ApplyPytorchModel",
                default=False,
                required=False
            )

            apply_search_dssm_model = sdk2.parameters.Bool(
                "ApplySearchDSSMModel",
                default=False,
                required=False
            )

    def on_execute(self):
        tsar_banner_transport_binary_path = str(sdk2.ResourceData(self.Parameters.tsar_banner_transport_binary).path)

        difacto_tsar_model = make_difacto_dump(self.PRODUCTION_TASK_ID)
        dump_params = {
            "type": "TORCH_TSAR_BANNER_MODEL",
            "state": "READY",
            "attrs": {
                "model_id": TORCH_PRODUCTION_MODEL_ID,
                "released": "stable",
            },
        }
        torch_tsar_model = get_last_resource(dump_params)

        network_dssm_model = str(sdk2.ResourceData(sdk2.Resource.find(id=NETWORK_DSSM_MODEL).first()).path)
        identity_projector = str(sdk2.ResourceData(sdk2.Resource.find(id=IDENTITY_50_PROJECTOR).first()).path)
        search_dssm_model = str(sdk2.ResourceData(sdk2.Resource.find(id=SEARCH_DSSM_MODEL).first()).path)
        conf = self.make_config(
            last_tsar_dump_path=difacto_tsar_model,
            network_dssm_model=network_dssm_model,
            identity_projector=identity_projector,
            torch_tsar_model=torch_tsar_model,
            search_dssm_model=search_dssm_model
        )

        cmd = [
            tsar_banner_transport_binary_path,
            "--config", conf
        ]

        env = {'MR_RUNTIME': 'YT'}
        if self.Parameters.yt_token_vault:
            env['YT_TOKEN'] = sdk2.Vault.data(self.Parameters.yt_token_vault)
        if self.Parameters.yt_pool:
            env['YT_POOL'] = self.Parameters.yt_pool

        with process_wrapper(self, logger='banner_transport') as pl:
            sp.check_call(cmd, stdout=pl.stdout, stderr=pl.stderr, env=env)

    def make_config(self, last_tsar_dump_path, network_dssm_model, identity_projector, torch_tsar_model, search_dssm_model):
        config_path = "config.json"
        models = [
            {
                "ModelType": 0,
                "Version": 2,
                "VectorSize": 51,
                "MinValue": -1.0,
                "MaxValue": 1.0,
                "ModelPath": last_tsar_dump_path
            },
            {
                "ModelType": 1,
                "Version": 1,
                "VectorSize": 50,
                "MinValue": -1.0,
                "MaxValue": 1.0,
                "ModelPath": network_dssm_model,
                "ProjectorPath": identity_projector
            },
        ]

        if self.Parameters.apply_pytorch_model:
            models += [
                {
                    "ModelType": 2,
                    "Version": 3,
                    "VectorSize": 51,
                    "MinValue": -1.0,
                    "MaxValue": 1.0,
                    "ModelPath": torch_tsar_model
                }
            ]
        if self.Parameters.apply_search_dssm_model:
            models += [
                {
                    "ModelType": 1,
                    "Version": 4,
                    "VectorSize": 50,
                    "MinValue": -1.0,
                    "MaxValue": 1.0,
                    "ModelPath": search_dssm_model,
                    "ProjectorPath": identity_projector,
                    "ModelDestination": 1
                },
            ]

        config = {
            "Cluster": self.Parameters.cluster,
            "AdvMachineExportTable": self.Parameters.input_table,
            "GenocideTable": self.Parameters.genocide_table,
            "OutputTable": self.Parameters.output_table,
            "ScoringDataSizePerJobGB": self.Parameters.scoring_data_size_per_job_gb,
            "ScoringMapperMemoryLimitGB": self.Parameters.scoring_mapper_memory_limit_gb,
            "MetricTable": self.Parameters.validation_table,
            "MetricsTolerance": self.Parameters.metrics_tolerance,
            "GenocideResultTableName": self.Parameters.genocide_result_table_name,
            "YtPool": self.Parameters.yt_pool,
            "ModelFolder": "LEGACY_FIELD",
            "Models": models,
            "AcceptableDataWeightDeviation": self.Parameters.acceptable_data_weight_deviation,
            "VersionsPrefix": self.Parameters.versions_prefix,
            "TargetDomainTableName": self.Parameters.target_domain_table,
            "ApplyGenocide": self.Parameters.apply_genocide,
            "FilterNonActiveBanners": self.Parameters.filter_non_active_banners
        }

        logging.info(config)
        json.dump(config, open(config_path, 'w'))
        return config_path
