import logging
import os

from google.protobuf import text_format

from sandbox import sdk2
from sandbox.common.utils import Enum
from sandbox.projects.neural_net.resources import resources as nn_resources
from sandbox.projects.jupiter.ReleaseExtFiles import JupiterOptimizedRthubBundle
from sandbox.projects.common import context_managers
from sandbox.projects.common.nanny import nanny
from sandbox.projects.common.arcadia import sdk as arcadia_sdk
from sandbox.projects.common.constants import constants as sdk_constants
from sandbox.projects.common import binary_task
from sandbox.sdk2.helpers import subprocess
from sandbox.sdk2.vcs.svn import Arcadia


BASE_MODEL_FILE = "optimized_model.dssm"
BEGEMOT_MODELS_FILE = "begemot_optimized_models.tar"
MMETA_MODEL_FILE = "mmeta_optimized_model.dssm"
RUNTIME_MODELS_MODEL_FILE = "runtime_models_optimized_model.dssm"
SHMICK_BASE_MODEL_FILE = "shmick_base_optimized_model.dssm"

OPTIMIZED_MODEL_BUILDER_DIR = "quality/neural_net/build_models_for_runtime"
OPTIMIZED_MODEL_BUILDER_BINARY_PATH = os.path.join(OPTIMIZED_MODEL_BUILDER_DIR, "build_models_for_runtime")

META_INFO_FILE = "optimized_model_meta.txt"

# Variables for rthub search component.
RTHUB_MODEL_FILE = "optimized_rthub_model.dssm"
RTHUB_PACK_ID = "WEB_RTHUB_OPTIMIZED_MODELS_DSSM_PACK"

# Variables for rthub_migration search component.
RTHUB_MIGRATION_PACK_ID = "JUPITER_OPTIMIZED_RTHUB_BUNDLE"


class SearchComponent(Enum):
    BASE = "base"
    BEGEMOT = "begemot"
    MMETA = "mmeta"
    RTHUB = "rthub"
    RTHUB_MIGRATION = "rthubmigration"
    RUNTIME_MODELS = "runtime_models"
    SHMICK_BASE = "shmick_base"


