# -*- coding: utf-8 -*-

import os
import logging
import random
import json
import base64

from sandbox import common
from sandbox import sdk2
from sandbox.common.types.client import Tag
from sandbox.sandboxsdk.environments import PipEnvironment
import sandbox.projects.common.binary_task as binary_task


class PLAYLIST_SERVICE_RESPONSES(sdk2.Resource):
    """
        Playlist service responses
    """


class PLAYLIST_SERVICE_REQUEST_PLAN(sdk2.Resource):
    """
        Plan file for servant client
    """


class VhGeneratePlaylistRequestsFromYt(binary_task.LastBinaryTaskRelease, sdk2.Task):
    """
        Select max_request_number requests to playlist service from YT logs.
        Generate plan for servant client and initial requests.
    """

    class Requirements(sdk2.Requirements):
        privileged = True
        client_tags = Tag.INTEL_E5_2650 & Tag.LXC & Tag.GENERIC
        execution_space = 10 * 1024
        required_ram = 16 * 1024
        environments = [
            PipEnvironment("yandex-yt", "0.9.26"),
        ]

    class Parameters(sdk2.Task.Parameters):
        yt_token_vault = sdk2.parameters.String(
            "YT_TOKEN vault name",
            name="yt_token_vault",
            default="yt_token_for_testenv",
            required=True,
        )
        yt_cluster = sdk2.parameters.String(
            "YT cluster (i.e. hahn)",
            name="yt_cluster",
            default="hahn",
            required=True,
        )
        request_log_path = sdk2.parameters.String(
            "Request logs path in YT",
            name="request_log_path",
            default="//logs/vh-playlist-service-logs/1d",
            required=True,
        )
        max_request_number = sdk2.parameters.String(
            "maximum requests number",
            name="max_request_number",
            default="10000"
        )
        yt_logs_limit = sdk2.parameters.String(
            "number of rows in YT table with logs to read",
            name="yt_logs_limit",
            default="1000000"
        )
        ext_params = binary_task.LastBinaryReleaseParameters()

        with sdk2.parameters.Output():
            plan_file_ugc = sdk2.parameters.Resource(
                "requests to PLAYLIST_SERVICE_UGC",
                name="raw_requests_file_ugc",
                resource_type=PLAYLIST_SERVICE_REQUEST_PLAN
            )
            plan_file_vh = sdk2.parameters.Resource(
                "requests to PLAYLIST_SERVICE_VH",
                name="raw_requests_file_vh",
                resource_type=PLAYLIST_SERVICE_REQUEST_PLAN
            )

    def on_execute(self):
        log_sample_vh, log_sample_ugc = list(self.get_log_samples_from_yt())
        logging.info("Successfully read logs from yt")

        self.Parameters.plan_file_vh = PLAYLIST_SERVICE_REQUEST_PLAN(self, "plan file for servant client PLAYLIST_SERVICE_VH", "plan_file_vh.txt")
        plan_file_vh = sdk2.ResourceData(self.Parameters.plan_file_vh).path

        self.Parameters.plan_file_ugc = PLAYLIST_SERVICE_REQUEST_PLAN(self, "plan file for servant client PLAYLIST_SERVICE_UGC", "plan_file_ugc.txt")
        plan_file_ugc = sdk2.ResourceData(self.Parameters.plan_file_ugc).path

#       using http_request_vh because VH requests are most often made by kinopoisk via http
        self.create_requests_file(log_sample_vh, plan_file_vh, "http_request_vh", "http")
        self.create_requests_file(log_sample_ugc, plan_file_ugc, "ugc_request", "proto")
        logging.info("Successfully created requests file")

    def get_default_ctx(self):
        from library.python import resource
        data = resource.find("sandbox/projects/vh/frontend/generate_playlist_service_requests/default_ctx.json")

        return data

    def get_default_ctx_http(self):
        from library.python import resource
        data = resource.find("sandbox/projects/vh/frontend/generate_playlist_service_requests/default_ctx_http.json")
        return data

    def create_ctx(self, json_request_proto):
        from extsearch.video.vh.playlist_service.library.data_structures.protos.handle_by_uuid_structs_pb2 import TStreamsByUuidRequest
        from google.protobuf.json_format import Parse

        js_request = json.loads(json_request_proto)

        proto_res = Parse(json.dumps(js_request), TStreamsByUuidRequest())

        ctx = json.loads(self.get_default_ctx())
        ctx["answers"][1]["binary"] = base64.b64encode(proto_res.SerializeToString())
        return ctx

    def create_ctx_http_request(self, json_request):
        ctx = json.loads(self.get_default_ctx_http())
        ctx["answers"][1]["binary"] = json.loads(json_request)
        return ctx

    def create_requests_file(self, log_sample, plan_file, column_name, scheme):
        logging.info("Start create requests file")
        logging.info("Scheme: {}".format(scheme))
        with open(str(plan_file), 'wb') as plan_out:
            for log_row in log_sample:

                try:
                    if scheme == "proto":
                        request = self.create_ctx(log_row[column_name])
                    elif scheme == "http":
                        request = self.create_ctx_http_request(log_row[column_name])
                    plan_out.write(json.dumps(request) + '\n')
                except ValueError:
                    logging.info("bad request:" + log_row[column_name])

    def get_log_samples_from_yt(self):
        from yt.wrapper import JsonFormat, YtClient

        yt_token = sdk2.Vault.data(self.Parameters.yt_token_vault)
        yt_cluster = self.Parameters.yt_cluster
        client = YtClient(yt_cluster, yt_token)

        log_path_suffix = client.list(self.Parameters.request_log_path)
        logging.info("Found logs: %s" % ", ".join(log_path_suffix))
        current_log_path_suffix = sorted(log_path_suffix)[-1]

        table_name = os.path.join(self.Parameters.request_log_path, current_log_path_suffix)
        logging.info("Used logs table: " + table_name)

        if client.exists(table_name):
            sample_size = int(self.Parameters.max_request_number)

            logging.info("Start ReservoirSampling")
            sample_rows_ugc = self.ReservoirSampling(sample_size)
            sample_rows_vh = self.ReservoirSampling(sample_size)
            rows = client.read_table(table_name, format=JsonFormat())
            rows_counter_in_yt = 0

            for row in rows:
                if row["http_response_vh"] is not None and row["http_response_vh"] != "{}" and row["http_response_vh"] != "":
                    sample_rows_vh.sample(row)

                if row["ugc_node_response"] is not None and row["ugc_node_response"] != "{}" and row["ugc_node_response"] != "":
                    sample_rows_ugc.sample(row)

                if rows_counter_in_yt % 1000 == 0:
                    logging.info("Rows counter in yt: " + str(rows_counter_in_yt))

                if self.Parameters.yt_logs_limit:
                    if rows_counter_in_yt > int(self.Parameters.yt_logs_limit):
                        break

                rows_counter_in_yt += 1

            return sample_rows_vh, sample_rows_ugc

        raise common.errors.TaskFailure("Table with logs not found")

    class ReservoirSampling(object):
        def __init__(self, sample_num):
            assert sample_num > 0
            self._sample_num = int(sample_num)
            self._reservoir = []
            self._pos = 0
            self._sample = self._no_sample

        def _no_sample(self, item):
            self._reservoir.append(item)
            if len(self._reservoir) >= self._sample_num:
                self._sample = self._naive_sample

        def _naive_sample(self, item):
            k = int(random.random() * self._pos)
            if k < self._sample_num:
                self._reservoir[k] = item

        def sample(self, item):
            self._pos += 1
            self._sample(item)

        def __iter__(self):
            return iter(self._reservoir)
