import logging
import os
import re
import shlex
import tempfile

from google.protobuf import text_format

from sandbox import sdk2
from sandbox.common.errors import TaskFailure
from sandbox.common.utils import Enum
from sandbox.sandboxsdk.svn import Arcadia
from sandbox.sdk2.helpers import subprocess as sp
from sandbox.projects import resource_types as rt
from sandbox.projects.common import error_handlers as eh
from sandbox.projects.common import file_utils as fu
from sandbox.projects.common.arcadia import sdk as arcadia_sdk
from sandbox.projects.common.constants import constants as sdk_constants
from sandbox.projects.common import binary_task
from sandbox.projects.WizardRuntimeBuild.ya_make.YaMake import YaMake
from sandbox.projects.common.build import parameters as build_parameters


TESTS_DIR = "quality/neural_net/build_models_for_runtime/tests"
CANONIZE_TESTS_DIR = os.path.join(TESTS_DIR, "canonize")
TESTS_INPUT_DIR = os.path.join(TESTS_DIR, "input")

ENV_VAR_PATTERN = r'^\s*(\b[a-zA-Z_]\w*=((\'[^\']*\')|(\"[^\"]*\")|([^\s\'\"]+))(\s+|$))+$'
VAULT_PATTERN = r'\$\(vault:(?P<dst>file|value):(?P<owner>[^:]+):(?P<name>[^:]+)\)'


class RunMode(Enum):
    ADD = "add"
    REMOVE = "remove"
    CANONIZE_TESTS = "canonize"


class CommitMode(Enum):
    DRY_RUN = "dry_run"
    CREATE_REVIEW = "create_review"
    COMMIT_CHANGES = "commit_changes"


def get_name_with_version(name, version):
    return name + "_v_" + version


def get_model_filename(name, version):
    return get_name_with_version(name, version) + ".dssm"


