# -*- coding: utf-8 -*-
import logging
import os
import tempfile
import tarfile
import re
import time
import traceback

from sandbox import sdk2
from sandbox.common.types.task import Status
from sandbox.common.types.resource import State
from sandbox.sdk2.helpers import subprocess as sp
from sandbox.sdk2 import ssh
from sandbox.common.errors import TaskFailure

from sandbox.projects.blender.util.arcadia import selective_checkout, svn_diff, svn_export, build_directory, run_tests
from sandbox.projects.blender.util.subtask import ensure_subtask_is_successful


class BlenderOnlineLearningRtmrHelper(sdk2.Resource):
    pass


class BlenderOnlineLearningModelPack(sdk2.Resource):
    creation_time = sdk2.Attributes.Integer('Pack creation time')


def _sleep_timeout(retry):
    timeouts = [10, 30, 60, 120, 300]
    return timeouts[retry] if retry < len(timeouts) else timeouts[-1]


MODEL_DATA_ARC_PATH = 'search/web/rearrs_upper/rearrange.fast/blender_online_learning/proto_models'
MODEL_TEST_ARC_PATH = 'search/web/rearrs_upper/tests/rearrange.fast/blender_online_learning/tests'
RTMR_HELPER_ARC_PATH = 'yweb/blender/online_learning/bin/rtmr_helper'
MODEL_FILE_EXT = '.bin'


