# -*- coding: utf-8 -*-
import json
import logging
import shutil
import hashlib
import urlparse

from sandbox import sdk2
from sandbox.common import errors
from sandbox.projects.common.vcs import arc
from sandbox.projects.common.nanny import nanny
from sandbox.projects.websearch.cfg_models import resources
from sandbox.sandboxsdk import environments


def _write_optional_strings_array(key, values):
    if not values:
        return ''
    if not isinstance(values, list):
        values = [values]
    return '{}: [{}]\n'.format(key, ','.join('"' + i + '"' for i in values))


class BuildCfgModelsArchive(nanny.ReleaseToNannyTask2, sdk2.Task):
    """
        Создаёт ресурс с моделями и конфигом для сервиса CFG_MODELS.
    """

    class Parameters(sdk2.Task.Parameters):
        with sdk2.parameters.RadioGroup("Take config from:") as config_mode:
            config_mode.values["arcadia"] = config_mode.Value("Arcadia")
            config_mode.values["manual"] = config_mode.Value("Task parameters")
        with config_mode.value["arcadia"]:
            arcadia_path = sdk2.parameters.ArcadiaUrl("Svn or arc url", required=True)
            models_patch = sdk2.parameters.JSON("Add/replace config", default=[], required=True)
            models_remove = sdk2.parameters.List("Sources to remove")
            arc_token = sdk2.parameters.YavSecret("ARC_TOKEN secret identifier")
        with config_mode.value["manual"]:
            models_json = sdk2.parameters.JSON("Config", required=True)
        yt_token = sdk2.parameters.YavSecret("YT_TOKEN secret identifier")
        with sdk2.parameters.Output():
            bundle_resource = sdk2.parameters.Resource("Bundle with models", resource_type=resources.CfgModelsArchive)

    class Requirements(sdk2.Task.Requirements):
        environments = (
            environments.PipEnvironment('yandex-yt', use_wheel=True),
        )

    def on_execute(self):
        loaders = {
            "sbr": self._load_from_sandbox_resource,
            "yt": self._load_from_yt,
            "berts_storage": self._load_from_berts_storage,
        }

        if self.Parameters.config_mode == "arcadia":
            parsed_path = urlparse.urlparse(self.Parameters.arcadia_path)
            if parsed_path.scheme == sdk2.svn.Arcadia.ARCADIA_ARC_SCHEME:
                arc_token = self.Parameters.arc_token.data()[self.Parameters.arc_token.default_key]
                with arc.Arc(arc_oauth_token=arc_token).mount_path('', parsed_path.fragment, fetch_all=False) as mount_path:
                    with open(mount_path + parsed_path.path, 'r') as f:
                        models_json_basetext = f.read()
            else:
                models_json_basetext = sdk2.svn.Arcadia.cat(self.Parameters.arcadia_path)
            logging.info("got config from arcadia: {}".format(models_json_basetext))
            models_json_base = json.loads(models_json_basetext)
            models_json = []
            for model in models_json_base:
                if model["source"] in self.Parameters.models_remove:
                    continue
                if any(model["source"] == i["source"] for i in self.Parameters.models_patch):
                    continue
                models_json.append(model)
            models_json += self.Parameters.models_patch
        else:
            models_json = self.Parameters.models_json

        h = hashlib.md5(json.dumps(models_json).encode('utf-8')).hexdigest()
        prev_resource = resources.CfgModelsArchive.find(state="READY", attrs={"bundle_json_md5": h}).first()
        if prev_resource is not None:
            self.Parameters.bundle_resource = prev_resource
            return

        with self.memoize_stage.create_output_resource:
            self.Parameters.bundle_resource = resources.CfgModelsArchive(self, self.Parameters.description, "cfg_models_bundle")
            self.Parameters.bundle_resource.bundle_json_md5 = h

        self.output_bundle = sdk2.ResourceData(self.Parameters.bundle_resource)
        self.output_bundle.path.mkdir(0o755, parents=True, exist_ok=True)

        config_text = ''
        always_loaded_filelist = ''
        self.loaded_files = set()
        self.yt_token = None
        self.yt_clients = {}
        for idx, model in enumerate(models_json):
            badkeys = ["'{}'".format(key) for key in model if key not in [
                "source", "local_name", "version", "layout", "signal_name", "targets", "max_batch_size", "max_job_size", "only_load_in", "skip_doc_if_omitted_inputs", "cache", "degrade", "comment"]]
            if badkeys:
                raise errors.TaskFailure("Don't know what to do with the key(s) {} in model #{}".format(','.join(badkeys), idx))
            source = model["source"]
            source_local_name = model.get("local_name")
            version = model.get("version", "-")
            layout = model.get("layout")
            signal_name = model.get("signal_name")
            assert signal_name, ("signal_name is required field, at %s" % source)
            targets = model.get("targets", {})
            max_batch_size = model.get("max_batch_size")
            max_job_size = model.get("max_job_size")
            only_load_in = model.get("only_load_in", [])
            skip_doc_if_omitted_inputs = model.get("skip_doc_if_omitted_inputs", [])
            cache = model.get("cache", {})
            degrade = model.get("degrade", {})

            source_scheme, _, source_path = source.partition(':')
            if source_scheme not in loaders:
                raise errors.TaskFailure("Unknown storage type for model source '{}'".format(source))
            source_file = loaders[source_scheme](source_path, source_local_name)
            if not only_load_in:
                always_loaded_filelist += source_file.encode('utf-8') + '\n'
            config_text += 'Model {\n'
            config_text += '    Path: "{}"\n'.format(source_file)
            if layout is not None:
                config_text += '    OpenLayout: "{}"\n'.format(layout)
            if version is not None:
                config_text += '    Version: "{}"\n'.format(version)
            if signal_name is not None:
                config_text += '    SignalName: "{}"\n'.format(signal_name)
            if max_batch_size is not None:
                config_text += '    MaxBatchSize: {}\n'.format(int(max_batch_size))
            if max_job_size is not None:
                config_text += '    MaxJobSize: {}\n'.format(int(max_job_size))
            config_text += _write_optional_strings_array('    OnlyLoadIn', only_load_in)
            config_text += _write_optional_strings_array('    SkipDocIfOmittedInputs', skip_doc_if_omitted_inputs)
            for target_id in sorted(targets.keys()):
                config_text += '    Targets {\n'
                config_text += '        Id: "{}"\n'.format(target_id)
                target = targets[target_id]
                for factor in sorted(target.keys()):
                    predict = target[factor]
                    factor_slice, factor_idx = factor.split(':')
                    factor_idx = int(factor_idx)
                    config_text += '        FactorTargets {\n'
                    config_text += '            PredictName: "{}"\n'.format(predict)
                    config_text += '            TargetSlice: "{}"\n'.format(factor_slice)
                    config_text += '            TargetIndex: {}\n'.format(factor_idx)
                    config_text += '        }\n'
                config_text += '    }\n'
            for cache_ctype in cache:
                config_text += '    Cache {\n'
                if cache_ctype:
                    config_text += '        Ctype: "{}"\n'.format(cache_ctype)
                config_text += _write_optional_strings_array('        Targets', cache[cache_ctype].get('targets'))
                config_text += _write_optional_strings_array('        QueryKeys', cache[cache_ctype].get('query_keys'))
                config_text += _write_optional_strings_array('        DocKeys', cache[cache_ctype].get('doc_keys'))
                config_text += '        MaxCacheSize: {}\n'.format(int(cache[cache_ctype]['max_cache_size']))
                config_text += '    }\n'
            for degrade_ctype in degrade:
                config_text += '    Degrade {\n'
                if degrade_ctype:
                    config_text += '        Ctype: "{}"\n'.format(degrade_ctype)
                stat_decay_time = degrade[degrade_ctype].get('stat_decay_time')
                if stat_decay_time is not None:
                    config_text += '        StatDecayTime: "{}"\n'.format(stat_decay_time)
                gpu = degrade[degrade_ctype].get('gpu')
                if gpu is not None:
                    config_text += '        Gpu {\n'
                    config_text += '            StartDegradationF: {}\n'.format(float(gpu['StartDegradationF']))
                    config_text += '            EndDegradationF: {}\n'.format(float(gpu['EndDegradationF']))
                    config_text += '            MultLowerBound: {}\n'.format(float(gpu['MultLowerBound']))
                    config_text += '            AutoModeType: "exp"\n'
                    config_text += '        }\n'
                gpu_hard_limit = degrade[degrade_ctype].get('gpu_hard_limit')
                if gpu_hard_limit is not None:
                    config_text += '        GpuHardLimit: {}\n'.format(float(gpu_hard_limit))
                config_text += '    }\n'
            config_text += '}\n'

        self.output_bundle.path.joinpath("bundle.cfg").write_bytes(config_text)
        self.output_bundle.path.joinpath("always_loaded.txt").write_bytes(always_loaded_filelist)
        self.output_bundle.ready()

    def _make_unique_source_name(self, source_name):
        if source_name not in self.loaded_files:
            self.loaded_files.add(source_name)
            return source_name
        name, ext1, ext2 = source_name.partition('.')
        dupidx = 2
        while True:
            source_name = name + "_" + str(dupidx) + ext1 + ext2
            if source_name not in self.loaded_files:
                self.loaded_files.add(source_name)
                return source_name
            dupidx += 1

    def _load_from_sandbox_resource(self, source_path, source_local_name):
        resource_id = int(source_path)
        resource = sdk2.Resource[resource_id]
        data = sdk2.ResourceData(resource)
        if not source_local_name:
            source_local_name = data.path.name
        source_local_name = self._make_unique_source_name(source_local_name)
        shutil.copy(str(data.path), str(self.output_bundle.path.joinpath(source_local_name)))
        logging.info('copied sbr:{} -> {}'.format(resource_id, source_local_name))
        return source_local_name

    def _load_from_yt(self, source_path, source_local_name):
        if not source_local_name:
            slash = source_path.rfind('/')
            source_local_name = source_path[slash+1:]  # also works when slash == -1
        source_local_name = self._make_unique_source_name(source_local_name)
        logging.info('downloading from YT: {} -> {}'.format(source_path, source_local_name))
        import yt.wrapper as yt
        if self.yt_token is None:
            self.yt_token = self.Parameters.yt_token.data()[self.Parameters.yt_token.default_key]
        cluster, _, path_on_cluster = source_path.partition(':')
        if cluster not in self.yt_clients:
            self.yt_clients[cluster] = yt.YtClient(proxy=cluster, config={'token': self.yt_token})
        yt_client = self.yt_clients[cluster]
        stream = yt_client.read_file(path_on_cluster)
        with self.output_bundle.path.joinpath(source_local_name).open('wb') as model_file:
            for chunk in stream.chunk_iter():
                model_file.write(chunk)
        return source_local_name

    def _load_from_berts_storage(self, source_path, source_local_name):
        if not source_local_name:
            source_local_name = source_path.replace('/', '_') + '.htxt'
        return self._load_from_yt("hahn://home/ranking/prod_berts_storage/models/" + source_path + "/bert.htxt", source_local_name)
