import logging
import os
import string
import time
from collections import defaultdict
from sandbox import sdk2
from sandbox.projects.common import file_utils as fu
from sandbox.projects.suggest.dicts import SuggestDictTask
from sandbox.sandboxsdk.environments import PipEnvironment


def normalize_query(title):
    result = title.translate(None, string.punctuation).strip()
    return result.decode('utf-8').lower().encode('utf-8')


def current_timestamp():
    return int(time.time())


class BuildEzSuggestDict(sdk2.Task, SuggestDictTask):
    """ Build easy suggest dictionary """

    class Parameters(sdk2.Task.Parameters):
        yt_token_vault_name = sdk2.parameters.String(
            'YT token vault record name',
            name='vault_token_name',
            default='yt_token'
        )

        yt_cluster = sdk2.parameters.String(
            'YT cluster',
            default='hahn'
        )

        dictionary_name = sdk2.parameters.String(
            "Dictionary name",
            default="dict"
        )

        source_table_path = sdk2.parameters.String(
            "Source table location",
            default=""
        )

        query_column_name = sdk2.parameters.String(
            "Query column name",
            default='query'
        )

        normalize_queries = sdk2.parameters.Bool(
            'Normalize queries',
            default=True
        )

        use_word_index = sdk2.parameters.Bool(
            'Use word index',
            default=True
        )

        frequency_column_name = sdk2.parameters.String(
            "Frequency column name ",
            default='frequency'
        )

        regional_frequencies_column_name = sdk2.parameters.String(
            "Regional frequency column name ",
            default='regional_frequencies'
        )

        data_column_name = sdk2.parameters.String(
            "Data column name ",
            default=''
        )

        create_dict_info = sdk2.parameters.Bool(
            "Create dict info file",
            default=False
        )

        start_ts = sdk2.parameters.Integer(
            "Start timestamp for whole build dict process",
            default=0
        )

        autodeploy = sdk2.parameters.Bool(
            "Autodeploy dict",
            default=False
        )

    class Requirements(sdk2.Requirements):
        environments = [
            PipEnvironment("yandex-yt-yson-bindings-skynet"),
            PipEnvironment("yandex-yt")
        ]

    def on_execute(self):
        logging.debug('Reading source table...')

        from yt.wrapper import YtClient
        proxy = self.Parameters.yt_cluster + '.yt.yandex.net'
        token = sdk2.Vault.data(self.Parameters.yt_token_vault_name)
        client = YtClient(proxy, token)

        query_column_name = self.Parameters.query_column_name
        unique_queries = set()
        queries = []

        frequency_column_name = self.Parameters.frequency_column_name
        frequencies_sum = 0

        regional_frequencies_column_name = self.Parameters.regional_frequencies_column_name
        regional_frequencies = defaultdict(int)

        data_column_name = self.Parameters.data_column_name
        datas = []

        for row in client.read_table(self.Parameters.source_table_path):
            query = row[query_column_name]
            if self.Parameters.normalize_queries:
                query = normalize_query(query)
            if query in unique_queries:
                raise Exception('Duplicate query: ' + query)
            else:
                unique_queries.add(query)

            frequency = int(row[frequency_column_name])
            frequencies_sum += frequency

            regional_frequencies_strs = []
            if regional_frequencies_column_name:
                for r, f in row[regional_frequencies_column_name].items():
                    regional_frequencies[r] += int(f)
                    regional_frequencies_strs.append(r + ':' + str(f))
            regional_frequencies_str = ','.join(regional_frequencies_strs)

            queries.append(query + '\t\t' +
                           str(frequency) + '\t' +
                           regional_frequencies_str + '\n')

            if data_column_name:
                data = row[data_column_name]
                datas.append(query + '\t' + data + '\n')

        logging.debug('Writing source files...')

        cwd = os.getcwd()

        queries_path = os.path.join(cwd, 'queries')
        queries = sorted(queries)
        fu.write_lines(queries_path, queries)

        groups_path = os.path.join(cwd, 'groups')
        fu.write_lines(groups_path, ['\n'])

        streams_path = os.path.join(cwd, 'streams')
        streams = []
        for r, f in regional_frequencies.items():
            streams.append(r + '\t' + str(f) + '\n')
        streams.append('ALL\t' + str(frequencies_sum) + '\n')
        fu.write_lines(streams_path, streams)

        data_path = ''
        if datas:
            data_path = os.path.join(cwd, 'data')
            fu.write_lines(data_path, datas)

        logging.debug('Building dictionary...')

        dict_name = self.Parameters.dictionary_name
        dict_path = os.path.join(os.getcwd(), dict_name)
        os.makedirs(dict_path)

        dict_prefix = os.path.join(dict_path, dict_name)
        self.run_data_builder(dict_prefix,
                              queries_path,
                              groups_path,
                              streams_path,
                              data_path,
                              word_index=self.Parameters.use_word_index)

        start_ts = self.Parameters.start_ts if self.Parameters.start_ts else current_timestamp()
        if self.Parameters.create_dict_info:
            self.create_dict_info(dict_name, dict_path, start_ts)

        logging.debug('Publishing dictionary...')

        self.publish_dict(dict_name,
                          "Easy suggest dictionary",
                          dict_path,
                          autodeploy=self.Parameters.autodeploy)