class BuildOptimizedModel(nanny.ReleaseToNannyTask2, binary_task.LastBinaryTaskRelease, sdk2.Task):
    __logger = logging.getLogger("TASK_LOGGER")
    __logger.setLevel(logging.DEBUG)

    class Parameters(sdk2.Parameters):
        arcadia_url = sdk2.parameters.ArcadiaUrl(
            "Arcadia url",
            default_value=Arcadia.ARCADIA_TRUNK_URL,
        )
        arcadia_patch = sdk2.parameters.String(
            "Apply patch (diff file rbtorrent, paste.y-t.ru link or plain text). Doc: https://nda.ya.ru/3QTTV4",
            multiline=True,
            default="",
        )
        with sdk2.parameters.String("Search component") as search_component:
            search_component.values[SearchComponent.BASE] = search_component.Value(value=SearchComponent.BASE, default=True)
            search_component.values[SearchComponent.BEGEMOT] = search_component.Value(value=SearchComponent.BEGEMOT)
            search_component.values[SearchComponent.MMETA] = search_component.Value(value=SearchComponent.MMETA)
            search_component.values[SearchComponent.RTHUB] = search_component.Value(value=SearchComponent.RTHUB)
            search_component.values[SearchComponent.RTHUB_MIGRATION] = search_component.Value(value=SearchComponent.RTHUB_MIGRATION)
            search_component.values[SearchComponent.RUNTIME_MODELS] = search_component.Value(value=SearchComponent.RUNTIME_MODELS)
            search_component.values[SearchComponent.SHMICK_BASE] = search_component.Value(value=SearchComponent.SHMICK_BASE)
        bundle_config_path = sdk2.parameters.String(
            "Path to bundle.config.in",
            description="Arcadia path to models bundle config",
            required=True,
        )
        threads_count = sdk2.parameters.Integer(
            "Threads count",
            description="Count of threads to use for model compression stage when building optimized_model.dssm",
            default_value=4,
        )
        disk_space = sdk2.parameters.Integer(
            "Required disk space (in GB)",
            default_value=32,
        )
        ram = sdk2.parameters.Integer(
            "Required RAM size (in GB)",
            default_value=64,
        )
        tasks_archive_resource = binary_task.binary_release_parameters(stable=True)

    def get_arcadia_src_dir(self):
        if arcadia_sdk.fuse_available():
            return arcadia_sdk.mount_arc_path(self.Parameters.arcadia_url, fallback=True)
        return context_managers.nullcontext(Arcadia.get_arcadia_src_dir(self.Parameters.arcadia_url))

    def on_enqueue(self):
        self.Requirements.disk_space = self.Parameters.disk_space * 1024
        self.Requirements.ram = self.Parameters.ram * 1024
        self.Requirements.cores = self.Parameters.threads_count

    def on_execute(self):
        if self.Parameters.search_component == SearchComponent.BASE:
            optimized_model = nn_resources.WEB_BASE_OPTIMIZED_MODEL_DSSM(
                self,
                "Web basesearch optimized_model.dssm",
                BASE_MODEL_FILE,
            )
        elif self.Parameters.search_component == SearchComponent.BEGEMOT:
            optimized_models_pack = nn_resources.WEB_BEGEMOT_OPTIMIZED_MODELS_DSSM_PACK(
                self,
                "Archive with web begemot dssm optimized models",
                BEGEMOT_MODELS_FILE,
            )
        elif self.Parameters.search_component == SearchComponent.MMETA:
            optimized_model = nn_resources.WEB_MMETA_OPTIMIZED_MODEL_DSSM(
                self,
                "Web middlesearch optimized dssm model",
                MMETA_MODEL_FILE,
            )
        elif self.Parameters.search_component == SearchComponent.RTHUB:
            optimized_models_pack = nn_resources.WEB_RTHUB_OPTIMIZED_MODELS_DSSM_PACK(
                self,
                "Archive with optimized dssm model for udf at rthub",
                RTHUB_PACK_ID + ".tar",
            )
        elif self.Parameters.search_component == SearchComponent.RTHUB_MIGRATION:
            optimized_models_pack = JupiterOptimizedRthubBundle(
                self,
                "Folder with optimized dssm model for rthub migration tool at jupiter",
                RTHUB_MIGRATION_PACK_ID,
            )
        elif self.Parameters.search_component == SearchComponent.RUNTIME_MODELS:
            optimized_model = nn_resources.WEB_RUNTIME_MODELS_OPTIMIZED_MODEL_DSSM(
                self,
                "Web runtime models optimized dssm model",
                RUNTIME_MODELS_MODEL_FILE,
            )
        elif self.Parameters.search_component == SearchComponent.SHMICK_BASE:
            optimized_model = nn_resources.WEB_SHMICK_BASE_OPTIMIZED_MODEL_DSSM(
                self,
                "Shmick basesearch optimized_model.dssm",
                SHMICK_BASE_MODEL_FILE,
            )

        self.__logger.info("Preparing arcadia directory: build build_models_for_runtime"
                           " binary and models directories")
        with self.get_arcadia_src_dir() as arcadia_dir:
            if self.Parameters.arcadia_patch:
                Arcadia.apply_patch(arcadia_dir, self.Parameters.arcadia_patch, self.path())
            arcadia_sdk.do_build(
                build_system=sdk_constants.YMAKE_BUILD_SYSTEM,
                source_root=arcadia_dir,
                targets=[os.path.dirname(self.Parameters.bundle_config_path)],
                results_dir=arcadia_dir,
                clear_build=False,
                ignore_recurses=True,
            )
            arcadia_sdk.do_build(
                build_system=sdk_constants.YMAKE_BUILD_SYSTEM,
                source_root=arcadia_dir,
                targets=[OPTIMIZED_MODEL_BUILDER_DIR] + get_models_directories(arcadia_dir, self.Parameters.bundle_config_path),
                results_dir=arcadia_dir,
                clear_build=False,
            )

            self.__logger.info("Building optimized models for search component %s" % self.Parameters.search_component)
            build_models_for_runtime_binary = os.path.join(arcadia_dir, OPTIMIZED_MODEL_BUILDER_BINARY_PATH)
            command = [
                build_models_for_runtime_binary,
                "--search-component", self.Parameters.search_component,
                "--save-separate-meta-info",
                "-j", str(self.Parameters.threads_count),
                "-b", os.path.join(arcadia_dir, self.Parameters.bundle_config_path),
            ]
            with sdk2.helpers.ProcessLog(self, logger="optimized_model_builder") as optimized_model_builder_pl:
                subprocess.check_call(command, stdout=optimized_model_builder_pl.stdout,
                                      stderr=optimized_model_builder_pl.stderr)

        if self.Parameters.search_component == SearchComponent.BASE:
            set_meta_info(optimized_model)
            sdk2.ResourceData(optimized_model).ready()
        elif self.Parameters.search_component == SearchComponent.BEGEMOT:
            subprocess.check_call(['tar', '-cf', os.path.join('..', BEGEMOT_MODELS_FILE), '.'], cwd="begemot_optimized_models")
            set_meta_info(optimized_models_pack)
            sdk2.ResourceData(optimized_models_pack).ready()
        elif self.Parameters.search_component == SearchComponent.MMETA:
            set_meta_info(optimized_model)
            sdk2.ResourceData(optimized_model).ready()
        elif self.Parameters.search_component == SearchComponent.RTHUB:
            subprocess.check_call(['mkdir', RTHUB_PACK_ID])
            subprocess.check_call(['mv', RTHUB_MODEL_FILE, RTHUB_PACK_ID])
            subprocess.check_call(['tar', '-cf', RTHUB_PACK_ID + ".tar", RTHUB_PACK_ID])
            set_meta_info(optimized_models_pack)
            sdk2.ResourceData(optimized_models_pack).ready()
        elif self.Parameters.search_component == SearchComponent.RTHUB_MIGRATION:
            subprocess.check_call(['mkdir', RTHUB_MIGRATION_PACK_ID])
            subprocess.check_call(['mv', RTHUB_MODEL_FILE, RTHUB_MIGRATION_PACK_ID])
            sdk2.ResourceData(optimized_models_pack).ready()
        elif self.Parameters.search_component == SearchComponent.RUNTIME_MODELS:
            set_meta_info(optimized_model)
            sdk2.ResourceData(optimized_model).ready()
        elif self.Parameters.search_component == SearchComponent.SHMICK_BASE:
            set_meta_info(optimized_model)
            sdk2.ResourceData(optimized_model).ready()

    def on_release(self, additional_parameters):
        sdk2.Task.on_release(self, additional_parameters)
        nanny.ReleaseToNannyTask2.on_release(self, additional_parameters)


def set_meta_info(optimized_model_resource):
    import cyson
    with open(META_INFO_FILE, "r") as meta_info_file:
        metainfo = cyson.loads(meta_info_file.read())
        optimized_model_resource.revision = metainfo["Revision"]
        optimized_model_resource.flags_configs = metainfo["FlagsConfigs"]


def get_models_directories(arcadia_dir, bundle_config_path):
    from kernel.dssm_applier.optimized_model.protos.bundle_config_pb2 import TModelsBundleConfig as BundleConfig

    bundle_config = BundleConfig()
    with open(os.path.join(arcadia_dir, bundle_config_path), "r") as bundle_config_file:
        text_format.Parse(bundle_config_file.read(), bundle_config)

    res = []
    for directory in bundle_config.ModelsDirectories:
        res += [os.path.normpath(os.path.join(os.path.dirname(bundle_config_path), directory))]

    return res
