import json
import itertools
import logging
import os
import re

from sandbox import sdk2
from sandbox.projects.common import file_utils
from sandbox.projects.suggest.dicts import SuggestDictTask
from sandbox.sandboxsdk.environments import PipEnvironment


def normalize(query):
    return re.sub('\t|\n', ' ', query).lower().encode('utf-8')


def format_float(f):
    return format(f, '.8f').rstrip('0').rstrip('.')


def format_regional_frequencies(freqs):
    regional_freq_strs = []

    for f in freqs:
        region_str = str(f[0])
        freq_str = format(f[1], '.8f').rstrip('0').rstrip('.')
        regional_freq_strs.append(region_str + ':' + freq_str)

    return ','.join(regional_freq_strs)


class BuildEdadealTestDicts(sdk2.Task, SuggestDictTask):
    """ Build edadeal suggests dictionary """

    class Parameters(sdk2.Task.Parameters):
        yt_token_vault_name = sdk2.parameters.String(
            "Name of the vault record with YT token",
            name="vault_token_name",
            default="yt_token"
        )

        queries_table_path = sdk2.parameters.String(
            "Queries table path",
            default="//home/edadeal/cooked/butara/suggest/2018-11-27"
        )

        autodeploy = sdk2.parameters.Bool("Autodeploy", default=True)

    class Requirements(sdk2.Requirements):
        environments = [
            PipEnvironment("ujson==2.0.3"),
            PipEnvironment("yandex-yt-yson-bindings-skynet"),
            PipEnvironment("yandex-yt")
        ]

    def get_yt_client_for_reader(self, proxy, token):
        from yt.wrapper import YtClient
        client = YtClient(proxy, token=token)
        client.config['read_parallel']['enable'] = True
        return client

    def on_execute(self):
        import ujson
        import yt.wrapper as yt

        proxy = "hahn.yt.yandex.net"
        token = sdk2.Vault.data(self.Parameters.yt_token_vault_name)

        self.input_table = self.Parameters.queries_table_path

        queries = []
        region_ids = set()
        unique_queries = set()

        data_path = os.path.join(os.getcwd(), 'data')
        queries_path = os.path.join(os.getcwd(), 'queries')

        yt_client_for_reader = self.get_yt_client_for_reader(proxy, token)
        for i, raw_row in enumerate(
            yt_client_for_reader.read_table(
                self.input_table,
                format=yt.JsonFormat(attributes={'encode_utf8': False}),
                raw=True
            )
        ):
            if i % 10000 == 0:
                logging.debug('%d rows processed', i)

            row = ujson.loads(raw_row)

            global_freq = 0.0
            regional_freqs = []

            for elem in row['weights']:
                region_ids.add(elem[0])
                global_freq += elem[1]
                regional_freqs.append((elem[0], elem[1]))

            global_freq_str = format_float(global_freq)

            query = normalize(row['text'])
            if query in unique_queries:
                continue

            unique_queries.add(query)

            queries.append(
                query + '\t\t' +
                global_freq_str + '\t' +
                format_regional_frequencies(regional_freqs) + '\n'
            )
            rich_data = query + '\t' + json.dumps(
                row['query'],
                ensure_ascii=False
            ).encode('utf-8') + '\n'
            file_utils.append_lines(data_path, (rich_data,))

            for alias in row['keywords']:
                normalized_alias = normalize(alias)
                if normalized_alias in unique_queries:
                    continue

                unique_queries.add(normalized_alias)

                queries.append(
                    normalize(alias) + '\t' + query + '\t' +
                    global_freq_str + '\n'
                )

        streams_path = os.path.join(os.getcwd(), 'streams')

        queries.sort()
        file_utils.write_lines(queries_path, queries)
        del queries

        region_ids = sorted(region_ids)
        streams = itertools.chain((
            '{}\t1\n'.format(elem)
            for elem in region_ids
        ), ('ALL\t{}\n'.format(len(region_ids)),))
        file_utils.write_lines(streams_path, streams)
        del region_ids

        groups_path = os.path.join(os.getcwd(), 'groups')
        file_utils.write_lines(groups_path, ['\n'])

        dict_path = os.path.join(os.getcwd(), "dict")
        os.makedirs(dict_path)
        dict_prefix = os.path.join(dict_path, "edadeal_test")

        self.setup_yt_client(proxy, token)
        self.run_data_builder(
            dict_prefix, queries_path,
            groups_path, streams_path, data_path,
            word_index=True, top_size=100
        )

        self.publish_dict(
            "edadeal_test",
            "Dictionaries for edadeal",
            dict_path,
            autodeploy=self.Parameters.autodeploy
        )
