from concurrent import futures
import signal

import grpc
from grpc_reflection.v1alpha import reflection
from library.python.protobuf.json import proto2json
import yt.yson

from crypta.lab.rule_estimator.proto import rule_estimate_stats_pb2
from crypta.lab.rule_estimator.services.api.proto import (
    api_pb2,
    api_pb2_grpc,
)
from crypta.lib.python import time_utils
from crypta.lib.python.grpc import keepalive
from crypta.lib.python.lb_pusher import logbroker
from crypta.lib.python.yt.dyntables import kv_client


class RuleEstimator(api_pb2_grpc.RuleEstimatorServicer):
    def __init__(self, pq_writer, logger, kv_client):
        self.pq_writer = pq_writer
        self.logger = logger
        self.kv_client = kv_client

    def Update(self, request, context):
        payload = proto2json.proto2json(request)
        self.pq_writer.write(payload)
        self.logger.info("Update: %s", payload)

        proto = rule_estimate_stats_pb2.RuleEstimateStats()
        proto.IsReady = False
        proto.Timestamp = time_utils.get_current_time()
        value = proto.SerializeToString()

        records = {
            str(rule_condition_id): value
            for rule_condition_id in request.RuleConditionIds
        }
        records[make_rule_key(request.RuleConditionIds)] = value
        self.kv_client.write_many(records)

        return api_pb2.UpdateResponse(Message="Ok")

    def Ping(self, request, context):
        self.logger.info("Ping")
        return api_pb2.PingResponse(Message="Ok")

    def GetRuleConditionStats(self, request, context):
        return self._get_stats(str(request.RuleConditionId), context)

    def _get_stats(self, key, context):
        result = self.kv_client.lookup(key)
        proto = rule_estimate_stats_pb2.RuleEstimateStats()

        if result is None:
            context.set_code(grpc.StatusCode.NOT_FOUND)
        else:
            proto.ParseFromString(yt.yson.get_bytes(result))

        return proto

    def SetRuleConditionStats(self, request, context):
        return self._set_stats(str(request.RuleConditionId), request.Stats, context)

    def _set_stats(self, key, stats, context):
        self.kv_client.write(
            key,
            stats.SerializeToString(),
        )
        return api_pb2.SetRuleConditionStatsResponse(Message="Ok")

    def GetRuleStats(self, request, context):
        return self._get_stats(make_rule_key(request.RuleConditionIds), context)

    def SetRuleStats(self, request, context):
        return self._set_stats(make_rule_key(request.RuleConditionIds), request.Stats, context)


def serve(config, logger):
    pq_client = logbroker.PQClient(
        config.Logbroker.Url,
        config.Logbroker.Port,
        tvm_id=config.Tvm.SourceTvmId,
        tvm_secret=config.Tvm.Secret,
    )
    pq_writer = pq_client.get_writer(config.Topic)
    client = kv_client.make_kv_client(config.Yt.Proxy, config.RuleConditionStatsPath, config.Yt.Token)

    server = grpc.server(
        futures.ThreadPoolExecutor(max_workers=config.Workers),
        options=keepalive.get_keepalive_options(),
    )
    api_pb2_grpc.add_RuleEstimatorServicer_to_server(
        RuleEstimator(pq_writer, logger, client),
        server,
    )
    reflection.enable_server_reflection((
        api_pb2.DESCRIPTOR.services_by_name["RuleEstimator"].full_name,
        reflection.SERVICE_NAME,
    ), server)
    server.add_insecure_port("[::]:{}".format(config.Port))

    signal.signal(signal.SIGTERM, lambda *args: server.stop(10).wait(10))

    with pq_client, pq_writer:
        server.start()
        server.wait_for_termination()


def make_rule_key(rule_revision_ids):
    return "rule-" + "-".join(str(x) for x in sorted(set(rule_revision_ids)))
