# -*- coding: utf-8 -*-

from sandbox import sdk2
from sandbox import common
from sandbox.projects import resource_types
from sandbox.projects.judtier import shards
from sandbox.sandboxsdk import environments
import zlib

yql_query_template = """
USE arnold;
PRAGMA DisableSimpleColumns;
$config = AsAtom(@@{{
    "name": "NJupiter.SDocErf2InfoProto",
    "meta": "H4sIAKt7iF8CA+NK4uKsqKjQKyjKL8kX4vDzKi3ILEktUvLmEgx2yU92LUoz8sxLyw8AS0txcYQW5XgkFmcYShQpMGrwBsH5SHJGEsUockYAzNiFn2QAAAA="
}}@@);
$parse = YQL::Udf(AsAtom("Protobuf.Parse"), Void(), Void(), $config);

$tmp = SELECT Host || Path AS MainContentUrl, Shard, RobotBeautyUrl, $parse(FinalErf) AS erf
FROM RANGE(
`//home/jupiter/shards_prepare/{jupiter_state}`, `0`, `999999`, `calculated_attrs_chunked["JudTier":"JudTies"]`
);

INSERT INTO `{result_table}` WITH TRUNCATE
SELECT RankingUrl, MainContentUrl, BeautyUrl, Shard, UrlHash1, UrlHash2
FROM (
SELECT String::SplitToList(y.OriginalUrls, "\t") AS OriginalUrls,
x.MainContentUrl AS MainContentUrl, z.RobotBeautyUrl AS BeautyUrl, x.Shard AS Shard, z.erf.UrlHash1 AS UrlHash1, z.erf.UrlHash2 AS UrlHash2
FROM `//home/jupiter/export/{jupiter_state}/judtier/dups` AS x
JOIN `{maxassessment_table}` AS y
ON x.Host == y.Host and x.Path == y.Path
LEFT JOIN $tmp AS z
ON x.MainContentUrl == z.MainContentUrl
)
FLATTEN BY OriginalUrls as RankingUrl
ORDER BY RankingUrl;
"""


