import datetime
import json
import logging
import os
import subprocess
import textwrap
import shutil

from sandbox import sdk2
from sandbox.projects import resource_types
import sandbox.common.types.task as ctt
from sandbox.projects.quality.resources import resources as rs
from sandbox.common.errors import TaskFailure

from sandbox.projects.common.arcadia import sdk as arcadiasdk
from sandbox.projects.common.constants import constants as sdk_constants

import sandbox.projects.common.dynamic_models.archiver as models_archiver
import sandbox.projects.common.dynamic_models.compare as models_compare


class PublishBertModelsArchive(sdk2.Task):
    """ Publish new bert models config. """

    class Requirements(sdk2.Requirements):
        tasks_resource = sdk2.Task.Requirements.tasks_resource(default=1620113038)

    class Parameters(sdk2.Parameters):
        kill_timeout = 3600

        yt_token_vault_owner = sdk2.parameters.String('Vault owner for yt token')
        yt_token_vault_name = sdk2.parameters.String('Vault name for yt token')
        run_check_models = sdk2.parameters.Bool('Run check_models for archive', default=True)

    def on_enqueue(self):
        self.Context.output = rs.BertModelsArchive(
            self,
            'Bert Models Archive',
            'bert_models_archive'
        ).id

    def _build(self, target, output_directory):
        with arcadiasdk.mount_arc_path(sdk2.svn.Arcadia.ARCADIA_TRUNK_URL) as aarcadia:
            arcadiasdk.do_build(
                build_system=sdk_constants.SEMI_DISTBUILD_BUILD_SYSTEM,
                source_root=aarcadia,
                targets=[target],
                results_dir=output_directory,
                clear_build=False,
            )

    def _fetch_last_resource(self, resource_type):
        resource = sdk2.Resource.find(
            type=resource_type,
            state="READY",
        ).first()
        logging.info("Used %s resource id %s" % (str(resource.type), resource.id))
        return resource.id, str(sdk2.ResourceData(resource).path)

    def on_execute(self):
        with self.memoize_stage.build_archive:
            target = 'quality/neural_net/bert_models/build_archive'
            binary_name = 'build_archive'
            output_directory = 'archive_binary'
            self._build(target, output_directory)
            logging.info(os.listdir(output_directory))

            token = sdk2.Vault.data(self.Parameters.yt_token_vault_owner, self.Parameters.yt_token_vault_name)
            os.environ['YT_PROXY'] = 'hahn'
            os.environ['YT_TOKEN'] = token

            bert_models_archive = sdk2.Resource[self.Context.output]
            bert_models_archive_path = str(sdk2.ResourceData(bert_models_archive).path)

            meta_info = {
                'sandbox_task_id': str(self.id),
                'sandbox_resource_id': str(bert_models_archive.id),
                'build_time': datetime.datetime.utcnow().isoformat(),
            }
            meta_info_path = '_metainfo'
            with open(meta_info_path, 'w') as fout:
                json.dump(meta_info, fout, indent=4)

            cmd = [
                os.path.join(os.curdir, output_directory, target, binary_name),
                '-o', bert_models_archive_path,
                '-m', meta_info_path,
            ]
            p = subprocess.Popen(cmd, stdout=subprocess.PIPE, stderr=subprocess.PIPE)
            code = p.wait()
            if code != 0:
                logging.info('STDOUT')
                logging.info(p.stdout.read())
                logging.info('STDERR')
                logging.info(p.stderr.read())

            bert_models_archive.utc_time = meta_info['build_time']

        with self.memoize_stage.test_archive:
            if self.Parameters.run_check_models:
                bert_models_archive = sdk2.ResourceData(sdk2.Resource[self.Context.output])
                bert_models_archive_path = str(bert_models_archive.path)

                # fetch last ready archiver tool
                _, archiver_path = self._fetch_last_resource(resource_types.ARCHIVER_TOOL_EXECUTABLE)

                models_list = models_archiver.get_list(archiver_path, str(bert_models_archive.path))
                logging.info('models_list')
                logging.info(json.dumps(models_list, indent=4))

                unpacked_models_path = 'unpacked_models'
                models_archiver.unpack(archiver_path, bert_models_archive_path, unpacked_models_path)

                # build bert_models_tool to check models from archive
                target = 'quality/relev_tools/bert_models/models_tool'
                binary_name = 'bert_models_tool'
                output_directory = 'models_tool_binary'
                self._build(target, output_directory)
                logging.info(os.listdir(output_directory))
                cmd = [
                    os.path.join(os.curdir, output_directory, target, binary_name),
                    'check_model',
                ]

                failed_check = False

                for model_name in os.listdir(unpacked_models_path):
                    # skip nonmodel files like metainfo
                    if not model_name.startswith('_'):
                        model_path = os.path.join(unpacked_models_path, model_name)
                        p = subprocess.Popen(cmd + ['-i', model_path], stdout=subprocess.PIPE, stderr=subprocess.PIPE)
                        code = p.wait()
                        if code != 0:
                            logging.error('Model %s has failed check' % model_name)
                            logging.error(p.stderr.read())
                            failed_check = True

                if failed_check:
                    bert_models_archive.broken()
                    raise TaskFailure('Some model has failed check')

        with self.memoize_stage.make_diff:
            bert_models_archive = sdk2.ResourceData(sdk2.Resource[self.Context.output])
            bert_models_archive_path = str(bert_models_archive.path)
            _, archiver_path = self._fetch_last_resource(resource_types.ARCHIVER_TOOL_EXECUTABLE)
            last_archive_id, last_archive_path = self._fetch_last_resource(rs.BertModelsArchive)

            self.Context.diff = models_compare.compare_archives(archiver_path, bert_models_archive_path, last_archive_path)
            bert_models_archive.ready()

    @sdk2.footer()
    def footer(self):
        return textwrap.dedent("""
            <a target="_blank">Deleted: {}</a><br/>
            <a target="_blank">Added: {}</a><br/>
            <a target="_blank">Same: {}</a><br/>
            <a target="_blank">Changed: {}</a><br/>
        """).format(
            json.dumps(self.Context.diff["deleted"]),
            json.dumps(self.Context.diff["added"]),
            json.dumps(self.Context.diff["same"]),
            json.dumps(self.Context.diff["changed"]),
        ).strip()
