from collections import defaultdict
import json
import logging
import os

from sandbox import sdk2

from sandbox.projects.common import file_utils as fu
from sandbox.projects.suggest.dicts import SuggestDictTask
from sandbox.sandboxsdk.environments import PipEnvironment
from sandbox.sandboxsdk.errors import SandboxTaskFailureError


class BuildVideoHostingDicts(sdk2.Task, SuggestDictTask):
    """ Build video hosting dictionaries from yt data """

    @staticmethod
    def normalize_title(string):
        chars = [c if c.isalnum() else ' ' for c in string.lower()]
        words = [w for w in (''.join(chars)).split(' ') if w]
        return ' '.join(words)

    class Parameters(sdk2.Task.Parameters):
        yt_token_vault_name = sdk2.parameters.String(
            'Name of the vault record with YT token',
            name='vault_token_name',
            default='yt_token'
        )

        yt_cluster = sdk2.parameters.String(
            'YT Cluster',
            default='arnold'
        )

        queries_table_path = sdk2.parameters.String(
            'Queries table path',
            default='//home/video-hosting/suggest/prevdata/queries'
        )

        data_table_path = sdk2.parameters.String(
            'Data table path',
            default='//home/video-hosting/suggest/prevdata/data'
        )

        dict_name = sdk2.parameters.String(
            "Dict name",
            name="dict_name",
            default="video_hosting",
            required=True,
        )

        general_normalization = sdk2.parameters.Bool(
            "General normalization",
            default=False
        )

    class Requirements(sdk2.Requirements):
        environments = [
            PipEnvironment('yandex-yt-yson-bindings-skynet'),
            PipEnvironment('yandex-yt')
        ]
        cores = 1
        ram = 1024
        disk_space = 10 * 1024

        class Caches(sdk2.Requirements.Caches):
            pass  # means that task do not use any shared caches

    def on_execute(self):
        queries, data, regions, uuid2key = self.read_data_from_yt()
        self.build_dicts(queries, data, regions, uuid2key)

    def read_data_from_yt(self):
        from yt.wrapper import YtClient
        proxy = self.Parameters.yt_cluster + '.yt.yandex.net'
        token = sdk2.Vault.data(self.Parameters.yt_token_vault_name)
        client = YtClient(proxy, token)

        queries = defaultdict(lambda: defaultdict(lambda: []))
        prev_query = defaultdict(str)
        dup = defaultdict(int)
        uuid2key = defaultdict(str)

        queries_table_path = self.Parameters.queries_table_path
        if not client.is_sorted(queries_table_path):
            logging.error("Table {} isn't sorted".format(queries_table_path))
            raise SandboxTaskFailureError("Table {} isn't sorted".format(queries_table_path))

        for row in client.read_table(queries_table_path):
            item_type = row['type']
            query = row['query']
            if prev_query[item_type] == query:
                dup[item_type] += 1
                query += '@' + str(dup[item_type])
            else:
                dup[item_type] = 0
                prev_query[item_type] = query

            uuid = row['uuid']
            weight = row['query_weight'] or 0.0
            queries[item_type][uuid].append((query, weight))

        data = defaultdict(lambda: [])
        regions = defaultdict(lambda: set())

        for row in client.read_table(self.Parameters.data_table_path):
            item_type = row['type']

            uuid = row['uuid']
            title = row['obj'][2]
            normalized_title = self.normalize_title(title.decode('utf-8')).encode('utf-8')
            key = '@' + uuid + ' ' + normalized_title

            uuid2key[uuid] = key

            obj_data = [it.decode('utf-8') if isinstance(it, basestring) else it for it in row['obj']]
            obj = json.dumps(obj_data, ensure_ascii=False).encode('utf-8')

            if item_type != 'catchup':
                data[item_type].append('\t'.join([key, obj]))
            else:
                start_time = str(row['start_time'])
                finish_time = str(row['finish_time'])
                expiration_date = str(row['expiration_date'])
                data[item_type].append('\t'.join([key, obj, start_time, finish_time, expiration_date]))

            regions_str = row['region']
            regions[uuid] = set([r.strip() for r in regions_str.split(',')])

        return queries, data, regions, uuid2key

    def build_dicts(self, queries, data, regions, uuid2key):
        input_data_path = os.path.join(os.getcwd(), 'data')
        os.makedirs(input_data_path)

        dicts_path = os.path.join(os.getcwd(), 'dicts')
        os.makedirs(dicts_path)

        streams4type = {}
        queries_path4type = {}
        groups_path4type = {}
        data_path4type = {}
        folder4type = {}
        for item_type, grouped_queries in queries.iteritems():
            folder = os.path.join(input_data_path, item_type)
            os.makedirs(folder)

            query_lines = []
            group_lines = []
            streams = defaultdict(int)

            for uuid, queries in grouped_queries.iteritems():
                if len(queries) > 1:
                    group_lines.append('\t'.join([q[0] for q in queries]) + '\n')

                group_weight = sum([q[1] for q in queries])
                canonical_query = uuid2key[uuid]
                if not canonical_query:
                    continue

                for region in regions[uuid]:
                    streams[region] += group_weight

                regions_str = ','.join([region + ':' + str(group_weight) for region in regions[uuid]])

                weight_sum = sum([query[1] for query in queries])
                query_lines.append('{}\t\t{}\t{}\n'.format(canonical_query, weight_sum, regions_str))
                for query in queries:
                    query_lines.append('{}\t{}\t{}\n'.format(query[0], canonical_query, str(query[1])))

            queries_path = os.path.join(folder, 'queries')
            fu.write_lines(queries_path, sorted(query_lines))

            groups_path = os.path.join(folder, 'groups')
            fu.write_lines(groups_path, group_lines)

            data_path = os.path.join(folder, 'data')
            fu.write_lines(data_path, [line + '\n' for line in data[item_type]])

            streams4type[item_type] = streams
            queries_path4type[item_type] = queries_path
            groups_path4type[item_type] = groups_path
            data_path4type[item_type] = data_path
            folder4type[item_type] = folder

        if self.Parameters.general_normalization:
            general_streams = defaultdict(int)
            for item_type, streams in streams4type.iteritems():
                for region, weight in streams.iteritems():
                    general_streams[region] += weight

        for item_type, streams in streams4type.iteritems():
            queries_path = queries_path4type[item_type]
            groups_path = groups_path4type[item_type]
            data_path = data_path4type[item_type]
            folder = folder4type[item_type]

            if self.Parameters.general_normalization:
                streams = general_streams
            stream_lines = ['{}\t{}\n'.format(region, weight) for region, weight in streams.iteritems()]
            stream_lines.append('ALL\t{}'.format(sum(streams.values())))
            streams_path = os.path.join(folder, 'streams')
            fu.write_lines(streams_path, stream_lines)

            dict_folder = os.path.join(dicts_path, item_type)
            os.makedirs(dict_folder)
            dict_prefix = os.path.join(dict_folder, item_type)

            self.run_data_builder(dict_prefix, queries_path, groups_path, streams_path, data_path, word_index=True)

        self.publish_dict(self.Parameters.dict_name, 'Dictionaries for video hosting', dicts_path, autodeploy=True)
