# -*- coding: utf-8 -*-
import logging
import os
import tempfile
import tarfile
import shutil
import json
import time
import re

from os.path import join as pj

from sandbox import sdk2
from sandbox.sdk2.helpers import subprocess as sp
from sandbox.common.types.resource import State
from sandbox.common.types.task import Status
from sandbox.common.errors import TaskFailure, TemporaryError

from sandbox.projects.common.solomon import push_to_solomon_v2

from sandbox.projects.blender.resource_types import BlenderModel, BlenderNirvanaOutput
from sandbox.projects.blender.util.arcadia import change_dir, svn_export, selective_checkout, svn_diff, \
    build_directory, run_tests
from sandbox.projects.blender.util.subtask import ensure_subtask_is_successful

_RD_TOOLS_DIR = 'search/web/rearrs_upper/rearrange.dynamic'
_FML_CONFIG_PATCHER = 'yweb/blender/scripts/patch_fml_config/patch_fml_config.py'
_SPLITTER = re.compile('\s|,')
_KNOWN_EXTENSIONS = set(['info', 'mnmc', 'xtd', 'regtree', 'cbm'])
_FILES_TO_SKIP_IN_SOURCE_ARCHIVE = set(['tar_info.json'])  # used in build bundle grid


class CommitBlenderModel(sdk2.Task):
    """ Prepare and commit model to any extended formula storage (search/web/util/formula/storage.h) """

    class Requirements(sdk2.Task.Requirements):
        disk_space = 50 * 1024  # Mb
        # force cores1 using: https://wiki.yandex-team.ru/sandbox/cookbook/#cores1multislot
        cores = 1
        ram = 8192

        class Caches(sdk2.Requirements.Caches):
            pass  # means that task do not use any shared caches

    class Parameters(sdk2.Task.Parameters):
        arcadia_url = sdk2.parameters.ArcadiaUrl('Svn url for arcadia', required=True)
        resource_id = sdk2.parameters.String(
            'Resource ids', required=True, description='Resource ids splitted by space or comma, each resource is '
                                                       'archive with one or multiple models'
        )
        model_path = sdk2.parameters.String('Models arcadia path', required=True)
        model_make_file = sdk2.parameters.String('Models make file',
                                                 description='by default use ya.make from model path')
        model_version = sdk2.parameters.Integer('Model version (timestamp)', default_value=0)

        replace_existing_model = sdk2.parameters.Bool('Replace existing model', default_value=False)

        run_tests = sdk2.parameters.Bool('Run tests', default_value=False)
        with run_tests.value[True]:
            test_path = sdk2.parameters.String('Tests path', required=True)
            test_timeout = sdk2.parameters.Integer('Test timeout', description='Timeout for ya tests in sec')

        do_commit = sdk2.parameters.Bool('Do commit')
        with do_commit.value[True]:
            ssh_key_owner = sdk2.parameters.String('Owner of the vault ssh key',
                                                   default_value='BLENDER')
            ssh_key_name = sdk2.parameters.String('Name of the vault ssh key',
                                                  default_value='robot-blendr-priemka-private')
            ssh_user_name = sdk2.parameters.String('Ssh user name',
                                                   default_value='robot-blendr-priemka')
            diff_resolver = sdk2.parameters.String('Diff resolver',
                                                   description='by default diff_resolver is equal to task author')
            max_commit_retries = sdk2.parameters.Integer('Number of commit retry', default_value=3)

        send_to_solomon = sdk2.parameters.Bool('Send monitoring data to solomon')
        with send_to_solomon.value[True]:
            oauth_token_secret_id = sdk2.parameters.YavSecret(
                label='OAuth token secret id',
                description='Secret should contain keys: oauth_token',
                default_value='sec-01dr8w2ax4x1e5bgnynmk2x1sj'
            )
            solomon_project = sdk2.parameters.String('Solomon project', default_value='blender')
            solomon_cluster = sdk2.parameters.String('Solomon cluster', default_value='production')
            solomon_service = sdk2.parameters.String('Solomon service', default_value='fast_models')

        patch_fml_config = sdk2.parameters.Bool('Patch fml config')
        with patch_fml_config.value[True]:
            pfc_vertical = sdk2.parameters.String('Vertical name', required=True)
            pfc_locale = sdk2.parameters.String('Locale', required=True)
            pfc_ui = sdk2.parameters.List('User interface', required=True)
            pfc_filter_mode = sdk2.parameters.String('Set filter mode', default_value='inherit_prod')
            pfc_config_path = sdk2.parameters.String(
                'fml_config.json path',
                default_value='search/web/rearrs_upper/rearrange.dynamic/blender/fml_config.json',
                description='arcadia path to fml_config.json'
            )
        deduplicate = sdk2.parameters.Bool('Deduplicate models')

    class Context(sdk2.Task.Context):
        models = None  # only main models
        is_xtd = False
        bundle_resource_id = None
        bundle_models = None  # all models (with extracted children)
        patch = None
        test_task = None
        commit_retries = 0
        commit_timestamp = 0
        commit_result = None
        monitoring_data = dict()
        abt_params = dict()

    def _untar_resource(self, resource_id):
        logging.info('start unpacking resource %s', resource_id)
        resource = sdk2.Resource.find(status=State.READY, id=resource_id).first()
        if not resource:
            raise TaskFailure("Can't find resource {}".format(resource_id))
        path = str(sdk2.ResourceData(resource).path)
        tmp_dir = tempfile.mkdtemp(resource_id)
        assert tarfile.is_tarfile(path), 'input resource should be tar'
        logging.info('untar %s', path)
        with change_dir(tmp_dir):
            with tarfile.open(path) as tar:
                tar.extractall()
        return tmp_dir

    @staticmethod
    def _check_extension(model):
        parts = model.rsplit('.')
        ext = parts[-1] if len(parts) > 0 else ''
        assert ext in _KNOWN_EXTENSIONS, 'Unknown extension {} for model {}'.format(ext, model)

    def _untar_models(self):
        r_ids = [r for r in _SPLITTER.split(self.Parameters.resource_id) if r]
        assert r_ids, 'Empty source resources'
        models_dir = tempfile.mkdtemp('models')
        models = set()
        for r_id in r_ids:
            r_dir = self._untar_resource(r_id)
            for m in os.listdir(r_dir):
                if m in _FILES_TO_SKIP_IN_SOURCE_ARCHIVE:
                    logging.info('skip file %s in input resource', m)
                    continue
                assert m not in models, 'Duplicate model name {}'.format(m)
                models.add(m)
                self._check_extension(m)
                shutil.move(pj(r_dir, m), models_dir)
        self.Context.models = tuple(models)
        assert len(self.Context.models) > 0, 'Can\'t find any models in source tar'
        if any(m.endswith('.xtd') for m in self.Context.models):
            self.Context.is_xtd = True
            assert all(m.endswith('.xtd') for m in self.Context.models), 'Either all models should be .xtd bundle or none'
        self.set_info('Models to commit: {}'.format(', '.join(self.Context.models)))
        return models_dir

    def _set_model_version(self, version, model_file):
        assert self.Context.is_xtd, 'version handling is supported only for xtd model'
        logging.info('setting version %s to model %s', version, model_file)
        with open(model_file) as f:
            data = json.load(f)
        train_details = data.setdefault('Meta', dict()).setdefault('TrainDetails', dict())
        assert 'Version' not in train_details
        train_details['Version'] = version
        with open(model_file, 'w') as f:
            json.dump(data, f, indent=4)

    def _read_monitoring_data(self, model_file):
        assert self.Context.is_xtd, 'monitoring is supported only for xtd model'
        logging.info('read monitoring data for model %s', model_file)
        with open(model_file) as f:
            train_details = json.load(f).get('Meta', dict()).get('TrainDetails', dict())
        monitoring_data = train_details.get('Monitoring', dict())
        assert monitoring_data, 'Empty monitoring data'
        if 'WorkflowStartTimestamp' not in monitoring_data and self.Parameters.model_version:
            monitoring_data['WorkflowStartTimestamp'] = self.Parameters.model_version / 1000  # milliseconds to sec
        return monitoring_data

    def _get_model_version(self, model_file):
        assert self.Context.is_xtd, 'version handling is supported only for xtd model'
        logging.info('getting version from model %s', model_file)
        with open(model_file) as f:
            data = json.load(f)
            return data.get('Meta', dict()).get('TrainDetails', dict()).get('Version', 0)

    def _extract_xtd(self, model_file, model_dir, rd_tools_dir, out_dir):
        cmd = [
            pj(rd_tools_dir, 'extract_fml.py'), pj(model_dir, model_file),
            '--out-dir', out_dir,
            '--out_log'
        ]
        logging.info('extract submodels from bundle %s', model_file)
        with sdk2.helpers.ProcessLog(self, logger='extract_submodels') as pl:
            sp.check_call(cmd, stdout=pl.stdout, stderr=sp.STDOUT)
        # in case of absence internal bin models in xtd extract_fml.py do nothing, so we must manually copy model
        if not os.path.exists(pj(out_dir, model_file)):
            shutil.copy(pj(model_dir, model_file), out_dir)

    def _deduplicate(self, unpacked_models_dir, rd_tools_dir, arcadia_dir):
        logging.info('start models deduplication')
        build_directory(self.Parameters.model_path, arcadia_dir=arcadia_dir, task=self)
        model_path = pj(arcadia_dir, self.Parameters.model_path)
        dedup_dir = tempfile.mkdtemp('deduplicated_models')
        cmd = [
            pj(rd_tools_dir, 'deduplicate_model_storage.py'),
            'added',
            '-m', model_path,
            '-i', unpacked_models_dir,
            '-o', dedup_dir,
            '--print_dup_info'
        ]
        with sdk2.helpers.ProcessLog(self, logger='dedup_models') as pl:
            sp.check_call(cmd, stdout=pl.stdout, stderr=sp.STDOUT)
        return dedup_dir

    def _build_resource(self, models_dir, rd_tools_dir, arcadia_dir):
        unpacked_dir = tempfile.mkdtemp('unpacked')
        for model in self.Context.models:
            if self.Context.is_xtd:
                self._extract_xtd(model, models_dir, rd_tools_dir, unpacked_dir)
                if self.Parameters.model_version:
                    self._set_model_version(self.Parameters.model_version, pj(unpacked_dir, model))
                if self.Parameters.send_to_solomon:
                    self.Context.monitoring_data[model] = self._read_monitoring_data(pj(unpacked_dir, model))
            else:
                shutil.copy(pj(models_dir, model), unpacked_dir)
        resource = BlenderModel(
            self, description='unpacked models', path='unpacked_models.tar.gz', ttl='inf', backup_task=True,
        )
        if self.Parameters.deduplicate:
            unpacked_dir = self._deduplicate(unpacked_dir, rd_tools_dir, arcadia_dir)
            assert os.listdir(unpacked_dir), 'No models after deduplication'
        data = sdk2.ResourceData(resource)
        logging.info('start building bundle resource at path %s', data.path)
        bundle_models = []
        with tarfile.open(str(data.path), 'w:gz') as tar:
            for f in os.listdir(unpacked_dir):
                tar.add(os.path.join(unpacked_dir, f), f)
                bundle_models.append(f)
        assert bundle_models
        data.ready()
        self.Context.bundle_resource_id = resource.id
        self.Context.bundle_models = tuple(bundle_models)

    def _remove_existing_models(self, arcadia_dir, rd_tools_dir):
        logging.info('start removing existing models: %s', ', '.join(self.Context.models))
        build_directory(self.Parameters.model_path, arcadia_dir=arcadia_dir, task=self)
        model_path = pj(arcadia_dir, self.Parameters.model_path)
        models_to_delete = []
        for m in self.Context.models:
            model_file = pj(model_path, m)
            if not os.path.exists(model_file):
                logging.info('model %s doesn\'t exist', m)
                continue
            if self.Parameters.model_version:
                current_version = self._get_model_version(model_file)
                # todo: support work in case of many models: just skip outdated model or something else?
                if current_version > self.Parameters.model_version:
                    raise TaskFailure('Current model version more than new one: {} > {}'.format(
                        current_version, self.Parameters.model_version))
            models_to_delete.append(m)
        if not models_to_delete:
            return
        cmd = [os.path.join(rd_tools_dir, 'rm_fml.py')]
        cmd.extend(models_to_delete)
        cmd.extend([
            '--model-path', model_path,
            '--make-files-only',
            '--with-subformulas',
            '--dont-check-deps',  # we assume that there is no child intersection between models, if it exists either we update all child models or tests will fail
        ])
        with sdk2.helpers.ProcessLog(self, logger='remove_existing_model') as pl:
            sp.check_call(cmd, stdout=pl.stdout, stderr=sp.STDOUT)
        logging.info('diff after removing existing models:\n%s', svn_diff(arcadia_dir, self.Parameters.model_path))

    def _add_models(self, arcadia_dir, rd_tools_dir):
        logging.info('start adding new models and their children: %s', self.Context.bundle_models)
        if self.Parameters.model_make_file:
            make_file = pj(arcadia_dir, self.Parameters.model_make_file)
        else:
            make_file = pj(arcadia_dir, self.Parameters.model_path, 'ya.make')
        for model in self.Context.bundle_models:
            cmd = [
                os.path.join(rd_tools_dir, 'update_make_file.py'), 'add',
                '-c', make_file,
                '-f', model,
                '-r', str(self.Context.bundle_resource_id)
            ]
            with sdk2.helpers.ProcessLog(self, logger='add_model') as pl:
                sp.check_call(cmd, stdout=pl.stdout, stderr=sp.STDOUT)

    def _checkout_models(self):
        paths = [self.Parameters.model_path]
        if self.Parameters.patch_fml_config:
            paths.append(os.path.dirname(self.Parameters.pfc_config_path))
        return selective_checkout(paths, arcadia_url=self.Parameters.arcadia_url, task=self)

    @staticmethod
    def _cut_extention(path):
        return path.rsplit('.', 1)[0]

    def _build_abt_params(self):
        self.Context.abt_params = {'models': list(map(self._cut_extention, self.Context.models))}

    def _patch_fml_config(self, arcadia_dir):
        patcher = pj(
            svn_export(os.path.dirname(_FML_CONFIG_PATCHER), self.Parameters.arcadia_url, depth='files'),
            os.path.basename(_FML_CONFIG_PATCHER)
        )
        for model in self.Context.models:
            for ui in self.Parameters.pfc_ui:
                cmd = [
                    patcher, 'custom',
                    '-c', pj(arcadia_dir, self.Parameters.pfc_config_path),
                    '-v', self.Parameters.pfc_vertical,
                    '--ui', ui,
                    '-t', self.Parameters.pfc_locale,
                    '-f', model,
                    '--do_not_resolve',
                    '--fml_id', self._cut_extention(model),
                ]
                if self.Parameters.pfc_filter_mode:
                    cmd.extend(['--set_filter', self.Parameters.pfc_filter_mode])
                with sdk2.helpers.ProcessLog(self, logger='patch_fml_config') as pl:
                    sp.check_call(cmd, stdout=pl.stdout, stderr=sp.STDOUT)
        self.Context.abt_params.update({
            'vertical': self.Parameters.pfc_vertical,
            'locale': self.Parameters.pfc_locale,
            'ui': ['tablet' if x == 'pad' else x for x in self.Parameters.pfc_ui],
        })

    def _do_commit(self):
        msg = '[diff-resolver:{}] SKIP_CHECK commit models {} to {} from: https://sandbox.yandex-team.ru/task/{}/view'.format(
            self.Parameters.diff_resolver or self.author, ', '.join(self.Context.models), self.Parameters.model_path, self.id
        )
        arcadia_dir = self._checkout_models()
        sdk2.svn.Arcadia.apply_patch(arcadia_dir, self.Context.patch, tempfile.mkdtemp())
        if self.Parameters.do_commit:
            with sdk2.ssh.Key(self, self.Parameters.ssh_key_owner, self.Parameters.ssh_key_name):
                commit_result = sdk2.svn.Arcadia.commit(arcadia_dir, msg, self.Parameters.ssh_user_name)
                if not commit_result:
                    raise TaskFailure('Empty commit result')
                self.set_info('Done commit with result: {}'.format(commit_result))
        else:
            self.set_info('Commit would be done with message: "{}"\nPatch:\n{}'.format(msg, self.Context.patch))
            commit_result = 'No real commit, because do_commit=False'
        self.Context.commit_timestamp = int(time.time())
        self.Context.commit_result = commit_result

    @staticmethod
    def _append_sensor(sensors, name, kind, value):
        sensors.append({
            'labels': {'sensor': name},
            'kind': kind,
            'value': value
        })

    def _send_monitoring_data_to_solomon(self):
        fast_data_dir = 'search/web/rearrs_upper/rearrange.fast/'
        assert self.Parameters.model_path.startswith(fast_data_dir), 'Monitoring make a sense only for fast models'
        path = self.Parameters.model_path[len(fast_data_dir):]
        assert self.Parameters.do_commit or self.Parameters.solomon_project != 'production', \
            'It is forbidden to send fake points to production cluster'
        assert self.Context.monitoring_data, 'Empty monitoring data'
        assert self.Context.commit_timestamp, 'Zero commit timestamp'
        oauth_token = self.Parameters.oauth_token_secret_id.data()['oauth_token']
        params = {
            'project': self.Parameters.solomon_project,
            'cluster': self.Parameters.solomon_cluster,
            'service': self.Parameters.solomon_service,
        }
        for model, data in self.Context.monitoring_data.items():
            common_labels = {
                'model': model.rsplit('.', 1)[0],
                'path': path,
            }
            sensors = [
                {
                    'labels': {'sensor': 'last_pool_ts'},
                    'kind': 'IGAUGE',
                    'value': data['LastPoolTimestamp']
                },
                {
                    'labels': {'sensor': 'pool_size'},
                    'kind': 'IGAUGE',
                    'value': data['PoolSize']
                },
                {
                    'labels': {'sensor': 'commit_ts'},
                    'kind': 'IGAUGE',
                    'value': self.Context.commit_timestamp
                },
                {
                    'labels': {'sensor': 'delay_from_last_pool_item_minutes'},
                    'kind': 'GAUGE',
                    'value': (self.Context.commit_timestamp - data['LastPoolTimestamp']) / 60.
                }
            ]
            workflow_start_ts = self.Context.monitoring_data.get('WorkflowStartTimestamp')
            if workflow_start_ts:
                sensors.extend([
                    {
                        'labels': {'sensor': 'workflow_start_ts'},
                        'kind': 'IGAUGE',
                        'value': workflow_start_ts,
                    },
                    {
                        'labels': {'sensor': 'workflow_working_time_minutes'},
                        'kind': 'GAUGE',
                        'value': (self.Context.commit_timestamp - workflow_start_ts) / 60.
                    }
                ])
            logging.info('upload to solomon: common_labels %s, sensors %s', common_labels, sensors)
            push_to_solomon_v2(oauth_token, params, sensors, common_labels)

    def _validate_params(self):
        if self.Parameters.replace_existing_model and self.Parameters.deduplicate:
            raise Exception('Model replacing and deduplication cannot be used together')
        if self.Parameters.patch_fml_config:
            assert 'rearrange.dynamic' in self.Parameters.model_path, "Shoudn't patch fml config for fast models"
            assert self.Parameters.pfc_vertical, 'Empty fml_config vertical'
            assert self.Parameters.pfc_ui, 'Empty fml_config ui'
            assert self.Parameters.pfc_locale, 'Empty fml_config locale'
            assert self.Parameters.pfc_config_path, 'Empty fml_config path'

    def _build_nirvana_output(self):
        name2data = {
            'abt_params.json': json.dumps(self.Context.abt_params, indent=4),
            'commit_result.txt': self.Context.commit_result,
            'patch.txt': self.Context.patch
        }
        resource = BlenderNirvanaOutput(self, description='tar with nirvana artifacts', path='nirvana_output.tar.gz')
        resource_data = sdk2.ResourceData(resource)
        with tarfile.open(str(resource_data.path), 'w:gz') as tar:
            for name, data in name2data.items():
                with open(name, 'w') as f:
                    f.write(data)
                tar.add(name)
        resource_data.ready()

    def on_execute(self):
        self._validate_params()

        if not self.Context.patch:
            arcadia_dir = self._checkout_models()
            rd_tools = svn_export(_RD_TOOLS_DIR, self.Parameters.arcadia_url, depth='files')
            models_dir = self._untar_models()
            self._build_resource(models_dir, rd_tools, arcadia_dir)
            if self.Parameters.replace_existing_model:
                self._remove_existing_models(arcadia_dir, rd_tools)
            self._add_models(arcadia_dir, rd_tools)
            self._build_abt_params()
            if self.Parameters.patch_fml_config:
                self._patch_fml_config(arcadia_dir)
            self.Context.patch = svn_diff(arcadia_dir)
            if not self.Context.patch:
                raise TaskFailure('Empty patch')
            self.set_info(self.Context.patch)

        if self.Parameters.run_tests:
            with self.memoize_stage.run_tests():
                self.Context.test_task = run_tests(
                    self.Parameters.test_path, patch=self.Context.patch, arcadia_url=self.Parameters.arcadia_url,
                    task=self
                )
                self.set_info('Wait tests')
                raise sdk2.WaitTask(self.Context.test_task, Status.Group.FINISH | Status.Group.BREAK,
                                    timeout=self.Parameters.test_timeout or None)
            ensure_subtask_is_successful(self.Context.test_task, self)

        with self.memoize_stage.do_commit(commit_on_entrance=False):
            try:
                self._do_commit()
            except Exception as e:
                if self.Context.commit_retries < self.Parameters.max_commit_retries:
                    self.Context.patch = None
                    self.Context.commit_retries += 1
                    raise TemporaryError(e)  # initiate task restart
                else:
                    raise TaskFailure(e)

        if self.Parameters.send_to_solomon:
            with self.memoize_stage.send_to_solomon(commit_on_entrance=False):
                self._send_monitoring_data_to_solomon()

        self._build_nirvana_output()
