#include "shard.h"

#include <solomon/libs/cpp/logging/logging.h>

#include <library/cpp/actors/core/actor.h>
#include <library/cpp/actors/core/event.h>
#include <library/cpp/actors/core/hfunc.h>
#include <library/cpp/containers/absl_flat_hash/flat_hash_map.h>

namespace NSolomon::NDataProxy {
namespace {

using namespace NActors;
using namespace yandex::monitoring::memstore;

struct TLocalEvents: private TPrivateEvents {
    enum {
        RpcResponse = SpaceBegin,
        RpcError,
        End,
    };
    static_assert(End < SpaceEnd, "too many event types");

    struct TRpcResponse: public TEventLocal<TRpcResponse, RpcResponse> {
        ui32 ReqId;
        std::unique_ptr<IEventBase> Event;

        TRpcResponse(ui32 reqId, std::unique_ptr<IEventBase> event) noexcept
            : ReqId{reqId}
            , Event{std::move(event)}
        {
        }
    };

    struct TRpcError: public TEventLocal<TRpcError, RpcError> {
        ui32 ReqId;
        NGrpc::TGrpcStatus Status;

        TRpcError(ui32 reqId, NGrpc::TGrpcStatus&& status) noexcept
            : ReqId{reqId}
            , Status{std::move(status)}
        {
        }
    };
};

/**
 * Handles RPC response over future subscription.
 */
template <typename TRespEvent>
class TResponseHandler {
    using TResp = decltype(std::declval<TRespEvent>().Message);

public:
    TResponseHandler(TClusterId clusterId, TShardId shardId, ui32 reqId) noexcept
        : ActorSystem_{TActorContext::ActorSystem()}
        , ShardActorId_{TActorContext::AsActorContext().SelfID}
        , ClusterId_{clusterId}
        , ShardId_{shardId}
        , ReqId_{reqId}
    {
    }

    void operator()(TAsyncResponse<TResp> resp) {
        ActorSystem_->Send(ShardActorId_, MakeResultEvent(std::move(resp)));
    }

private:
    IEventBase* MakeResultEvent(TAsyncResponse<TResp> resp) noexcept {
        try {
            auto valueOrError = resp.ExtractValue();
            if (valueOrError.Success()) {
                return new TLocalEvents::TRpcResponse(
                        ReqId_,
                        std::make_unique<TRespEvent>(ShardId_, ClusterId_, valueOrError.Extract())
                );
            } else {
                return new TLocalEvents::TRpcError(ReqId_, valueOrError.ExtractError());
            }
        } catch (...) {
            auto msg = CurrentExceptionMessage();
            return new TLocalEvents::TRpcError(ReqId_, NGrpc::TGrpcStatus::Internal(msg));
        }
    }

private:
    TActorSystem* ActorSystem_;
    TActorId ShardActorId_;
    TClusterId ClusterId_;
    TShardId ShardId_;
    ui32 ReqId_;
};

class TMemStoreShard: public TActor<TMemStoreShard> {
    struct TRequestCtx {
        TActorId ReplyTo;
        NWilson::TTraceId Span;
    };

public:
    TMemStoreShard(
            TClusterId clusterId,
            TShardId shardId,
            TString shardLocation,
            std::shared_ptr<NMemStore::IMemStoreClusterRpc> rpc)
        : TActor<TMemStoreShard>{&TThis::StateFunc}
        , ClusterId_{clusterId}
        , ShardId_{shardId}
        , ShardLocation_{std::move(shardLocation)}
        , Rpc_{std::move(rpc)}
    {
    }

// NOLINTNEXTLINE(readability-identifier-naming)
#define rpcFunc(Req, Resp, Method) \
    case TMemStoreShardEvents::Req::EventType: { \
        OnRequest< \
                TMemStoreShardEvents::Req, \
                TMemStoreShardEvents::Resp, \
                &NMemStore::IMemStoreRpc::Method>(ev); \
        break; \
    }

