import json
import logging
import codecs
import os
import urllib2

from collections import defaultdict
from datetime import datetime
from sandbox import sdk2
from sandbox.common.types.misc import NotExists
from sandbox.projects import resource_types
from sandbox.sandboxsdk import environments
from random import shuffle


class YaneUploadMarkupToToloka(sdk2.Task):
    """ Task processes texts from input resource with markupdaemon, groups individual text hypos by position in text
        and uploads data to YT for further processing in Toloka.
    """

    class Requirements(sdk2.Task.Requirements):
        environments = [
            environments.PipEnvironment('yandex-yt'),
            environments.PipEnvironment('yandex-yt-yson-bindings-skynet')
        ]

    class Parameters(sdk2.Task.Parameters):
        # common parameters
        kill_timeout = 3600

        # custom parameters
        texts = sdk2.parameters.Resource("Texts for processing (each text on new line, \\n replaced with \\t)",
                                         resource_type=resource_types.OTHER_RESOURCE, required=False)
        yt_proxy = sdk2.parameters.String("YT proxy", default="hahn", required=True)
        highlighter_url = sdk2.parameters.String("Markupdaemon URL", required=True)
        yt_output_folder = sdk2.parameters.String("YT output folder", required=True)
        yt_token_vault_name = sdk2.parameters.Vault("YT token vault name", required=True)
        good_sources = sdk2.parameters.List("Good key sources codes (ontoids from these sources will not be rejected on -200 score)", required=False)
        bad_sources = sdk2.parameters.List("Sources to exclude at all when generating Toloka tasks", required=False)

    def fetch_markup_from_highlighter(self, text):
        content = text.encode('utf-8')
        req = urllib2.Request(self.Parameters.highlighter_url, data=content)
        response = urllib2.urlopen(req)
        if response.code != 200:
            raise Exception("Could'nt highlight text: %s" % response.code)

        response_data_json = response.read()

        if response_data_json in ['', None, 'null', '[]', '{}']:
            # log.debug('Empty response')
            return None

        response_data = json.loads(response_data_json)

        return response_data

    def read_texts_from_file(self, filestream):
        texts = []
        for line in filestream:
            texts.append(line)
        return texts

    def group_hypos(self, predicted_markup):
        """Groups hypos by position in text. Key is composed by two values: (hypo.begin, hypo.length)"""
        grouped_markup = defaultdict(list)
        for hypo in predicted_markup:
            key = (hypo['begin'], hypo['length'])
            grouped_markup[key].append(hypo)
        for key in grouped_markup.keys():
            grouped_markup[key].sort(key=lambda elem: elem['score'], reverse=True)
        return grouped_markup

    def filter_hypo_fields(self, hypo):
        required_keys = ["title", "url", "id"]
        cleaned_hypo = {required_key: hypo[required_key] for required_key in required_keys}
        return cleaned_hypo

    def filter_bad_hypos(self, markup):
        if (len(self.Parameters.good_sources) + len(self.Parameters.bad_sources)) == 0:
            return markup

        good_markup = []
        for hypo in markup:
            if hypo['score'] == -200 and hypo['id'][:3] not in self.Parameters.good_sources:
                continue
            if hypo['id'][:3] in self.Parameters.bad_sources:
                continue
            good_markup.append(hypo)
        return good_markup

    def on_execute(self):
        from yt.wrapper import YtClient, JsonFormat
        texts_file_path = sdk2.ResourceData(self.Parameters.texts).path
        yt_client = YtClient(proxy=self.Parameters.yt_proxy, token=self.Parameters.yt_token_vault_name.data())

        texts = []
        with codecs.open(str(texts_file_path), encoding='utf-8') as texts_file:
            texts = self.read_texts_from_file(texts_file)

        yt_lines = []

        count = 1
        text_count = 1

        for text in texts:
            markup = self.fetch_markup_from_highlighter(text)
            if markup is None:
                continue
            filtered_markup = self.filter_bad_hypos(markup)
            grouped_hypos = self.group_hypos(filtered_markup)
            for (begin, length), hypos in grouped_hypos.iteritems():
                cleaned_hypos = [self.filter_hypo_fields(hypo) for hypo in hypos]
                cleaned_hypos = cleaned_hypos[:10]
                shuffle(cleaned_hypos)
                yt_lines.append({"begin": begin, "length": length, "text": text.rstrip(), "hypos": cleaned_hypos, "yaneTaskId": str(count), "yaneTextId": str(text_count)})
                count += 1
            text_count += 1
            if text_count % 100 == 0:
                logging.info("Processed " + str(text_count) + " texts out of " + str(len(texts)))

        if self.Context.yt_table == NotExists:
            yt_table = os.path.join(self.Parameters.yt_output_folder, str(datetime.now()))
            yt_client.create("table",
                             yt_table,
                             attributes={
                                 'schema': [
                                     {"name": "begin", "type": "uint64"},
                                     {"name": "length", "type": "uint64"},
                                     {"name": "text", "type": "string"},
                                     {"name": "hypos", "type": "any"},
                                     {"name": "yaneTaskId", "type": "string"},
                                     {"name": "yaneTextId", "type": "string"}
                                 ]})
            self.Context.yt_table = yt_table

        yt_client.write_table(
            self.Context.yt_table,
            yt_lines,
            format=JsonFormat(attributes={
                'encode_utf8': False,
                'plain': True,
            })
        )