class JudTierGetDupsTable(sdk2.Task):
    """
        Выкачивает таблицу с результатами канонизации оценённых урлов,
        считаемую в ходе подготовки JudTier, как ресурс JUDTIER_DUPS_TABLE
    """

    class Parameters(sdk2.Task.Parameters):
        jupiter_state = sdk2.parameters.String('Jupiter state to work with', required=True, description="Юпитерный стейт, e.g. 20161231-123456")
        split_shards = sdk2.parameters.Bool('Split for shards', description="Если включено, создавать по одному ресурсу на каждый шард")
        precreate_count = sdk2.parameters.Integer('Count of per-shard files, 0 = auto-detect', description="Если указано, то ресурсы создаются сразу при создании таска")

    class Requirements(sdk2.Task.Requirements):
        environments = (
            environments.PipEnvironment('yandex-yt', use_wheel=True),
            environments.PipEnvironment('yql', use_wheel=True),
        )
        cores = 1

        class Caches(sdk2.Requirements.Caches):
            pass

    format_attribute = "RankingUrl;MainContentUrl;BeautyUrl;Shard;UrlHash"

    def _create_shard_resource(self, shard):
        return resource_types.JUDTIER_DUPS_TABLE(
            self,
            "JudTier dups table for " + shard,
            shard + "-urls.tsv.gz",
            shard_name=shard,
            format=self.format_attribute).id

    def on_enqueue(self):
        with self.memoize_stage.create_output_resource:
            self.Context.timestamp = str(shards.state_to_timestamp(self.Parameters.jupiter_state))
            res = resource_types.JUDTIER_DUPS_TABLE(
                self,
                "JudTier dups table for " + self.Parameters.jupiter_state,
                "urls.tsv.gz",
                jupiter_state=self.Parameters.jupiter_state,
                format=self.format_attribute)
            self.Context.output_resource_id = res.id
            if self.Parameters.split_shards:
                self.Context.shard_resource_ids = {}
                for i in xrange(self.Parameters.precreate_count):
                    shard = "primus-JudTier-{}-{}-{}".format(0, i, self.Context.timestamp)
                    self.Context.shard_resource_ids[shard] = self._create_shard_resource(shard)

    class SplittingWriter(object):
        def __init__(self, parent):
            self.parent = parent
            self.split_shards = parent.Parameters.split_shards
            self.main_output_file = sdk2.ResourceData(sdk2.Resource[parent.Context.output_resource_id]).path.open('wb')
            self.main_zlib_object = zlib.compressobj(6, zlib.DEFLATED, 16+15)
            self.created_shards = []
            if self.split_shards:
                self.shard_output_files = {}
                for shard, resource_id in parent.Context.shard_resource_ids.iteritems():
                    self._add_shard_output(shard, resource_id)
            self.prev_line = ''

        def _add_shard_output(self, shard, resource_id):
            output_file = sdk2.ResourceData(sdk2.Resource[resource_id]).path.open('wb')
            zlib_object = zlib.compressobj(6, zlib.DEFLATED, 16+15)
            self.shard_output_files[shard] = (output_file, zlib_object)

        def write(self, line, shard):
            if line < self.prev_line:
                raise common.errors.TaskError("Sort order is broken: {} < {}".format(line, self.prev_line))
            if line == self.prev_line:
                return
            self.prev_line = line
            self.main_output_file.write(self.main_zlib_object.compress(line))
            if self.split_shards:
                if shard not in self.shard_output_files:
                    resource_id = self.parent._create_shard_resource(shard)
                    self._add_shard_output(shard, resource_id)
                    self.created_shards.append(shard)
                output_file, zlib_object = self.shard_output_files[shard]
                output_file.write(zlib_object.compress(line))

        def close(self):
            self.main_output_file.write(self.main_zlib_object.flush())
            self.main_output_file.close()
            if self.split_shards:
                for output_file, zlib_object in self.shard_output_files.itervalues():
                    output_file.write(zlib_object.flush())
                    output_file.close()

        def get_created_shards(self):
            return sorted(self.created_shards)

    def on_execute(self):
        import yt.wrapper as yt
        import yql.api.v1.client
        yt_token = sdk2.Vault.data('RATED-POOL-MAKERS', 'RATED_POOLS_YT_TOKEN')
        yt.config['token'] = yt_token
        yt.config['proxy']['url'] = 'arnold.yt.yandex.net'
        yql_token = yt_token
        hahn_client = yt.YtClient(proxy='hahn', config={'token': yt_token})
        maxassessment_state = None
        for data in hahn_client.read_table(yt.TablePath("//home/rated-pools/jupiter_MaxAssessment/state-mapping-history", exact_key=self.Parameters.jupiter_state), format="dsv"):
            maxassessment_state = data["delivery_state"]
        if maxassessment_state is None:
            maxassessment_state = yt.get("//home/jupiter/delivery_snapshot/" + self.Parameters.jupiter_state + "/MaxAssessment/@delivery_state")
        maxassessment_table = "//home/jupiter/delivery/MaxAssessment/" + maxassessment_state + "/MaxAssessment.raw"
        with yql.api.v1.client.YqlClient(token=yql_token) as yql_client:
            with yt.TempTable(prefix="judtier-get-dups-table-") as result_table:
                query = yql_client.query(yql_query_template.format(result_table=result_table, jupiter_state=self.Parameters.jupiter_state, maxassessment_table=maxassessment_table), syntax_version=1)
                query.run()
                if not query.get_results().is_success:
                    raise common.errors.TaskFailure("YQL request has failed: " + '; '.join(str(error) for error in query.get_results().errors))
                writer = JudTierGetDupsTable.SplittingWriter(self)
                yt_format = yt.SchemafulDsvFormat(["RankingUrl", "MainContentUrl", "BeautyUrl", "Shard", "UrlHash1", "UrlHash2"], enable_escaping=False, raw=True)
                cur_url = None
                cur_url_data = []
                for line in yt.read_table(result_table, format=yt_format):
                    ranking_url, main_content_url, beauty_url, shard, urlhash1, urlhash2 = line.rstrip('\n').split('\t')
                    # Shard in YT table is "JudTier/x-yy". Basesearch shard is "primus-JudTier-x-yy-timestamp".
                    shard = "primus-" + shard.replace('/', '-') + "-" + self.Context.timestamp
                    urlhash = (int(urlhash1) << 32) + int(urlhash2)
                    if ranking_url != cur_url:
                        for prev_line, prev_shard in sorted(cur_url_data):
                            writer.write(prev_line, prev_shard)
                        cur_url = ranking_url
                        cur_url_data = []
                    out_line = "{}\t{}\t{}\t{}\t{}\n".format(ranking_url, main_content_url, beauty_url, shard, urlhash)
                    cur_url_data.append((out_line, shard))
                for prev_line, prev_shard in sorted(cur_url_data):
                    writer.write(prev_line, prev_shard)
        writer.close()
        if self.Parameters.split_shards and self.Parameters.precreate_count:
            non_precreated_shards = writer.get_created_shards()
            if non_precreated_shards:
                raise common.errors.TaskFailure("Mismatch between Jupiter data and precreate_count. New shards: {}".format(non_precreated_shards))
