import itertools
import jinja2
import json
import logging
import os
import re
import requests
import sys

from collections import OrderedDict
from pathlib2 import Path
from sandbox import common, sdk2
from sandbox.common import errors
from sandbox.projects.common import binary_task, task_env
from sandbox.projects.common.vcs.arc import Arc
from sandbox.projects.mt.make.util import mount_arc_with_retries, post_pull_request_comment, run_mt_make_tool

import sandbox.common.types.notification as ctn
import sandbox.common.types.task as ctt


ARC_HOST = 'api.arc-vcs.yandex-team.ru:6734'
ARC_MOUNT_MAX_ATTEMPTS = 5


class EvalMtdata(binary_task.LastBinaryTaskRelease, sdk2.Task):

    class Parameters(sdk2.Task.Parameters):
        ext_params = binary_task.binary_release_parameters(stable=True)

        description = "Eval neural machine translation quality in commit"
        owner = "MT"

        notifications = [
            sdk2.Notification(
                statuses=[
                    ctt.Status.FAILURE,
                    ctt.Status.EXCEPTION,
                    ctt.Status.TIMEOUT
                ],
                recipients=["alexeynoskov@yandex-team.ru", "dronte@yandex-team.ru"],
                transport=ctn.Transport.EMAIL
            )
        ]

        old_revision_hash = sdk2.parameters.String("Arc hash of old revision")
        new_revision_hash = sdk2.parameters.String("Arc hash of new revision")

        pull_request_id = sdk2.parameters.Integer("Id of pull request to write comment into")

        quota = sdk2.parameters.String("Name of nirvana quota", default="mt-eval")
        secret = sdk2.parameters.YavSecret("YAV secret identifier (with optional version)")

    class Requirements(task_env.BuildLinuxRequirements):
        pass

    class Context(sdk2.Task.Context):
        updated_directions_services = None

        old_graph_ids = None
        new_graph_ids = None

        old_metrics = None
        new_metrics = None
        monitoring_metrics = None

    @sdk2.header(title="Report")
    def header(self):
        return self.render_template('header.tmpl')

    def on_execute(self):
        assert self.Parameters.old_revision_hash is not None
        assert self.Parameters.new_revision_hash is not None

        # First we need to determine which directions and services are updated
        if self.Context.updated_directions_services is None:
            logging.info("Detecting updated directions")
            self.Context.updated_directions_services = self.detect_changed_directions_services()

        if len(self.Context.updated_directions_services) == 0:
            logging.info("No directions changed - exiting")
            return

        if self.Context.old_graph_ids is None or self.Context.new_graph_ids is None:
            self.build_and_run_all_graphs()

        api = self.create_nirvana_api()
        self.wait_graph_completion(api)

        if self.Context.old_metrics is None:
            self.Context.old_metrics = self.get_metrics_from_graphs(
                api, self.Context.old_graph_ids, 'get_metrics_from_ev')

        if self.Context.new_metrics is None:
            self.Context.new_metrics = self.get_metrics_from_graphs(
                api, self.Context.new_graph_ids, 'get_metrics_from_ev')

        if self.Context.monitoring_metrics is None:
            self.Context.monitoring_metrics = self.get_metrics_from_graphs(
                api, self.Context.new_graph_ids, 'get_monitoring_from_ev')

        if self.Parameters.pull_request_id > 0:
            logging.info("Posting results")
            self.post_comment()

    def get_changed_lines(self):
        import grpc
        from arc.api.public.repo_pb2 import DiffRequest
        from arc.api.public.repo_pb2_grpc import DiffServiceStub
        from arc.api.public.shared_pb2 import FlatPath

        arc_token = self.Parameters.secret.data()['arc-token']

        creds = grpc.composite_channel_credentials(
            grpc.ssl_channel_credentials(),
            grpc.access_token_call_credentials(arc_token),
        )
        channel = grpc.secure_channel(ARC_HOST, creds)

        diff = DiffServiceStub(channel).Diff(DiffRequest(
            FromRevision=self.Parameters.old_revision_hash,
            ToRevision=self.Parameters.new_revision_hash,
            Mode=FlatPath,
            PathFilter=['dict/mt/data.yaml'],
            ContextSize=0
        ))
        return diff

    def get_services_from_changed_files(self, changed_files, revision):
        arc_token = self.Parameters.secret.data()['arc-token']
        arc = Arc(arc_oauth_token=arc_token)

        work_dir = Path(os.getcwd()).joinpath('config_generator_%s' % revision)
        work_dir.mkdir()

        with mount_arc_with_retries(arc, changeset=revision) as arc_mount_path:
            run_mt_make_tool(
                'dict/mt/make/tools/config_generator/config_generator',
                ['--deps'],
                arcadia_path=arc_mount_path,
                secrets=self.Parameters.secret.data(),
                work_dir=work_dir
            )

        services = set()
        for file_path in Path(work_dir).glob('*.json.deps'):
            file_name = file_path.name
            direction = file_name.split('.')[0]

            if '.nmt.' in file_name:
                key = 'mt_service'
            elif '.browser-nmt.' in file_name:
                key = 'browser_mt_service'
            else:
                continue

            with file_path.open('r') as direction_deps:
                for dep in direction_deps:
                    if dep.strip() in changed_files:
                        services.add((direction, key))
                        break

        logging.info("Changed services in %s: %r", revision, services)
        return services

    def detect_changed_directions_services(self):
        diff = self.get_changed_lines()

        # We track changes in models (for now)
        changed_model_files = set()
        for chunk in diff:
            for line in chunk.Data.split('\n'):
                m = re.match(r'(\+|\-)(data\/.*\.npz):', line)
                if m is not None:
                    changed_model_files.add(m.group(2))

        # Get services from both revision
        old_services = self.get_services_from_changed_files(changed_model_files, self.Parameters.old_revision_hash)
        new_services = self.get_services_from_changed_files(changed_model_files, self.Parameters.new_revision_hash)

        # Group services by direction
        services_by_direction = {}
        for d, g in itertools.groupby(sorted(old_services | new_services), lambda t: t[0]):
            services_by_direction[d] = [t[1] for t in g]

        logging.info("Services to eval: %r", services_by_direction)
        return services_by_direction

    def build_and_run_all_graphs(self):
        arc_token = self.Parameters.secret.data()['arc-token']
        arc = Arc(arc_oauth_token=arc_token)

        nirvana_token = self.Parameters.secret.data()['nirvana-token']
        nirvana_quota = self.Parameters.quota

        home_path = Path(os.getcwd())

        env = {
            'HOME': str(home_path),
            'YA_TOKEN': self.Parameters.secret.data()['ya-token'],
            'YT_TOKEN': self.Parameters.secret.data()['yt-token']
        }

        with mount_arc_with_retries(arc) as arc_mount_path:
            vhrc_path = home_path.joinpath(".vhrc")
            vhrc_path.write_text('\n'.join([
                "--oauth-token=%s" % nirvana_token,
                "--quota=%s" % nirvana_quota,
                "--arcadia-root=%s" % arc_mount_path,
            ]))
            vhrc_path.chmod(0o400)

            if self.Context.old_graph_ids is None:
                arc.checkout(arc_mount_path, self.Parameters.old_revision_hash, force=True)
                self.Context.old_graph_ids = self.build_and_run_graphs(home_path, arc_mount_path, env, graph_name="old")

            if self.Context.new_graph_ids is None:
                arc.checkout(arc_mount_path, self.Parameters.new_revision_hash, force=True)
                self.Context.new_graph_ids = self.build_and_run_graphs(
                    home_path,
                    arc_mount_path,
                    env,
                    graph_name="new",
                    fetch_monitoring_systems=['yandex', 'google'])

    def build_and_run_graphs(self, home_path, arc_mount_path, env, graph_name, fetch_monitoring_systems=None):
        logging.info("Generating %s eval graphs" % graph_name)

        binary_dir = os.path.join("dict", "mt", "make", "tools", "eval_nmt")
        binary_path = os.path.join(arc_mount_path, binary_dir, "eval_nmt")

        with sdk2.helpers.ProcessLog(self, logger='%s-build' % graph_name) as pl:
            sdk2.helpers.subprocess.check_call(
                [
                    os.path.join(arc_mount_path, "ya"),
                    "make",
                    "--yt-store", "-r",
                    binary_dir
                ],
                cwd=arc_mount_path, env=env,
                stdout=pl.stdout, stderr=pl.stderr,
            )
        work_path = home_path.joinpath('work-%s' % graph_name)
        work_path.mkdir()

        direction_graphs = {}
        for direction in self.Context.updated_directions_services:
            direction_graphs[direction] = {}
            for service in self.Context.updated_directions_services[direction]:
                wi_path = work_path.joinpath("%s_%s.json" % (direction, service))

                add_args = []
                if fetch_monitoring_systems:
                    add_args.append('--ev-fetch-results-system')
                    add_args.extend(fetch_monitoring_systems)

                with sdk2.helpers.ProcessLog(self, logger='%s-graph-%s-%s' % (graph_name, direction, service)) as pl:
                    sdk2.helpers.subprocess.check_call(
                        [
                            str(binary_path),
                            direction, "--directory", direction,
                            "--nmt-service-key", service,
                            "--nmt", "--use-local-vocs", "--use-local-model",
                            "--mode", "nirvana", "--mtdata", "@arcadia",
                            "--ev-eval",
                            "--write-workflow-info", str(wi_path)
                        ] + add_args,
                        cwd=str(work_path), env=env,
                        stdout=pl.stdout, stderr=pl.stderr,
                    )

                direction_graphs[direction][service] = json.loads(wi_path.read_text())

        return direction_graphs

    def wait_graph_completion(self, api):
        wi_ids = []
        for direction in self.Context.new_graph_ids:
            for service in self.Context.new_graph_ids[direction]:
                for old_or_new in [self.Context.old_graph_ids, self.Context.new_graph_ids]:
                    wi_ids.append(old_or_new[direction][service]['workflow_instance_id'])
        wi_states = api.get_workflow_execution_states(wi_ids)

        if any(s['status'] != 'completed' for s in wi_states):
            raise sdk2.WaitTime(180)

        if any(s['result'] != 'success' for s in wi_states):
            raise errors.TaskFailure("Eval graphs didn't complete successfully")

    def get_metrics_from_graphs(self, api, graphs, block_name):
        metrics = {}
        for direction in graphs:
            metrics[direction] = {}
            for service, graph in graphs[direction].items():
                metrics[direction][service] = self.get_metrics_from_graph(
                    api, direction, graph, block_name)
        return metrics

    def get_metrics_from_graph(self, api, direction, graph, block_name):
        def get_block_guids_by_name(workflow):
            guids_by_name = {}

            for b in workflow['blocks']:
                if b['name']:
                    if b['name'] in guids_by_name:
                        logging.error("Duplicate block %s, guids %s and %s" % (b['name'], b['blockGuid'], guids_by_name[b['name']]))
                    guids_by_name[b['name']] = b['blockGuid']

            return guids_by_name

        def parse_metrics(metrics_text):
            metrics = json.loads(metrics_text)
            return metrics

        workflow = api.get_workflow(graph['workflow_id'], graph['workflow_instance_id'])
        block_guids_by_name = get_block_guids_by_name(workflow)

        block_guid = block_guids_by_name[block_name]

        block_results = api.get_block_results(
            block_patterns=[{'guid': block_guid}],
            workflow_id=graph['workflow_id'],
            workflow_instance_id=graph['workflow_instance_id'],
        )
        block_results = {b['blockGuid']: {r['endpoint']: r['directStoragePath'] for r in b['results'] if r['endpoint'] != 'exception' and 'directStoragePath' in r} for b in block_results}

        return parse_metrics(requests.get(block_results[block_guid]['output_0']).text)

    def create_nirvana_api(self):
        from nirvana_api import NirvanaApi
        return NirvanaApi(self.Parameters.secret.data()['nirvana-token'])

    def post_comment(self):
        post_pull_request_comment(
            pull_request_id=self.Parameters.pull_request_id,
            comment=self.render_template("comment.tmpl"),
            arcanum_token=self.Parameters.secret.data()['arcanum-token']
        )

    def render_template(self, template_name):
        def make_graph_urls(graphs):
            if graphs is None:
                return {}
            graph_links = {}
            for direction in graphs:
                graph_links[direction] = {}
                for service in graphs[direction]:
                    graph_links[direction][service] = (
                        "https://nirvana.yandex-team.ru/flow/{}/{}/graph".format(
                            graphs[direction][service]['workflow_id'],
                            graphs[direction][service]['workflow_instance_id']))
            return graph_links

        def combine_testset_metrics(old, new, monitoring):
            res = {}
            res['old'] = {}
            res['new'] = {}
            res['yandex'] = {}
            res['google'] = {}
            res['diff'] = {}
            for k in set(old.keys()) | set(new.keys()):
                ov = old.get(k)
                nv = new.get(k)

                res['yandex'] = monitoring.get('yandex', {})
                res['google'] = monitoring.get('google', {})
                res['old'][k] = ov
                res['new'][k] = nv
                res['diff'][k] = ov and nv and nv - ov
            return res

        def combine_direction_metrics(old, new, monitoring):
            return OrderedDict(
                (t,
                 combine_testset_metrics(
                     old.get(t, {}),
                     new.get(t, {}),
                     monitoring.get(t, {})))
                for t in sorted(set(old.keys()) | set(new.keys()))
            )

        if self.Context.old_metrics is not None and self.Context.new_metrics is not None:
            monitoring_metrics = self.Context.monitoring_metrics or {}
            directions = sorted(set(self.Context.old_metrics.keys()) | set(self.Context.new_metrics.keys()))
            metrics = {}
            for direction in directions:
                metrics[direction] = {}
                old_direction_metrics = self.Context.old_metrics[direction]
                new_direction_metrics = self.Context.new_metrics[direction]
                for service in self.Context.old_metrics[direction]:
                    metrics[direction][service] = {}
                    metrics[direction][service] = combine_direction_metrics(
                        old_direction_metrics.get(service, {}),
                        new_direction_metrics.get(service, {}),
                        monitoring_metrics[direction].get(service, {}))
        else:
            metrics = None

        testsets_keys = {}
        metric_keys = {}
        if metrics is not None:
            for direction in metrics:
                testsets_keys[direction] = {}
                metric_keys[direction] = {}
                for service in metrics[direction]:
                    metric_keys[direction][service] = []
                    testsets_keys[direction][service] = []
                    for testset in metrics[direction][service]:
                        testsets_keys[direction][service].append(testset)
                        for metric_key in itertools.chain(
                                metrics[direction][service][testset]['old'],
                                metrics[direction][service][testset]['new']):
                            if metric_key not in metric_keys[direction][service]:
                                metric_keys[direction][service].append(metric_key)
                    testsets_keys[direction][service].sort()
                    metric_keys[direction][service].sort()

        def ev_result_link(inner_value, result_id):
            if isinstance(inner_value, float):
                inner_value = '%.2f' % inner_value
            if result_id is None:
                return inner_value
            return '[{text}]({link})'.format(
                text=inner_value,
                link='https://text-translate-ev.n.yandex-team.ru/web/1/result?resultId=%d' % int(result_id)
            )

        def ev_result_diff(inner_value, old_result_id, new_result_id):
            if isinstance(inner_value, float):
                inner_value = '%+.3f' % inner_value
            return '[{text}]({link})'.format(
                text=inner_value,
                link='https://text-translate-ev.n.yandex-team.ru/web/1/result_diff?oldResultId=%d&newResultId=%d' % (
                    int(old_result_id),
                    int(new_result_id))
            )
        jinja2.filters.FILTERS['ev_result_link'] = ev_result_link
        jinja2.filters.FILTERS['ev_result_diff'] = ev_result_diff

        template_path = os.path.join(os.path.dirname(__file__), template_name)
        jinja2_env = jinja2.Environment()

        template = jinja2_env.from_string(common.fs.read_file(template_path))

        template_context = {
            'updated_directions_services': self.Context.updated_directions_services,
            'old_graph_urls': make_graph_urls(self.Context.old_graph_ids),
            'new_graph_urls': make_graph_urls(self.Context.new_graph_ids),
            'metrics': metrics,
            'testsets_keys': testsets_keys,
            'metric_keys': metric_keys
        }

        # Intentionaly leaving for debugging in case of errors
        sys.stderr.write('\ntemplate_context =\n')
        sys.stderr.write(json.dumps(template_context, indent=4))
        sys.stderr.write('\n')

        return template.render(template_context)
