import itertools
import jinja2
import logging
import math
import os
import re

from collections import OrderedDict
from pathlib2 import Path
from sandbox import common, sdk2
from sandbox.common.errors import ResourceNotFound
from sandbox.projects.common import binary_task, error_handlers, task_env
from sandbox.projects.mt.make.util import post_pull_request_comment
from sandbox.projects.resource_types import ARCADIA_PROJECT

import sandbox.common.types.client as ctc
import sandbox.common.types.notification as ctn
import sandbox.common.types.task as ctt


ARC_HOST = 'api.arc-vcs.yandex-team.ru:6734'


class CheckMtdata(binary_task.LastBinaryTaskRelease, sdk2.Task):

    class Parameters(sdk2.Task.Parameters):
        ext_params = binary_task.binary_release_parameters(stable=True)

        description = "Check mtdata correctness in commit"
        owner = "MT"

        notifications = [
            sdk2.Notification(
                statuses=[
                    ctt.Status.FAILURE,
                    ctt.Status.EXCEPTION,
                    ctt.Status.TIMEOUT
                ],
                recipients=["alexeynoskov@yandex-team.ru"],
                transport=ctn.Transport.EMAIL
            )
        ]

        old_revision_hash = sdk2.parameters.String("Arc hash of old revision", required=False)
        new_revision_hash = sdk2.parameters.String("Arc hash of new revision", required=True)

        pull_request_id = sdk2.parameters.Integer("Id of pull request to write comment into", required=False)

        secret = sdk2.parameters.YavSecret("YAV secret identifier containing arc-token and arcanum-token (with optional version)", required=True)

        fix = sdk2.parameters.Bool("Try to fix errors", default=False)

    class Context(sdk2.Context):
        errors = []  # Errors found in mtdata

    class Requirements(task_env.TinyRequirements):
        pass

    def on_execute(self):
        mtdata = self.get_mtdata_content()

        logging.info("Checking resources")
        for key, loc in mtdata.items():
            self.check_resource(key, loc)
        logging.info("Checked %d lines from mdata", len(mtdata))

        if self.Context.errors:
            logging.error("There are errors in mtdata:")
            for t in self.Context.errors:
                logging.error("%s (%s): %s" % t)

            if self.Parameters.pull_request_id is not None and self.Parameters.pull_request_id > 0:
                logging.info("Posting results")
                self.post_comment(self.Context.errors)

            error_handlers.check_failed("Mtdata contains errors in resources: " + ', '.join(t[0] for t in self.Context.errors))
        else:
            logging.info("Everything is OK")

    def check_resource(self, key, loc):
        if not loc.startswith('sbr:'):
            logging.info("Skipping %r: %r", key, loc)
            return

        # Get and check resource metadata
        resource_id = int(loc[4:])

        try:
            resource = sdk2.Resource[resource_id]
        except ResourceNotFound:
            self.Context.errors.append((key, loc, 'Resource #%s not found' % resource_id))
            return

        if resource.owner != 'MT':
            self.Context.errors.append((key, loc, 'Owner is not MT: %r' % resource.owner))

        if not isinstance(resource.ttl, float) or not math.isinf(resource.ttl):
            if self.Parameters.fix and resource.owner == 'MT':
                resource.ttl = float('inf')
            else:
                self.Context.errors.append((key, loc, 'TTL is not inf: %r' % resource.ttl))

        if not hasattr(resource, 'creation_date'):
            if self.Parameters.fix and resource.owner == 'MT':
                resource.creation_date = resource.created.strftime('%Y-%m-%d')
            else:
                self.Context.errors.append((key, loc, 'Resource has no creation_date attribute'))

        # TODO: Check resource contents

    def get_mtdata_content(self):
        import grpc
        from arc.api.public.repo_pb2 import DiffRequest, ReadFileRequest
        from arc.api.public.repo_pb2_grpc import DiffServiceStub, FileServiceStub
        from arc.api.public.shared_pb2 import FlatPath

        arc_token = self.Parameters.secret.data()['arc-token']

        creds = grpc.composite_channel_credentials(
            grpc.ssl_channel_credentials(),
            grpc.access_token_call_credentials(arc_token),
        )
        channel = grpc.secure_channel(ARC_HOST, creds)

        if self.Parameters.old_revision_hash:
            logging.info('Diffing mtdata between %s and %s', self.Parameters.old_revision_hash, self.Parameters.new_revision_hash)

            diff = DiffServiceStub(channel).Diff(DiffRequest(
                FromRevision=self.Parameters.old_revision_hash,
                ToRevision=self.Parameters.new_revision_hash,
                Mode=FlatPath,
                PathFilter=['dict/mt/data.yaml'],
                ContextSize=0
            ))

            logging.info('Diff: %r' % diff)

            res = {}
            for chunk in diff:
                for line in chunk.Data.split('\n'):
                    if line.startswith('+++') or not line.startswith('+'):
                        continue
                    m = re.match(r'\+([^:\s]+):\s*([^\s]+)', line)
                    if m is None:
                        raise RuntimeError("Invalid mtdata diff line: %r" % line)
                    res[m.group(1)] = m.group(2)
        else:
            logging.info('Receiving mtdata from %s', self.Parameters.new_revision_hash)

            data = b''
            header = None
            for resp in FileServiceStub(channel).ReadFile(ReadFileRequest(
                Revision=self.Parameters.new_revision_hash,
                Path=str('dict/mt/data.yaml'
            ))):
                if resp.HasField('Data'):
                    data += resp.Data
                elif header is None:
                    header = resp.Header
                else:
                    raise RuntimeError("Two headers in response")

            if header is None:
                raise RuntimeError("No header is response")

            if len(data) != header.FileSize:
                raise RuntimeError("Wrong data size: %r != %r" % (len(data), header.FileSize))  # noqa

            res = {}
            for line in data.decode('utf-8').rstrip('\n').split('\n'):
                m = re.match(r'([^:\s]+):\s*([^\s]+)', line)
                if m is None:
                    raise RuntimeError("Invalid mtdata line: %r" % line)
                res[m.group(1)] = m.group(2)

        return res

    def post_comment(self, errors):
        template = jinja2.Template(common.fs.read_file(os.path.join(os.path.dirname(__file__), 'comment.tmpl')))

        post_pull_request_comment(
            pull_request_id=self.Parameters.pull_request_id,
            comment=template.render({'errors': errors}),
            arcanum_token=self.Parameters.secret.data()['arcanum-token']
        )
