import logging

from sandbox import sdk2, common
import sandbox.common.types.task as ctt

from jinja2 import Environment, BaseLoader
import tarfile
import tempfile
import os


GRAPH_RESOURCE_TYPE = 'AUTOCHECK_LOGS'

ERROR_MESSAGE_TEMPLATE = '''Node count: benchmark: {{ stat.benchmark_graph_node_count }}, test: {{stat.test_graph_node_count }}
autocheck_logs: <a href=/resource/{{ bm_res_id }}/view>benchmark</a>, <a href=/resource/{{ tst_res_id }}/view>test</a>
<b>Important:</b>
- unmatched node count: benchmark: {{ stat.benchmark_unmatched_node_count }}, test: {{stat.test_unmatched_node_count }}
- uid only changes: {{ stat.uid_only_changes_count }}
- uid not changed but should: {{ stat.significant_but_no_uid_changes_count }}
- other significant changes: {{ stat.significant_changes_count }}
<b>Ordinary:</b>
- uid and deps only changes: {{ stat.uid_and_deps_only_changes_count }}
- other changes: {{ stat.insignificant_changes_count }}
'''


class CacheTestGraphCompareParent(sdk2.Task):
    class Parameters(sdk2.Task.Parameters):
        autocheck_resource_ids = sdk2.parameters.List('autocheck_resource_ids', required=True)
        config_name_tag_prefix = sdk2.parameters.String('config_name_tag_prefix', default='cache-tests-config-name', required=True)
        job_type_tag_prefix = sdk2.parameters.String('job_type_tag_prefix', default='cache-tests-job-type', required=True)
        resource_ttl_days = sdk2.parameters.Integer('resource_ttl_days', default=7)

    def on_create(self):
        self.Requirements.tasks_resource = sdk2.service_resources.SandboxTasksBinary.find(
            owner=self.owner,
            attrs={'tasks_bundle': CacheTestGraphCompare.__name__},
        ).first()

    def on_execute(self):
        with self.memoize_stage.create_children:
            config_name_tag_prefix = self.Parameters.config_name_tag_prefix.lower() + ':'
            job_type_tag_prefix = self.Parameters.job_type_tag_prefix.lower() + ':'

            resources = self._organize_resources(self.Parameters.autocheck_resource_ids, config_name_tag_prefix, job_type_tag_prefix)
            logging.debug('Autocheck resources: %s', str(resources))

            children_tasks = []
            for key, partition_resources in resources.items():
                config_name, partition = key
                bm_res_id = partition_resources['benchmark']
                tst_res_id = partition_resources['test']
                child = CacheTestGraphCompare(
                    self,
                    description='Compare graphs for {} partition {}'.format(config_name, partition),
                    owner=self.owner,
                    notifications=self.Parameters.notifications,
                    benchmark_resource_id=bm_res_id,
                    test_resource_id=tst_res_id,
                    resource_ttl_days=self.Parameters.resource_ttl_days,
                )
                children_tasks.append(child)

            # Start created tasks
            for task in children_tasks:
                task.enqueue()

            self.Context.children_task_ids = [t.id for t in children_tasks]
            raise sdk2.WaitTask(children_tasks, [ctt.Status.Group.FINISH, ctt.Status.Group.BREAK], wait_all=True)

        info = ''
        for task_id in self.Context.children_task_ids:
            task = sdk2.Task[task_id]
            status = task.status
            if status != 'SUCCESS':
                info += 'Task {} failed with status {}\n'.format(task_id, status)
        if info:
            self.set_info(info)

    def _parse_tags(self, task, config_name_tag_prefix, job_type_tag_prefix):
        tags = [t.lower() for t in task['tags']]

        def prefix_value(pfx):
            for tag in tags:
                if tag.startswith(pfx):
                    return tag[len(pfx):]
            raise RuntimeError('Cannot find required tab with prefix {}'.format(pfx))

        config_name = prefix_value(config_name_tag_prefix)
        job_type = prefix_value(job_type_tag_prefix)
        if job_type not in {'benchmark', 'test'}:
            raise RuntimeError('Wrong job_type value: {} (expected "benchmark" or "test")'.format(job_type))
        return config_name, job_type

    def _organize_resources(self, resource_ids, config_name_tag_prefix, job_type_tag_prefix):
        # Cannot use sdk2 here, because of error: 'UnknownTaskType: Unknown task type u'AUTOCHECK_BUILD_PARENT_2''
        # Sdk2 tries to create task object by it's type name and lookup for the AutocheckBuildParent2 class to construct the object
        # but the class doesn't exist in the current task binary
        resources = self.server.resource.read(id=resource_ids, limit=len(resource_ids), type=GRAPH_RESOURCE_TYPE)['items']
        task_id_set = set(r['task']['id'] for r in resources)
        tasks = self.server.task.read(id=task_id_set, fields=['id', 'tags'], limit=len(task_id_set))['items']
        task_by_id = {t['id']: t for t in tasks}
        result = {}  # {config_name: {partition: {job_type: resource_id}}}
        for resource in resources:
            partition = resource['attributes']['partition']
            task_id = resource['task']['id']
            task = task_by_id.get(task_id)
            if task is None:
                raise RuntimeError('Task for task_id={} not found'.format(task_id))
            config_name, job_type = self._parse_tags(task, config_name_tag_prefix, job_type_tag_prefix)
            result.setdefault((config_name, partition), {})[job_type] = resource['id']
        # Check completeness
        for key, res in result.items():
            config_name, partition = key
            for job_type in 'benchmark', 'test':
                if job_type not in res:
                    raise RuntimeError('Cannot find {} resource for config_name={} and partition={}'.format(
                                       job_type, config_name, partition))
        return result


