# -*- coding: utf-8 -*-
import os
import subprocess
import json
import logging

from sandbox.projects.common import task_env
from sandbox.projects.websearch.begemot.common import add_realtime_version_file
from sandbox.projects.websearch.begemot.resources import BEGEMOT_REALTIME_PACKAGE
from sandbox.sandboxsdk import environments
import sandbox.sdk2 as sdk2


WORKING_DIR = "realtime_data"
EMPTY_MONITORING_LIST = 'empty'


_TABLES = [
    '//home/izolenta/real_time_training/search/odd/pools/@search_odd_skipped_states_1d',
    '//home/izolenta/real_time_training/search/odd/pools/@search_odd_skipped_states_7d',
    '//home/izolenta/real_time_training/search/odd/pools/@search_odd_skipped_states_14d',
    '//home/izolenta/real_time_training/search/odd/pools/@search_odd_skipped_states_30d',
    '//home/izolenta/real_time_training/search/log_dt/pools/@search_log_dt_skipped_states_1d',
    '//home/izolenta/real_time_training/search/log_dt/pools/@search_log_dt_skipped_states_7d',
    '//home/izolenta/real_time_training/search/log_dt/pools/@search_log_dt_skipped_states_14d',
    '//home/izolenta/real_time_training/search/log_dt/pools/@search_log_dt_skipped_states_30d',
    '//home/izolenta/real_time_training/search/log_dt/models/@metric-search_log_dt_dssm_CE',
    '//home/izolenta/real_time_training/search/log_dt/models/@metric-search_log_dt_dssm_CE_delta',
    '//home/izolenta/real_time_training/search/log_dt/models/@metric-search_log_dt_dssm_CE_week_ago_delta',
    '//home/izolenta/real_time_training/search/log_dt/models/@metric-search_log_dt_dssm_CE_rounded_week_ago_delta',
    '//home/izolenta/real_time_training/search/odd/models/@metric-search_odd_dssm_NZMSE',
    '//home/izolenta/real_time_training/search/odd/models/@metric-search_odd_dssm_NZMSE_delta',
    '//home/izolenta/real_time_training/search/odd/models/@metric-search_odd_dssm_NZMSE_week_ago_delta',
    '//home/izolenta/real_time_training/search/odd/models/@metric-search_odd_dssm_NZMSE_rounded_week_ago_delta',
]