class BuildConfigsUpdater(object):
    def __init__(self, model_directory):
        self.ya_make_path = os.path.join(model_directory, "ya.make")
        self.model_config_path = os.path.join(model_directory, "nn_applier.config.in")
        self.update_config_path = os.path.join(model_directory, "update.config.in")

    def get_model_name(self):
        return self._load_model_config().Name

    def add_model_build_configs(self, model_version, model_resource_id):
        model_config = self._load_model_config()
        model_filename = get_model_filename(model_config.Name, model_version)

        update_config = self._load_update_config()
        found = False
        external_layers = []

        if len(update_config.WebBaseBuildConfigs) > 1:
            raise TaskFailure("Update config for model %s contains more than one WebBaseBuildConfig" %
                              model_config.Name)
        for build_config in update_config.WebBaseBuildConfigs:
            found = True
            build_config.Version = model_version
            external_layers += self._get_external_layers(build_config, True)
            model_config.WebBaseBuildConfigs.add().CopyFrom(build_config)

        if len(update_config.WebBegemotBuildConfigs) > 1:
            raise TaskFailure("Update config for model %s contains more than one WebBegemotBuildConfig" %
                              model_config.Name)
        for build_config in update_config.WebBegemotBuildConfigs:
            found = True
            build_config.Version = model_version
            external_layers += self._get_external_layers(build_config, True)
            model_config.WebBegemotBuildConfigs.add().CopyFrom(build_config)

        if len(update_config.WebMMetaBuildConfigs) > 1:
            raise TaskFailure("Update config for model %s contains more than one WebMMetaBuildConfig" %
                              model_config.Name)
        for build_config in update_config.WebMMetaBuildConfigs:
            found = True
            build_config.Version = model_version
            external_layers += self._get_external_layers(build_config, True)
            model_config.WebMMetaBuildConfigs.add().CopyFrom(build_config)

        if len(update_config.WebRthubBuildConfigs) > 1:
            raise TaskFailure("Update config for model %s contains more than one WebRthubBuildConfig" %
                              model_config.Name)
        for build_config in update_config.WebRthubBuildConfigs:
            found = True
            build_config.Version = model_version
            external_layers += self._get_external_layers(build_config, True)
            model_config.WebRthubBuildConfigs.add().CopyFrom(build_config)

        if len(update_config.WebRthubMigrationBuildConfigs) > 1:
            raise TaskFailure("Update config for model %s contains more than one WebRthubMigrationBuildConfig" %
                              model_config.Name)
        for build_config in update_config.WebRthubMigrationBuildConfigs:
            found = True
            build_config.Version = model_version
            external_layers += self._get_external_layers(build_config, True)
            model_config.WebRthubMigrationBuildConfigs.add().CopyFrom(build_config)

        if len(update_config.WebRuntimeModelsBuildConfigs) > 1:
            raise TaskFailure("Update config for model %s contains more than one WebRuntimeModelsBuildConfig" %
                              model_config.Name)
        for build_config in update_config.WebRuntimeModelsBuildConfigs:
            found = True
            build_config.Version = model_version
            external_layers += self._get_external_layers(build_config, True)
            model_config.WebRuntimeModelsBuildConfigs.add().CopyFrom(build_config)

        if not found:
            raise TaskFailure("Not found any templates of build configs in update config for model %s" %
                              model_config.Name)

        ya_make = self._load_ya_make()
        if ya_make.get_sandboxed(model_filename).resource != 0:
            raise TaskFailure("%s file is already present in ya.make, can't add new one" % model_filename)
        ya_make.update_sandbox_resource(model_filename, model_resource_id, None)

        self._dump_model_config(model_config)
        self._dump_ya_make(ya_make)
        return external_layers

    def remove_model_build_configs(self, model_version):
        model_config = self._load_model_config()
        model_filename = model_config.Name + "_v_" + model_version + ".dssm"

        found = False
        external_layers = []
        for build_config in model_config.WebBaseBuildConfigs:
            if build_config.Version == model_version:
                found = True
                external_layers += self._get_external_layers(build_config, False)
                model_config.WebBaseBuildConfigs.remove(build_config)
        for build_config in model_config.WebBegemotBuildConfigs:
            if build_config.Version == model_version:
                found = True
                external_layers += self._get_external_layers(build_config, False)
                model_config.WebBegemotBuildConfigs.remove(build_config)
        for build_config in model_config.WebMMetaBuildConfigs:
            if build_config.Version == model_version:
                found = True
                external_layers += self._get_external_layers(build_config, False)
                model_config.WebMMetaBuildConfigs.remove(build_config)
        for build_config in model_config.WebRthubBuildConfigs:
            if build_config.Version == model_version:
                found = True
                external_layers += self._get_external_layers(build_config, False)
                model_config.WebRthubBuildConfigs.remove(build_config)
        for build_config in model_config.WebRthubMigrationBuildConfigs:
            if build_config.Version == model_version:
                found = True
                external_layers += self._get_external_layers(build_config, False)
                model_config.WebRthubMigrationBuildConfigs.remove(build_config)
        for build_config in model_config.WebRuntimeModelsBuildConfigs:
            if build_config.Version == model_version:
                found = True
                external_layers += self._get_external_layers(build_config, False)
                model_config.WebRuntimeModelsBuildConfigs.remove(build_config)
        if not found:
            raise TaskFailure("Not found any build configs for %s" % model_filename)

        ya_make = self._load_ya_make()
        if ya_make.get_sandboxed(model_filename).resource == 0:
            raise TaskFailure("Not found %s in ya.make, can't remove it" % model_filename)
        ya_make.remove_sandbox_resource(model_filename)

        self._dump_model_config(model_config)
        self._dump_ya_make(ya_make)
        return external_layers

    def _get_external_layers(self, build_config, need_size):
        external_layers = []
        for external_layer in build_config.ExternalLayers:
            external_layers += [(
                get_name_with_version(external_layer.To, build_config.Version),
                external_layer.ExpectedSize if need_size else 0
            )]
        return external_layers

    def _load_model_config(self):
        from quality.neural_net.build_models_for_runtime.protos.model_config_pb2 import TModelConfig as ModelConfig

        model_config = ModelConfig()
        with open(self.model_config_path, "r") as model_config_file:
            text_format.Parse(model_config_file.read(), model_config)

        return model_config

    def _dump_model_config(self, model_config):
        with open(self.model_config_path, "w") as model_config_file:
            text_format.PrintMessage(model_config, model_config_file, as_utf8=True)

    def _load_update_config(self):
        from quality.neural_net.build_models_for_runtime.protos.model_config_pb2 import TModelConfig as ModelConfig

        update_config = ModelConfig()
        with open(self.update_config_path, "r") as update_config_file:
            text_format.Parse(update_config_file.read(), update_config)

        return update_config

    def _load_ya_make(self):
        return YaMake(self.ya_make_path)

    def _dump_ya_make(self, ya_make):
        with open(self.ya_make_path, "w") as ya_make_file:
            ya_make.dump(ya_make_file)