class CacheTestGraphCompare(sdk2.Task):
    # Run on multislot
    class Requirements(sdk2.Requirements):
        cores = 2
        ram = 16384

        class Caches(sdk2.Requirements.Caches):
            pass  # Do not use any shared caches (required for running on multislot agent)

    class Parameters(sdk2.Task.Parameters):
        benchmark_resource_id = sdk2.parameters.Integer('benchmark_resource_id', required=True)
        test_resource_id = sdk2.parameters.Integer('test_resource_id', required=True)
        tar_name = sdk2.parameters.String('tar_name', default='subtract_on_dist.tar')
        path_in_tar = sdk2.parameters.String('path_in_tar', default='subtract_on_dist/execute/subtract/left/graph.json.uc')
        resource_ttl_days = sdk2.parameters.Integer('resource_ttl_days', default=7)

    def on_create(self):
        self.Requirements.tasks_resource = sdk2.service_resources.SandboxTasksBinary.find(
            owner=self.owner,
            attrs={'tasks_bundle': CacheTestGraphCompare.__name__},
        ).first()

    def on_execute(self):
        from sandbox.projects.devtools.CacheTestGraphCompare.compare import compare_graphs

        template = Environment(loader=BaseLoader).from_string(ERROR_MESSAGE_TEMPLATE)
        result_dir = self.path('compare_result')
        result_dir.mkdir(parents=True)

        bm_res = sdk2.Resource[self.Parameters.benchmark_resource_id]
        tst_res = sdk2.Resource[self.Parameters.test_resource_id]
        bm_res_data = sdk2.ResourceData(bm_res)
        tst_res_data = sdk2.ResourceData(tst_res)

        tmp_dir = tempfile.mkdtemp()
        bm_graph_path = os.path.join(tmp_dir, 'benchmark-graph.json')
        tst_graph_path = os.path.join(tmp_dir, 'test-graph.json')
        self._extract_graph(bm_res_data, bm_graph_path)
        self._extract_graph(tst_res_data, tst_graph_path)

        stat = compare_graphs(bm_graph_path, tst_graph_path, result_dir)

        if stat.total_error_count > 0:
            self.set_info(template.render(dict(stat=stat.__dict__, bm_res_id=bm_res.id, tst_res_id=tst_res.id)), do_escape=False)

            tar_path = self.path('compare_result.tgz')
            with tarfile.open(tar_path, mode='w:gz', compresslevel=5) as tar:
                tar.add(result_dir, 'compare_result')
            result_resource = CacheTestGraphCompareResult(self, 'Result of graphs comparing', tar_path, ttl=self.Parameters.resource_ttl_days)
            sdk2.ResourceData(result_resource).ready()

            if stat.fatal_error_count > 0:
                logging.info('Set ttl for resources %d, %d to %d days', bm_res.id, tst_res.id, self.Parameters.resource_ttl_days)
                bm_res.ttl = self.Parameters.resource_ttl_days
                tst_res.ttl = self.Parameters.resource_ttl_days
                raise common.errors.TaskFailure('Benchmark and test graphs are different. See report and logs for details')

        else:
            self.set_info('Graphs have no differences. Node count={}'.format(stat.benchmark_graph_node_count))

    def _extract_graph(self, res_data, dest_path):
        from library.python import compress
        tmp_dir = tempfile.mkdtemp()
        with tarfile.open(str(res_data.path / self.Parameters.tar_name)) as tar:
            tar.extract(self.Parameters.path_in_tar, tmp_dir)
        compress.decompress(os.path.join(tmp_dir, self.Parameters.path_in_tar), dest_path)


class CacheTestGraphCompareResult(sdk2.Resource):
    pass