    STATEFN(StateFunc) {
        switch (ev->GetTypeRewrite()) {
            rpcFunc(TFindReq, TFindResp, Find)
            rpcFunc(TLabelKeysReq, TLabelKeysResp, LabelKeys)
            rpcFunc(TLabelValuesReq, TLabelValuesResp, LabelValues)
            rpcFunc(TUniqueLabelsReq, TUniqueLabelsResp, UniqueLabels)
            rpcFunc(TReadOneReq, TReadOneResp, ReadOne)
            rpcFunc(TReadManyReq, TReadManyResp, ReadMany)

            hFunc(TMemStoreShardEvents::TUpdateLocation, OnUpdateLocation)
            hFunc(TLocalEvents::TRpcResponse, OnRpcResponse)
            hFunc(TLocalEvents::TRpcError, OnRpcError)
            hFunc(TEvents::TEvPoison, OnPoison)
        }
    }

    template <typename TReqEvent, typename TRespEvent, auto Method>
    void OnRequest(TAutoPtr<IEventHandle>& ev) {
        auto* xEv = reinterpret_cast<typename TReqEvent::TPtr*>(&ev);

        NMemStore::IMemStoreRpc* nodeRpc = Rpc_->Get(ShardLocation_);
        Y_VERIFY(nodeRpc, "shard %d uses invalid location %s", ShardId_, ShardLocation_.c_str());

        ui32 reqId = ReqId_++;
        TRequestCtx& reqCtx = InFlight_[reqId];
        reqCtx.ReplyTo = ev->Sender;
        reqCtx.Span = std::move(ev->TraceId);

        (nodeRpc->*Method)((*xEv)->Get()->Message)
            .Subscribe(TResponseHandler<TRespEvent>(ClusterId_, ShardId_, reqId));
    }

    void OnRpcResponse(TLocalEvents::TRpcResponse::TPtr& ev) {
        ui32 reqId = ev->Get()->ReqId;
        if (auto node = InFlight_.extract(reqId)) {
            auto& reqCtx = node.mapped();
            Send(reqCtx.ReplyTo, ev->Get()->Event.release(), 0, 0, std::move(reqCtx.Span));
        } else {
            Y_FAIL("invalid request id %d in shard %d", reqId, ShardId_);
        }
    }

    void OnRpcError(TLocalEvents::TRpcError::TPtr& ev) {
        ui32 reqId = ev->Get()->ReqId;
        if (auto node = InFlight_.extract(reqId)) {
            auto& reqCtx = node.mapped();
            Send(reqCtx.ReplyTo, new TMemStoreShardEvents::TError{
                ShardId_,
                ClusterId_,
                static_cast<grpc::StatusCode>(ev->Get()->Status.GRpcStatusCode),
                std::move(ev->Get()->Status.Msg),
            }, 0, 0, std::move(reqCtx.Span));
        } else {
            Y_FAIL("invalid request id %d in shard %d", reqId, ShardId_);
        }
    }

    void OnUpdateLocation(TMemStoreShardEvents::TUpdateLocation::TPtr& ev) {
        MON_DEBUG(MemStoreClient,
                "shard " << ShardId_ << " changed location from "
                << ShardLocation_ << " to " << ev->Get()->Address);
        ShardLocation_ = std::move(ev->Get()->Address);
    }

    void OnPoison(TEvents::TEvPoison::TPtr& ev) {
        Send(ev->Sender, new TEvents::TEvPoisonTaken);
    }

private:
    TClusterId ClusterId_;
    TShardId ShardId_;
    TString ShardLocation_;
    std::shared_ptr<NMemStore::IMemStoreClusterRpc> Rpc_;
    absl::flat_hash_map<ui32, TRequestCtx> InFlight_;
    ui32 ReqId_{0};
};

} // namespace

std::unique_ptr<NActors::IActor> MemStoreShard(
    TClusterId clusterId,
    TShardId shardId,
    TString shardLocation,
    std::shared_ptr<NMemStore::IMemStoreClusterRpc> rpc)
{
    return std::make_unique<TMemStoreShard>(clusterId, shardId, std::move(shardLocation), std::move(rpc));
}

} // namespace NSolomon::NDataProxy
