# -*- coding: utf-8 -*-

import collections
import jinja2
import json
import logging
import os
import re
from urllib import urlencode
from urlparse import urlparse, parse_qsl, urlunparse

from sandbox import sdk2
from sandbox import common
from sandbox.projects import resource_types
from sandbox.sandboxsdk.svn import Arcadia, Svn
from sandbox.sdk2.helpers import subprocess

from sandbox.projects.common.arcadia import sdk as arcadiasdk
from sandbox.projects.common.wizard import utils as wizard_utils
from sandbox.projects.websearch.begemot import parameters as bp
from sandbox.projects.websearch.begemot import resources as br


class CheckBegemotResponses(sdk2.Task):
    '''
    Process begemot logs and summarize info about exceptions
    '''

    FILTERED_EXCEPTIONS = [
        r'Search request is empty',
        r'Search request consists of non-letter characters, which are stripped entirely by normalization',
        r'search/begemot/rules/spellchecker/rule.cpp:\d+: request is too long',
        r'text too long \(over \d+B\)',
        # when synonyms are disabled metadata doesn't exist
        r'search/begemot/core/resource.h:\d+: Required resource Thesaurus large rule metadata not found',
        # TODO: BEGEMOT-2255
        r'search/begemot/rules/qtree/morphology/query/query.cpp:\d+: Created tree is corrupted: kernel/qtree/richrequest/verifytree.cpp:\d+: oOr node should have at least two children, but there is 1',
        r'UserId is invalid',  # invalid input for profile_log
    ]

    FILTERED_RELEVS = [
        'bgrtfactors',
        'dopp_url_counters_for_base',
        'fresh_detector_predict',
        'fresh_news_detector_predict',
    ]

    class Parameters(sdk2.Parameters):
        evlogdump = bp.BegemotEvlogdumpBinaryResource(required=True)
        begemot_eventlog = sdk2.parameters.Resource(
            'Logs from GetBegemotResponses',
            resource_type=br.BEGEMOT_EVENTLOG,
            required=True,
        )
        queries = sdk2.parameters.Resource(
            'Begemot queries',
            required=False
        )
        additional_queries = sdk2.parameters.String(
            'Additional cgi queries',
            required=False
        )
        begemot_responses = sdk2.parameters.Resource(
            'Responses from GetBegemotResponses',
            resource_type=br.BEGEMOT_RESPONSES_RESULT,
            required=False,
        )
        relev_to_debug = sdk2.parameters.String(
            'Relev to debug',
            required=False,
            description='e.g. AverageDistanceNormed',
        )
        arcadia_url = sdk2.parameters.String(
            'Arcadia URL',
            required=False,
            description='Required if additional queries set'
        )
        shard = sdk2.parameters.String(
            'Shard',
            required=False
        )

    class Requirements(sdk2.Requirements):
        client_tags = wizard_utils.ALL_SANDBOX_HOSTS_TAGS

    def is_filtered_exception(self, text):
        for ex in self.FILTERED_EXCEPTIONS:
            if re.match(ex, text):
                return True
        return False

    def _parse_additional_queries(self, queries={}):
        if self.Parameters.arcadia_url.startswith(Arcadia.ARCADIA_ARC_SCHEME):
            with arcadiasdk.mount_arc_path(self.Parameters.arcadia_url) as arc_arcadia:
                additional_requests_path = os.path.join(arc_arcadia, self.Parameters.additional_queries)
        else:
            arcadia_path = Arcadia.checkout(self.Parameters.arcadia_url, 'arcadia', depth=Svn.Depth.IMMEDIATES)
            additional_requests_path = os.path.join(arcadia_path, self.Parameters.additional_queries)
            Arcadia.update(additional_requests_path, depth=Svn.Depth.IMMEDIATES, parents=True)
        last_reqid = len(queries)
        with open(additional_requests_path, 'r') as f:
            for line in f.readlines():
                if line[0] != "@" and line != "\n":
                    last_reqid += 1
                    queries[str(last_reqid)] = line
        return queries

    def _parse_wizard_queries(self, path):
        query_map = {}
        with path.open('r') as queries:
            i = 0
            for line in queries:
                if line[0] != "@" and line != "\n":
                    i += 1
                    query_map[str(i)] = line
        if self.Parameters.additional_queries:
            query_map = self._parse_additional_queries(query_map)
        return query_map

    def _parse_queries(self):
        if self.Parameters.queries is not None:
            queries_path = sdk2.ResourceData(self.Parameters.queries).path
            if self.Parameters.queries.type.name == "BEGEMOT_CGI_QUERIES":
                return self._parse_wizard_queries(queries_path)
            uris = {}
            with queries_path.open('r') as queries:
                for line in queries:
                    try:
                        query = json.loads(line)
                    except ValueError:
                        logging.debug('Failed to parse line {line}'.format(line=line))
                        return uris
                    try:
                        reqid = None
                        uri = None
                        for q in query:
                            if (uri is None or reqid is None) and 'results' in q:
                                results = q['results']
                                if isinstance(results, dict):
                                    results = [results]
                                for res in results:
                                    if reqid is None:
                                        if 'reqid' in res:
                                            reqid = res['reqid']
                                        elif 'binary' in res and 'reqid' in res['binary']:
                                            reqid = res['binary']['reqid']
                                    if uri is None:
                                        if 'uri' in res:
                                            uri = res['uri']
                                        elif 'binary' in res and 'uri' in res['binary']:
                                            uri = res['binary']['uri']
                        if reqid is not None and uri is not None:
                            uris[reqid] = uri
                        else:
                            logging.debug('Failed to find uri or reqid, here is the complete query: {query}'.format(query=query))
                            return uris
                    except (KeyError, IndexError, TypeError):
                        logging.debug('Unexpected query JSON format, here is the complete query: {query}'.format(query=query))
                        return uris
            return uris
        elif self.Parameters.additional_queries:
            return self._parse_additional_queries()
        return {}

    def _shorten_uri(self, uri):
        if uri is None:
            return "Not found URI or query"
        parsed = urlparse(uri)
        parsed_query_list = parse_qsl(parsed.query)
        new_query_list = [x for x in parsed_query_list if 'srcrwr' not in x and 'metahost2' not in x and 'json_dump' not in x]
        new_query = urlencode(new_query_list)
        uri = urlunparse((parsed.scheme, parsed.netloc, parsed.path, parsed.params, new_query, parsed.fragment))
        return uri

    def _parse_relev(self, relev):
        result = {}
        for item in relev.strip().split(";"):
            kv = item.split('=', 1)
            try:
                result[kv[0]] = kv[1]
            except:
                pass
        return result

    def _diff_relevs(self, pairs_map):
        missmatch_cnts = collections.Counter()
        debug_key = None

        for key in pairs_map:
            p = pairs_map[key]
            first = self._parse_relev(p[0])
            second = self._parse_relev(p[1])
            for item in first:
                if item in self.FILTERED_RELEVS:
                    continue
                if item not in missmatch_cnts:
                    missmatch_cnts[item] = 0
                if item not in second or first[item] != second[item]:
                    if item == self.Parameters.relev_to_debug:
                        debug_key = key
                    missmatch_cnts[item] += 1
            for item in second:
                if item in self.FILTERED_RELEVS:
                    continue
                if item not in first:
                    if item == self.Parameters.relev_to_debug:
                        debug_key = key
                    missmatch_cnts[item] += 1
        return missmatch_cnts, debug_key

    def _validate_relevs(self):
        answers_map = {}
        reqids_map = {}
        queries_map = {}
        with open(str(sdk2.ResourceData(self.Parameters.queries).path), 'r') as f:
            for q in f.readlines():
                lr = set()
                text = None
                reqid = None
                for item in json.loads(q, encoding='utf-8'):
                    if item.get('name') == 'BEGEMOT_WORKERS_MISSPELL_OUT':
                        reqid = None
                        break
                    for i in item.get('results'):
                        if reqid is None and 'reqid' in i:
                            reqid = i['reqid']
                        res = i.get('binary')
                        if not isinstance(res, dict):
                            continue
                        if 'text' in res:
                            text = res['text']
                        if text is None and 'user_request' in res:
                            text = res['user_request']
                        if 'lr' in res:
                            if isinstance(res['lr'], dict):
                                lr.add(res['lr']['id'])
                            else:
                                lr.add(res['lr'])
                        if reqid is None and 'reqid' in res:
                            reqid = res['reqid']
                if isinstance(text, dict):
                    text = text['main']
                if text is None or not len(lr) or reqid is None:
                    continue
                lrs = " ".join(sorted(list(lr)))
                reqids_map[reqid] = (text, lrs)
                queries_map[reqid] = q

        responses, matches, mismatches, not_found = 0, 0, 0, 0
        mismatches_map = {}
        reqid_by_request = {}
        second_reqid_by_request = {}

        with open(str(sdk2.ResourceData(self.Parameters.begemot_responses).path), 'r') as f:
            for ans in f.readlines():
                responses += 1
                reqid = None
                relev = None
                for items in json.loads(ans, encoding='utf-8'):
                    if isinstance(items, dict):
                        items = [items]
                    for item in items:
                        if isinstance(item, dict) and item.get('type') == 'wizard':
                            relev = item.get('relev')
                        if 'reqid' in item:
                            reqid = item['reqid']

                if isinstance(reqid, list):
                    reqid = reqid[0]
                if reqid not in reqids_map or relev is None:
                    not_found += 1
                    continue

                text_lr = reqids_map[reqid]
                if text_lr in answers_map:
                    if reqid == reqid_by_request[text_lr]:
                        # skip hedgets
                        continue
                    if answers_map[text_lr] != relev:
                        mismatches += 1
                        mismatches_map[text_lr] = (answers_map[text_lr], relev)
                        second_reqid_by_request[text_lr] = reqid
                    else:
                        matches += 1
                else:
                    answers_map[text_lr] = relev
                    reqid_by_request[text_lr] = reqid

        diff_stats, text_lr_pair = self._diff_relevs(mismatches_map)
        queries_pair = (
            queries_map[reqid_by_request[text_lr_pair]],
            queries_map[second_reqid_by_request[text_lr_pair]]
        ) if text_lr_pair is not None else None

        if self.Parameters.relev_to_debug:
            if queries_pair is None:
                info = "Relev {} is a function of user_request, lr".format(self.Parameters.relev_to_debug)
            else:
                info = "Relev {} is NOT a function of user_request, lr. See examples in resources query_1.txt, query_2.txt, produced by this task".format(self.Parameters.relev_to_debug)
        else:
            info = "\n".join([
                "Responses analyzed: {}".format(responses - not_found),
                "Unique requests: {}".format(len(answers_map)),
                "Same requests, same responses: {}".format(matches),
                "Same requests, different responses (BAD): {}".format(mismatches),
                "",
                "Relevs with diff:",
            ])
            for k, v in diff_stats.most_common(len(diff_stats)):
                info = "\n".join([info, "{}: {:.2f}%".format(k, 100 * (float(v) / mismatches))])

        self.set_info(info)

        if queries_pair is not None:
            for num in [1, 2]:
                output_file = 'query_{}.txt'.format(num)
                res_type = resource_types.OTHER_RESOURCE
                output = res_type(
                    self,
                    "query_{}.txt (output of CHECK_BEGEMOT_RESPONSES)".format(num),
                    output_file
                )
                with open(output_file, 'w') as out:
                    out.write("{}\n".format(queries_pair[num - 1]))
                sdk2.ResourceData(output).ready()

            return False

        return True # in future return False for failed validation

    def _check_bert_responses(self):
        cnt, has_embedding = 0, 0
        with open(str(sdk2.ResourceData(self.Parameters.begemot_responses).path), 'r') as f:
            for ans in f.readlines():
                cnt += 1
                for item in json.loads(ans, encoding='utf-8'):
                    if item.get('type') == 'wizard':
                        try:
                            if item['rules']['BertInference']['FullSplitEmbedding '] is not None:
                                has_embedding += 1
                        except:
                            pass
                        break

        if has_embedding * 2 < cnt:
            return False, 'Only {} of {} responses have SerializedEmbedding in BertInference output'.format(has_embedding, cnt)
        return True, None

    def _check_webfresh_post_setup(self):
        cnt, with_webfresh_post_setup = 0, 0
        with open(str(sdk2.ResourceData(self.Parameters.begemot_responses).path), 'r') as f:
            for ans in f.readlines():
                cnt += 1
                found = False
                for item in json.loads(ans):
                    if found:
                        break
                    for i in item:
                        if isinstance(i, dict):
                            if i.get('rules', {}).get('WebFreshPostSetup', {}):
                                found = True

                if found:
                    with_webfresh_post_setup += 1

        percent = 100 * float(with_webfresh_post_setup) / cnt
        self.set_info('WebFreshPostSetup found in {}% answers'.format(percent))
        if percent > 90:
            return True, None
        return False, 'Percent of answers with WebFreshPostSetup is too low. At least 90% Required'

    def on_save(self):
        wizard_utils.setup_hosts(self)

    def on_execute(self):
        if self.Parameters.begemot_responses and not wizard_utils.validate_utf8(str(sdk2.ResourceData(self.Parameters.begemot_responses).path)):
            self.set_info(wizard_utils.INVALID_UNICODE_FAILURE_MESSAGE)

        relevs_ok, shard_special_check = True, True
        if self.Parameters.begemot_responses and self.Parameters.queries:
            relevs_ok = self._validate_relevs()

        shard_issue = None
        if self.Parameters.shard == 'Bert':
            shard_special_check, shard_issue = self._check_bert_responses()
        if self.Parameters.shard == 'SrcSetup + Merger':
            shard_special_check, shard_issue = self._check_webfresh_post_setup()

        begemot_eventlog = sdk2.ResourceData(self.Parameters.begemot_eventlog).path
        evlogdump = sdk2.ResourceData(self.Parameters.evlogdump).path
        evlogdump_err = self.log_path() / 'evlogdump.err.txt'

        reqids = {}
        query_uri = {}
        results = collections.defaultdict(lambda: [collections.defaultdict(dict), 0])
        rules_counter = collections.defaultdict(int)
        has_error = False
        parsed_queries = False

        with begemot_eventlog.open('r') as input_file, evlogdump_err.open('w') as err_file:
            p = subprocess.Popen([evlogdump.as_posix()], stdin=input_file, stdout=subprocess.PIPE, stderr=err_file)
            for line in p.stdout:
                l = line.strip('\n').split('\t', 5)
                frame_id = int(l[1])
                event_type = l[2]
                if event_type == 'TRequestInput':
                    if l[3] == 'InternalContext':
                        rules_counter['.request_error'] += 1
                elif event_type == 'TRuleStart':
                    rule_name = l[3]
                    rules_counter[rule_name] += 1
                elif event_type == 'THttpRequestId':
                    reqids[frame_id] = l[3]
                elif event_type in ('TRequestError', 'TRuleError', 'TLingBoostRuleError'):
                    if event_type == 'TRequestError':
                        rule_name = '.request_error'
                        event_type = l[3]
                    elif event_type == 'TLingBoostRuleError':
                        rule_name = 'LingBoost'
                    else:
                        rule_name = l[3]
                    exception_text = l[-1].strip()
                    if not self.is_filtered_exception(exception_text):
                        has_error = True

                    if not parsed_queries:
                        query_uri = self._parse_queries()
                        parsed_queries = True
                    results[(rule_name, event_type)][1] += 1
                    if frame_id in reqids:
                        results[(rule_name, event_type)][0][exception_text][reqids[frame_id]] = self._shorten_uri(query_uri.get(reqids[frame_id]))
                    else:
                        results[(rule_name, event_type)][0][exception_text][frame_id] = self._shorten_uri(query_uri.get(frame_id))

            p.wait()
            if p.returncode != 0:
                raise Exception('evlogdump has died with code {}'.format(p.evlogdump))

        ctx_results = {}

        for (rule_name, event_type), (events, event_count) in results.iteritems():
            rule_events = ctx_results.get(rule_name)
            if rule_events is None:
                rule_events = ctx_results[rule_name] = [{}, rules_counter[rule_name]]
            event = ctx_results[rule_name][0][event_type] = [event_count, '', '', '']
            for exception_text, queries in events.iteritems():
                if not event[1]:
                    event[1] = exception_text
                for reqid in queries:
                    if not event[2] or event[3] is None:
                        event[2] = reqid
                        event[3] = queries[reqid]

        self.Context.results = ctx_results

        if len(results):
            output = resource_types.OTHER_RESOURCE(
                self,
                'CheckBegemotResponses output',
                'output'
            )
            self.Context.output_url = output.http_proxy
            output_dir = output.path
            output_dir.mkdir(exist_ok=True)
            for (rule_name, event_type), (events, event_count) in results.iteritems():
                with (output_dir / ('{}_{}.json'.format(rule_name, event_type))).open('wb') as out_file:
                    json.dump(events, out_file)
            sdk2.ResourceData(output).ready()

            if has_error:
                raise common.errors.TaskFailure('There are errors')
        if not shard_special_check:
            raise common.errors.TaskFailure('Shard check failure: {}'.format(shard_issue))
        if not relevs_ok:
            raise common.errors.TaskFailure('Relevs validation failure')

    @sdk2.header()
    def header(self):
        if self.Context.results:
            template_path = os.path.dirname(os.path.abspath(__file__))
            env = jinja2.Environment(loader=jinja2.FileSystemLoader(template_path), extensions=['jinja2.ext.do'])
            return env.get_template('header.html').render({'results': self.Context.results, 'output_url': self.Context.output_url})