class BundleConfigUpdater(object):
    def __init__(self, bundle_config_path):
        self.bundle_config_path = bundle_config_path

    def add_new_model(self, model_name, model_version, flag_version):
        def add_submodel_to_flag_config(flag_config):
            for used_model in flag_config.UsedModels:
                if used_model.Name == model_name:
                    candidates_to_remove.append(used_model)
                    flag_config.UsedModels.remove(used_model)
                    self._remove_extra_submodels_infos(bundle_config, candidates_to_remove)
            used_model = flag_config.UsedModels.add()
            used_model.Name = model_name
            used_model.Version = model_version

        bundle_config = self._load_bundle_config()

        found = False
        last_flag_config = None
        candidates_to_remove = []
        for flag_config in bundle_config.FlagsConfigs:
            if flag_config.Id == flag_version:
                found = True
                add_submodel_to_flag_config(flag_config)
            if last_flag_config is None or last_flag_config.Id < flag_config.Id:
                last_flag_config = flag_config
        if not found:
            flag_config = bundle_config.FlagsConfigs.add()
            flag_config.CopyFrom(last_flag_config)
            flag_config.Default = False
            flag_config.Id = flag_version
            add_submodel_to_flag_config(flag_config)

        new_submodel_info = bundle_config.SubmodelsInfos.add()
        new_submodel_info.Name = model_name
        new_submodel_info.Version = model_version

        self._dump_bundle_config(bundle_config)
        return candidates_to_remove

    def change_default_flag_config(self, new_default_flag_version):
        bundle_config = self._load_bundle_config()

        found = False
        for flag_config in bundle_config.FlagsConfigs:
            if flag_config.Default:
                flag_config.Default = False
            if flag_config.Id == new_default_flag_version:
                found = True
                flag_config.Default = True
        if not found:
            raise TaskFailure("Bundle config doesn't contain flag config with version %s" % new_default_flag_version)

        self._dump_bundle_config(bundle_config)

    def remove_flag_config(self, flag_version):
        bundle_config = self._load_bundle_config()

        candidates_to_remove = []
        for flag_config in bundle_config.FlagsConfigs:
            if flag_config.Id == flag_version:
                candidates_to_remove = list(flag_config.UsedModels)
                bundle_config.FlagsConfigs.remove(flag_config)
        self._remove_extra_submodels_infos(bundle_config, candidates_to_remove)

        self._dump_bundle_config(bundle_config)
        return candidates_to_remove

    def _remove_extra_submodels_infos(self, bundle_config, candidates_to_remove):
        for flag_config in bundle_config.FlagsConfigs:
            for used_model in flag_config.UsedModels:
                if used_model in candidates_to_remove:
                    candidates_to_remove.remove(used_model)
        for candidate_to_remove in candidates_to_remove:
            for submodel_info in bundle_config.SubmodelsInfos:
                if candidate_to_remove == submodel_info:
                    bundle_config.SubmodelsInfos.remove(submodel_info)

    def _load_bundle_config(self):
        from kernel.dssm_applier.optimized_model.protos.bundle_config_pb2 import TModelsBundleConfig as BundleConfig

        bundle_config = BundleConfig()
        with open(self.bundle_config_path, "r") as bundle_config_file:
            text_format.Parse(bundle_config_file.read(), bundle_config)

        return bundle_config

    def _dump_bundle_config(self, bundle_config):
        with open(self.bundle_config_path, "w") as bundle_config_file:
            text_format.PrintMessage(bundle_config, bundle_config_file, as_utf8=True)