class UpdateBlenderOnlineModels(sdk2.Task):
    """ Task for updating blender online models """

    class Requirements(sdk2.Task.Requirements):
        disk_space = 50 * 1024

        # force cores1 using: https://wiki.yandex-team.ru/sandbox/cookbook/#cores1multislot
        cores = 1
        ram = 8192

        class Caches(sdk2.Requirements.Caches):
            pass  # means that task do not use any shared caches

    class Parameters(sdk2.Task.Parameters):
        arcadia_url = sdk2.parameters.ArcadiaUrl('Svn url for arcadia', required=True)
        rtmr_host = sdk2.parameters.String('RTMR host', default_value='rtmr')

        use_custom_model_list = sdk2.parameters.Bool('Use custom model list',
                                                     description='To overwrite using models from arcadia')
        with use_custom_model_list.value[True]:
            custom_model_list = sdk2.parameters.String('Model list', multiline=True, required=True)

        build_rtmr_helper = sdk2.parameters.Bool('Build rtmr_helper')
        run_tests = sdk2.parameters.Bool('Run tests', default_value=True)

        do_commit = sdk2.parameters.Bool('Do commit')
        with do_commit.value[True]:
            ssh_key_owner = sdk2.parameters.String('Owner of the vault ssh key',
                                                   default_value='BLENDER')
            ssh_key_name = sdk2.parameters.String('Name of the vault ssh key',
                                                  default_value='robot-blendr-priemka-private')
            ssh_user_name = sdk2.parameters.String('Ssh user name',
                                                   default_value='robot-blendr-priemka')

        test_timeout = sdk2.parameters.Integer('Test timeout', description='Timeout for ya tests in sec',
                                               default_value=900)
        fetch_retries = sdk2.parameters.Integer('Model fetch retries', default_value=2)
        skip_unavailable_models = sdk2.parameters.Bool('Dont fail if some models are unavailable')
        use_previous_version_on_unavailability = sdk2.parameters.Bool('Use previous version of model in case of unavailability')

    class Context(sdk2.Task.Context):
        models = ()
        pack_id = None
        patch = None
        build_rtmr_helper_task_id = None
        test_task_id = None

    def _build_binary(self, target, resource_type, force_checkout=False):
        kwargs = dict()
        if force_checkout:
            kwargs['checkout_mode'] = 'manual'
            kwargs['checkout'] = True
        task = sdk2.Task['YA_MAKE'](
            self,
            description='Build ' + target,
            owner=self.owner,
            checkout_arcadia_from_url=self.Parameters.arcadia_url,
            arts=os.path.join(target, os.path.basename(target)),
            targets=target,
            result_rt=resource_type.name,
            result_rd='here should be some description',
            result_single_file=True,
            build_type='release',
            build_arch='linux',
            **kwargs
        )
        return task.enqueue().id

    def _run_tests(self, target, patch):
        task = sdk2.Task['YA_MAKE'](
            self,
            description='Run tests for ' + target,
            owner=self.owner,
            checkout_arcadia_from_url=self.Parameters.arcadia_url,
            targets=target,
            build_type='release',
            build_arch='linux',
            arcadia_patch=patch,
            test=True
        )
        return task.enqueue().id

    def _checkout_model_data(self):
        return selective_checkout(MODEL_DATA_ARC_PATH, arcadia_url=self.Parameters.arcadia_url, task=self)

    @classmethod
    def _add_model_extension(cls, model_name):
        return model_name + MODEL_FILE_EXT

    @classmethod
    def _cut_model_extension(cls, file_name):
        if file_name.endswith(MODEL_FILE_EXT):
            return file_name[:-len(MODEL_FILE_EXT)]
        return None

    def _fetch_model_from_rtmr(self, model_name, helper_bin, dst_path):
        model_file = os.path.join(dst_path, self._add_model_extension(model_name))
        cmd = [
            helper_bin, 'dump_model',
            '-k', model_name,
            '-f', model_file,
            '-s', self.Parameters.rtmr_host,
        ]
        logging.info('fetch model %s to %s', model_name, model_file)
        with sdk2.helpers.ProcessLog(self, logger='fetch_model') as pl:
            sp.check_call(cmd, stdout=pl.stdout, stderr=sp.STDOUT)
        return model_file

    def _init_models(self):
        models = set()
        if self.Parameters.use_custom_model_list:
            models = re.split('\s', self.Parameters.custom_model_list)
        else:
            tmp_dir = svn_export(MODEL_DATA_ARC_PATH, self.Parameters.arcadia_url)
            for file_with_models in ['model_list.txt', 'fastres2_model_list.txt']:
                for line in open(os.path.join(tmp_dir, file_with_models)):
                    line = line.strip()
                    if line and not line.startswith('#'):
                        models.add(line)
        return tuple(sorted(models))

    def _fetch_models(self, work_dir, rtmr_helper):
        models_to_fetch = list(self.Context.models)
        failed_models, errors = [], []
        model2file = dict()
        for retry in range(self.Parameters.fetch_retries + 1):
            if retry > 0:
                errors[:] = []
                failed_models[:] = []
                time.sleep(_sleep_timeout(retry - 1))
            for mn in models_to_fetch:
                assert mn not in model2file, 'Duplicate model %s' % mn
                try:
                    model2file[mn] = self._fetch_model_from_rtmr(mn, rtmr_helper, work_dir)
                except sp.CalledProcessError:
                    failed_models.append(mn)
                    errors.append(traceback.format_exc())
            if not failed_models:
                break
            models_to_fetch = failed_models[:]
        if failed_models:
            msg = '{} models fetching failed with errors: {}'.format(', '.join(failed_models), '\n'.join(errors))
            if self.Parameters.skip_unavailable_models:
                logging.error(msg)
                self.set_info('Fail to fetch models: [{}]'.format(', '.join(failed_models)))
            else:
                raise Exception(msg)
        return model2file

    def _prepare_previous_version_models(self):
        arcadia_dir = self._checkout_model_data()
        build_directory(MODEL_DATA_ARC_PATH, arcadia_dir=arcadia_dir, task=self)
        path = os.path.join(arcadia_dir, MODEL_DATA_ARC_PATH)
        m2f = dict()
        for fn in os.listdir(path):
            mn = self._cut_model_extension(fn)
            if mn:
                m2f[mn] = os.path.realpath(os.path.join(path, fn))
        return m2f

    def _fetch_and_pack_models(self):
        rtmr_helper = BlenderOnlineLearningRtmrHelper.find(
            task_id=self.Context.build_rtmr_helper_task_id, status=State.READY
        ).order(-sdk2.Resource.id).first()
        logging.info('use rtmr_helper from resource %s', rtmr_helper.id)
        logging.info('models to fetch: %s', ', '.join(self.Context.models))
        models = self._fetch_models(tempfile.mkdtemp('models'), str(sdk2.ResourceData(rtmr_helper).path))
        logging.info('fetched models: %s', sorted(models.items()))
        self.set_info('Fetched models: [{}]'.format(', '.join(models.keys())))
        prev_version_models = dict()
        if self.Parameters.use_previous_version_on_unavailability:
            prev_version_models = self._prepare_previous_version_models()
            logging.info('prev version models: %s', sorted(prev_version_models.items()))
        used_prev_version, skipped = [], []
        for m in self.Context.models:
            if m not in models:
                if m in prev_version_models:
                    models[m] = prev_version_models[m]
                    used_prev_version.append(m)
                else:
                    skipped.append(m)
        if used_prev_version:
            self.set_info('Use previous version for models [{}]'.format(', '.join(used_prev_version)))
        if skipped:
            self.set_info('Skip models [{}]'.format(', '.join(skipped)))
        self.Context.models = tuple(sorted(models.keys()))
        pack = BlenderOnlineLearningModelPack(
            self, description='blender online models pack: {}'.format(', '.join(self.Context.models)),
            path='blender_online_models.tar.gz', ttl='inf', backup_task=True,
        )
        pack_data = sdk2.ResourceData(pack)
        logging.info('model pack path %s', pack_data.path)
        with tarfile.open(str(pack_data.path), 'w:gz') as tar:
            for mn, mf in models.items():
                tar.add(mf, self._add_model_extension(mn))
        pack.creation_time = int(time.time())
        pack_data.ready()
        return pack.id

    def _update_ya_make(self, ya_make_file):
        if not self.Context.models:
            raise TaskFailure('No models to update')
        with open(ya_make_file) as f:
            lines = f.readlines()
        from_sandbox_section = False
        has_written_models = False
        with open(ya_make_file, 'w') as f:
            for raw_line in lines:
                line = raw_line.strip()
                should_skip = False
                if line.startswith('FROM_SANDBOX'):
                    from_sandbox_section = True
                    should_skip = True
                if line.endswith(')') and from_sandbox_section:
                    from_sandbox_section = False
                    should_skip = True
                if from_sandbox_section or should_skip:
                    logging.info('skip ya.make line %s', line)
                    continue
                if line == 'END()':
                    f.write('FROM_SANDBOX({} OUT\n    {}\n)\n'.format(
                        self.Context.pack_id, '\n    '.join(map(self._add_model_extension, self.Context.models)))
                    )
                    has_written_models = True
                f.write(raw_line)
        assert has_written_models

    def _patch_ya_make(self, arcadia_dir):
        ya_make = os.path.join(MODEL_DATA_ARC_PATH, 'ya.make')
        self._update_ya_make(os.path.join(arcadia_dir, ya_make))
        diff = svn_diff(arcadia_dir, ya_make)
        logging.info('model ya.make diff\n%s', diff)
        return diff

    def _ensure_child_task_is_successful(self, task_id):
        ensure_subtask_is_successful(task_id, self)

    def on_execute(self):
        if self.Parameters.build_rtmr_helper:
            with self.memoize_stage.start_build(commit_on_entrance=False):
                self.Context.build_rtmr_helper_task_id = self._build_binary(
                    RTMR_HELPER_ARC_PATH, BlenderOnlineLearningRtmrHelper, True
                )
                self.set_info('Wait rtmr helper build')
                raise sdk2.WaitTask(self.Context.build_rtmr_helper_task_id, Status.Group.FINISH | Status.Group.BREAK)
            self._ensure_child_task_is_successful(self.Context.build_rtmr_helper_task_id)

        with self.memoize_stage.init_models(commit_on_entrance=False):
            self.Context.models = self._init_models()
            self.set_info('Models to update: {}'.format(', '.join(self.Context.models)))

        with self.memoize_stage.fetch_models(commit_on_entrance=False):
            self.set_info('Start fetching models from rtmr')
            self.Context.pack_id = self._fetch_and_pack_models()

        with self.memoize_stage.patch_ya_make(commit_on_entrance=False):
            arcadia_dir = self._checkout_model_data()
            self.Context.patch = self._patch_ya_make(arcadia_dir)
            self.set_info('ya.make patch:\n{}'.format(self.Context.patch))

        if self.Parameters.run_tests:
            with self.memoize_stage.run_tests():
                self.Context.test_task_id = run_tests(
                    MODEL_TEST_ARC_PATH, patch=self.Context.patch, arcadia_url=self.Parameters.arcadia_url, task=self
                )
                self.set_info('Wait tests')
                raise sdk2.WaitTask(self.Context.test_task_id, Status.Group.FINISH | Status.Group.BREAK,
                                    timeout=self.Parameters.test_timeout)
            self._ensure_child_task_is_successful(self.Context.test_task_id)

        with self.memoize_stage.do_commit(commit_on_entrance=False):
            msg = '[diff-resolver:{}] SKIP_CHECK update blender online models: https://sandbox.yandex-team.ru/task/{}/view'.format(
                self.author, self.id
            )
            arcadia_dir = self._checkout_model_data()
            sdk2.svn.Arcadia.apply_patch(arcadia_dir, self.Context.patch, tempfile.mkdtemp())
            if self.Parameters.do_commit:
                with ssh.Key(self, self.Parameters.ssh_key_owner, self.Parameters.ssh_key_name):
                    commit_result = sdk2.svn.Arcadia.commit(arcadia_dir, msg, self.Parameters.ssh_user_name)
                    if not commit_result:
                        raise TaskFailure('Empty commit result')
                    self.set_info('Done commit with result: {}'.format(commit_result))
            else:
                self.set_info('Commit would be done with message: "{}"\nPatch:\n{}'.format(
                    msg, sdk2.svn.Arcadia.diff(arcadia_dir)
                ))