class BuildModelsRealtimePackage(sdk2.Task):

    class Requirements(task_env.TinyRequirements):
        disk_space = 70 * 1024  # 70 Gb
        ram = 48 * 1024  # 48 Gb
        environments = [
            environments.PipEnvironment("yandex-yt"),
        ]

    class Parameters(sdk2.Parameters):
        begemot_shard = sdk2.parameters.String("begemot_shard", hint=True, default='rtmodels')
        begemot_rule = sdk2.parameters.String("begemot_rule", hint=True, default='RealTimeTraining')
        resource_proxy_url = sdk2.parameters.String("resource_proxy_url", hint=True, default='arnold.yt.yandex.net')
        current_state_name = sdk2.parameters.String('current_state_name', hint=True, default='CurrentState')
        model_states_attr_name = sdk2.parameters.String('model_states_attr_name', hint=True, default='ModelStates')
        model_names_attr_name = sdk2.parameters.String('model_names_attr_name', hint=True, default='ModelNames')
        monitoring_attr_name_lists = sdk2.parameters.List(
            'monitoring_attr_name_lists', default=[
                EMPTY_MONITORING_LIST,  # yabs
                ','.join(_TABLES),  # search
                EMPTY_MONITORING_LIST, # search fpm, TODO(filmih@): add monitoring attrs
                EMPTY_MONITORING_LIST, # search fpm 1 day ago
                EMPTY_MONITORING_LIST, # search fpm 3 days ago
                EMPTY_MONITORING_LIST, # search fpm 1 week ago
                EMPTY_MONITORING_LIST, # search fpm frozen
            ])
        resource_name = sdk2.parameters.String("resource_name", hint=True, default='optimized_model.dssm')
        resource_path_prefix = sdk2.parameters.String("resource_prefix", hint=True,
                                                      default='//home/izolenta/real_time_training/yabs/bundles')
        resource_path_prefixes = sdk2.parameters.List(
            "resource_path_prefixes",
            default=[
                '//home/izolenta/real_time_training/yabs/bundles',
                '//home/izolenta/real_time_training/search/bundles',
                '//home/web_personalization/features_personal_model/publish/htxt',
                '//home/web_personalization/features_personal_model/publish/htxt_1_day_ago',
                '//home/web_personalization/features_personal_model/publish/htxt_3_days_ago',
                '//home/web_personalization/features_personal_model/publish/htxt_1_week_ago',
                '//home/web_personalization/features_personal_model/publish/htxt_frozen_2021-12-01T21:00:00',
            ])
        resource_begemot_rules = sdk2.parameters.List(
            "resource_path_prefixes",
            default=[
                'RealTimeTraining',
                'RealTimeTrainingWeb',
                'FeaturesPersonalModelLoader',
                'FeaturesPersonalModelLoader',
                'FeaturesPersonalModelLoader',
                'FeaturesPersonalModelLoader',
                'FeaturesPersonalModelLoader',
            ])
        resource_path_suffixes = sdk2.parameters.List(
            "resource_path_suffixes",
            default=[
                'optimized_model.dssm',
                'optimized_model.dssm',
                'features_personal_model_v1.htxt',
                'features_personal_model_v1_1_day_ago.htxt',
                'features_personal_model_v1_3_days_ago.htxt',
                'features_personal_model_v1_1_week_ago.htxt',
                'features_personal_model_v1_frozen.htxt',
            ])
        yt_token_secret_name = sdk2.parameters.String("yt_token_secret_name",
                                                      hint=True,
                                                      default='robot_itditp_yt_token')

    class Context(sdk2.Context):
        rule2model2file = dict()
        rule2model2state = dict()
        rule2monitoring_data = dict()

    def _download_yt_resource(self, path_prefix, path_suffix, dir="."):
        import yt.wrapper as yt
        yt.config["proxy"]["url"] = self.Parameters.resource_proxy_url
        yt.config["token"] = sdk2.Vault.data(self.owner, self.Parameters.yt_token_secret_name)

        latest_state = yt.get(path_prefix + '/@{}'.format(self.Parameters.current_state_name))
        path2latest = yt.ypath_join(path_prefix,
                                    latest_state,
                                    path_suffix)

        if yt.get_attribute(path2latest, 'type') == "table":
            torrent_id = yt.sky_share(path2latest)
            subprocess.check_call(['sky', 'get', '-wud', dir, torrent_id])
        elif yt.get_attribute(path2latest, 'type') == "file":
            with open(os.path.join(dir, path_suffix), 'wb') as f:
                f.write(yt.read_file(path2latest).read())

    def _make_archive_for_begemot(self, archive_path, dirs, working_dir):
        version_file = add_realtime_version_file(working_dir, self.id)
        subprocess.check_call(['tar', '-C', working_dir, '-cf', archive_path] + dirs + [version_file])

    def _read_model_states(self, resource_path_prefix):
        import yt.wrapper as yt
        yt.config["proxy"]["url"] = self.Parameters.resource_proxy_url
        yt.config["token"] = sdk2.Vault.data(self.owner, self.Parameters.yt_token_secret_name)

        models = yt.get(resource_path_prefix + '/@{}'.format(self.Parameters.model_names_attr_name))
        logging.info("%s: %s", self.Parameters.model_names_attr_name, str(models))

        states = yt.get(resource_path_prefix + '/@{}'.format(self.Parameters.model_states_attr_name))
        logging.info("%s: %s", self.Parameters.model_states_attr_name, str(states))

        model2state = dict()
        for model, state in zip(models, states):
            model2state[model] = state
        return model2state

    def _read_monitoring_data(self, monitoring_attribute_names):
        import yt.wrapper as yt
        yt.config["proxy"]["url"] = self.Parameters.resource_proxy_url
        yt.config["token"] = sdk2.Vault.data(self.owner, self.Parameters.yt_token_secret_name)

        result = dict()
        for attr_name in monitoring_attribute_names.split(','):
            short_name = attr_name.split('/')[-1][1:]  # take attr name without "@"
            if yt.exists(attr_name):
                result[short_name] = yt.get(attr_name)
                logging.info("%s: %s", short_name, str(result[short_name]))
            else:
                logging.info("no data found using path: %s", attr_name)
        return result

    def _merge_dicts(self, from_dict, to_dict):
        for key, value in from_dict.items():
            assert key not in to_dict, "Found collision when merging dicts, key = " + key
            to_dict[key] = value

    def _write_dict_to_file(self, data, file_path):
        with open(file_path, 'w') as f:
            f.write(json.dumps(data, indent=4))

    def on_execute(self):
        assert len(self.Parameters.resource_path_prefixes) == len(self.Parameters.resource_begemot_rules)
        assert len(self.Parameters.resource_path_prefixes) == len(self.Parameters.monitoring_attr_name_lists)
        for path_prefix, rule_name, monitoring_attribute_names, path_suffix in zip(
            self.Parameters.resource_path_prefixes,
            self.Parameters.resource_begemot_rules,
            self.Parameters.monitoring_attr_name_lists,
            self.Parameters.resource_path_suffixes,
        ):
            rule_dir = os.path.join(WORKING_DIR, rule_name)
            if not os.path.exists(rule_dir):
                os.makedirs(rule_dir)

            if not rule_name in self.Context.rule2model2file:
                self.Context.rule2model2file[rule_name] = dict()
            assert path_suffix not in self.Context.rule2model2file[rule_name].values(), "Duplicating file for rule " + rule_name
            self._download_yt_resource(path_prefix=path_prefix, path_suffix=path_suffix, dir=rule_dir)

            if not rule_name in self.Context.rule2model2state:
                self.Context.rule2model2state[rule_name] = dict()
            model2state = self._read_model_states(path_prefix)
            self._merge_dicts(model2state, self.Context.rule2model2state[rule_name])

            if not rule_name in self.Context.rule2monitoring_data:
                self.Context.rule2monitoring_data[rule_name] = dict()
            if monitoring_attribute_names != EMPTY_MONITORING_LIST:
                self._merge_dicts(self._read_monitoring_data(monitoring_attribute_names), self.Context.rule2monitoring_data[rule_name])

            for model in model2state.keys():
                self.Context.rule2model2file[rule_name][model] = path_suffix

        for rule_name in set(self.Parameters.resource_begemot_rules):
            rule_dir = os.path.join(WORKING_DIR, rule_name)
            self._write_dict_to_file(self.Context.rule2model2state[rule_name], os.path.join(rule_dir, 'states'))
            self._write_dict_to_file(self.Context.rule2monitoring_data[rule_name], os.path.join(rule_dir, 'monitoring_data'))
            self._write_dict_to_file(self.Context.rule2model2file[rule_name], os.path.join(rule_dir, 'models_files'))

        archive_path = os.path.abspath('realtime_wizard_data.tar')
        self._make_archive_for_begemot(archive_path,
                                       dirs=self.Parameters.resource_begemot_rules,
                                       working_dir=WORKING_DIR)

        _ = BEGEMOT_REALTIME_PACKAGE(
            self,
            BEGEMOT_REALTIME_PACKAGE.name + ' for {}'.format(self.Parameters.begemot_shard),
            archive_path,
            released='stable',
            shard=self.Parameters.begemot_shard,
            ttl=3,
            version=self.id,
        )