class UpdateOptimizedModelConfigs(binary_task.LastBinaryTaskRelease, sdk2.Task):
    __logger = logging.getLogger("TASK_LOGGER")
    __logger.setLevel(logging.DEBUG)

    class Parameters(sdk2.Parameters):
        arcadia_url = sdk2.parameters.ArcadiaUrl(
            "Arcadia url",
            default_value=Arcadia.ARCADIA_TRUNK_URL,
        )
        bundle_config_path = sdk2.parameters.String(
            "Path to bundle.config.in",
            description="Arcadia path to models bundle config",
            required=True,
        )
        models_directories = sdk2.parameters.List(
            "Paths to models directories",
            description="Arcadia paths to directories with models and configs (strictly one path is needed in addition mode)",
            value_type=sdk2.parameters.String,
            default=[],
        )
        with sdk2.parameters.String("Mode") as run_mode:
            run_mode.values[RunMode.ADD] = run_mode.Value(value=RunMode.ADD)
            run_mode.values[RunMode.REMOVE] = run_mode.Value(value=RunMode.REMOVE)
            run_mode.values[RunMode.CANONIZE_TESTS] = run_mode.Value(value=RunMode.CANONIZE_TESTS, default=True)
        flag_version = sdk2.parameters.String(
            "Flag version",
            description="Flag version to add new version of model to / to remove",
            required=True,
        )
        with run_mode.value[RunMode.ADD]:
            model_resource = sdk2.parameters.Resource(
                "Model resource",
                description="Sandbox resource with model of new version in nn_applier format",
                resource_type=rt.DSSM_MODEL,
                required=False,
            )
            model_version = sdk2.parameters.String(
                "Model version",
                description="New version of model",
                required=True,
            )
        with run_mode.value[RunMode.REMOVE]:
            new_default_flag_version = sdk2.parameters.String(
                "New default flag version",
                description="If given, will set this flag version as default.",
                required=False,
            )
        with sdk2.parameters.String("Commit mode") as commit_mode:
            commit_mode.values[CommitMode.DRY_RUN] = commit_mode.Value(value=CommitMode.DRY_RUN, default=True)
            commit_mode.values[CommitMode.CREATE_REVIEW] = commit_mode.Value(value=CommitMode.CREATE_REVIEW)
            commit_mode.values[CommitMode.COMMIT_CHANGES] = commit_mode.Value(value=CommitMode.COMMIT_CHANGES)
        with commit_mode.value[CommitMode.CREATE_REVIEW], commit_mode.value[CommitMode.COMMIT_CHANGES]:
            commit_message = sdk2.parameters.String(
                "Commit message",
                default_value="Updating optimized_model.dssm configs",
            )
            commit_user = sdk2.parameters.Staff(
                "Commit user",
                default_value="zomb-sandbox-rw",
            )
        cpu = sdk2.parameters.Integer(
            "Required number of cpu cores",
            default_value=30,
        )
        disk_space = sdk2.parameters.Integer(
            "Required disk space (in GB)",
            default_value=32,
        )
        ram = sdk2.parameters.Integer(
            "Required RAM size (in GB)",
            default_value=128,
        )
        env_vars = build_parameters.EnvironmentVarsParam()
        tasks_archive_resource = binary_task.binary_release_parameters(stable=True)

    def deref(self, s):

        def deref_vault(match):
            secret = sdk2.Vault.data(match.group('owner'), match.group('name'))
            if match.group('dst') == 'file':
                deref_path = tempfile.NamedTemporaryFile().name
                fu.write_file(deref_path, secret)
                return deref_path
            return secret

        s = re.sub(VAULT_PATTERN, deref_vault, s)
        return s

    def get_env_vars(self):
        env_vars = self.Parameters.env_vars
        if env_vars and not re.match(ENV_VAR_PATTERN, env_vars):
            eh.check_failed("Incorrect 'Environment variables' parameter '{}'".format(env_vars))
        result = {k: self.deref(v) for k, v in (x.split('=', 1) for x in shlex.split(env_vars))}
        return result

    class Context(sdk2.Context):
        precanonize_diff = ""
        final_diff = ""
        created_review = ""

    def on_enqueue(self):
        self.Requirements.disk_space = self.Parameters.disk_space * 1024
        self.Requirements.ram = self.Parameters.ram * 1024
        self.Requirements.cores = self.Parameters.cpu

    def on_execute(self):
        env = self.get_env_vars()
        binary_task.LastBinaryTaskRelease.on_execute(self)
        self.__logger.info("Preparing arcadia directory")

        arcadia_dir = arcadia_sdk.do_clone(self.Parameters.arcadia_url, self)
        arcadia_sdk.do_build(
            build_system=sdk_constants.YMAKE_BUILD_SYSTEM,
            source_root=arcadia_dir,
            targets=self.Parameters.models_directories + [os.path.dirname(self.Parameters.bundle_config_path), TESTS_INPUT_DIR],
            results_dir=arcadia_dir,
            clear_build=False,
            checkout=True,
            env=env,
        )

        self.__logger.info("Updating configs")

        bundle_config_updater = BundleConfigUpdater(os.path.join(arcadia_dir, self.Parameters.bundle_config_path))

        external_layers = []
        models_to_remove = []
        if self.Parameters.run_mode == RunMode.ADD:
            if len(self.Parameters.models_directories) != 1:
                raise TaskFailure("Expected strictly one model directory in addition mode")
            build_configs_updater = BuildConfigsUpdater(os.path.join(arcadia_dir, self.Parameters.models_directories[0]))
            external_layers += build_configs_updater.add_model_build_configs(str(self.Parameters.model_version), self.Parameters.model_resource.id)
            models_to_remove = bundle_config_updater.add_new_model(build_configs_updater.get_model_name(), str(self.Parameters.model_version), str(self.Parameters.flag_version))
        elif self.Parameters.run_mode == RunMode.REMOVE:
            if self.Parameters.new_default_flag_version:
                bundle_config_updater.change_default_flag_config(str(self.Parameters.new_default_flag_version))
            models_to_remove = bundle_config_updater.remove_flag_config(str(self.Parameters.flag_version))

        for model_directory in self.Parameters.models_directories:
            build_configs_updater = BuildConfigsUpdater(os.path.join(arcadia_dir, model_directory))
            for model_to_remove in models_to_remove:
                if model_to_remove.Name == build_configs_updater.get_model_name():
                    external_layers += build_configs_updater.remove_model_build_configs(model_to_remove.Version)
                    models_to_remove.remove(model_to_remove)
        if len(models_to_remove) != 0:
            raise TaskFailure("Not all extra models were removed from build configs")

        self.__logger.info("Modifying tests input according to new build configs")
        for external_layer_name, external_layer_size in external_layers:
            with sdk2.helpers.ProcessLog(self, logger="tests_input_gen") as tests_input_gen_pl:
                sp.check_call([
                    os.path.join(arcadia_dir, TESTS_INPUT_DIR, "input_gen"),
                    "-i", os.path.join(arcadia_dir, TESTS_INPUT_DIR, "inputs.yson"),
                    "-o", os.path.join(arcadia_dir, TESTS_INPUT_DIR, "inputs.yson"),
                    "-n", external_layer_name,
                    "-s", str(external_layer_size),
                ], stdout=tests_input_gen_pl.stdout, stderr=tests_input_gen_pl.stderr)
        self.Context.precanonize_diff = Arcadia.diff(url=arcadia_dir)

        self.__logger.info("Canonizing tests")
        with sdk2.helpers.ProcessLog(self, logger="tests_canonize") as tests_canonize_pl:  # TODO(filmih@): replace with arcadia_skd.do_build(...) after DEVTOOLS-4160
            ya = os.path.join(arcadia_dir, 'ya')
            sp.check_call([ya, "make", "-rAZ", "--checkout", os.path.join(arcadia_dir, CANONIZE_TESTS_DIR)],
                          stdout=tests_canonize_pl.stdout, stderr=tests_canonize_pl.stderr, env=env)

        self.Context.final_diff = Arcadia.diff(url=arcadia_dir)
        if self.Parameters.commit_mode != CommitMode.DRY_RUN:
            self.__logger.info("Commiting changes" if self.Parameters.commit_mode == CommitMode.COMMIT_CHANGES else "Creating review with changes")
            revprops = ["arcanum:check-skip=yes"] if self.Parameters.commit_mode == CommitMode.COMMIT_CHANGES else ["arcanum:review=new", "arcanum:review-publish=yes"]
            try:
                Arcadia.commit(arcadia_dir, self.Parameters.commit_message, user=self.Parameters.commit_user, with_revprop=revprops)
            except Exception as exc:
                self.Context.created_review = re.findall(
                    r"Check status can be monitored using this special review request: ([0-9]+)",
                    str(exc),
                )
        else:
            self.__logger.info("Dry run: will not commit any changes")
