#include "shard_actor.h"
#include "events.h"
#include "watcher.h"

#include <solomon/libs/cpp/actors/events/events.h>
#include <solomon/libs/cpp/logging/logging.h>
#include <solomon/libs/cpp/trace/trace.h>
#include <solomon/services/dataproxy/lib/hash/hasher.h>

#include <library/cpp/actors/core/actor.h>
#include <library/cpp/actors/core/hfunc.h>
#include <library/cpp/containers/absl_flat_hash/flat_hash_map.h>

#include <util/string/builder.h>
#include <util/string/join.h>

#include <queue>

using namespace NActors;
using namespace yandex::solomon::stockpile;

namespace NSolomon::NDataProxy {

using namespace NTracing;

namespace {

template <typename TRequest, typename TResponse>
class TRpcMethod {
public:
    using TReq = TRequest;
    using TResp = TResponse;
    using TMethod = TStockpileAsyncResponse<TResponse> (IStockpileRpc::*)(const TRequest&);

    explicit constexpr TRpcMethod(TMethod impl) noexcept
        : Impl_(impl)
    {
    }

    TStockpileAsyncResponse<TResponse> Call(IStockpileRpc* rpc, const TRequest& req) {
        return (rpc->*Impl_)(req);
    }

private:
    TMethod Impl_;
};

template <typename T>
struct TReqToRpcMethod;

template <>
struct TReqToRpcMethod<TStockpileEvents::TReadOneReq> {
    static constexpr auto RpcMethod = TRpcMethod{&IStockpileRpc::ReadCompressedOne};
};

template <>
struct TReqToRpcMethod<TStockpileEvents::TReadManyReq> {
    static constexpr auto RpcMethod = TRpcMethod{&IStockpileRpc::ReadCompressedMany};
};

class TRetryState {
public:
    void Retry() noexcept {
        Y_VERIFY_DEBUG(Countdown_ > 0);
        --Countdown_;
        NextDelay_ = Min(NextDelay_ + NextDelay_ / 2, TDuration::Seconds(10));
    }

    TDuration NextDelay() const noexcept {
        return NextDelay_;
    }

    bool CanRetry(EStockpileStatusCode code) const noexcept {
        if (Countdown_ == 0) {
            return false;
        }

        switch (code) {
            // TODO: retry SHARD_ABSENT_ON_HOST
            case EStockpileStatusCode::NODE_UNAVAILABLE:
            case EStockpileStatusCode::SHARD_NOT_READY:
            case EStockpileStatusCode::RESOURCE_EXHAUSTED:
            case EStockpileStatusCode::NOTE_ENOUGH_READY_SHARDS:
                return true;

            default:
                return false;
        }
    }

private:
    ui32 Countdown_{3};
    TDuration NextDelay_{TDuration::MilliSeconds(100)};
};

enum class EStockpileReqState: ui8 {
    New,
    Inflight,
    Retry,
    Postponed,
};

class TRequestsManager {
public:
    class TReqCtx {
        friend class TRequestsManager;
    public:
        using TReqId = ui64;
    private:
        TReqCtx(TReqId reqId) : ReqId(reqId)
        {
        }
    public:
        TReqId ReqId;
        EStockpileReqState StockpileReqState{EStockpileReqState::New};
        std::unique_ptr<IEventHandle> OriginalRequest;
        int MessageType;
        TInstant ReceivedAt{};
        ui64 Deadline;
        TRetryState RetryState;
    };

    // A safe way to reference ReqCtx when unsure if it still lives
    // Can be promoted to strong ReqCtx* using LookupReqCtx
    class TWeakReqCtx {
        friend class TRequestsManager;
    private:
        TWeakReqCtx(TReqCtx::TReqId reqId) : ReqId_(reqId)
        {
        }
    public:
        TWeakReqCtx(TReqCtx* reqCtx) : TWeakReqCtx{reqCtx->ReqId}
        {
        }
        static TWeakReqCtx FromCookie(ui64 cookie) {
            return TWeakReqCtx(cookie);
        }
        ui64 ToCookie() const {
            return ReqId_;
        }
    private:
        TReqCtx::TReqId ReqId_;
    };

    TReqCtx* CreateReqCtx(TInstant now = {}) {
        std::unique_ptr<TReqCtx> reqCtx{new TReqCtx(RequestNumber_)};
        auto* reqCtxPtr = reqCtx.get();
        ReqCtxs_[reqCtx->ReqId] = std::move(reqCtx);
        RequestNumber_++;
        reqCtxPtr->ReceivedAt = now;
        return reqCtxPtr;
    }

    struct TPostponedReq {
    public:        
        TPostponedReq(TReqCtx* reqCtx)
            : WeakReqCtx{reqCtx}
            , ReceivedAt(reqCtx->ReceivedAt)
        {
        }
    public:
        TWeakReqCtx WeakReqCtx;
        TInstant ReceivedAt;
    };

    struct TPostponedReqCompare {
        bool operator()(const TPostponedReq& left, TPostponedReq& right) const {
            return left.ReceivedAt < right.ReceivedAt;
        }
    };

    void EraseReqCtx(TReqCtx* reqCtx) {
        ReqCtxs_.erase(reqCtx->ReqId);
    }

    TReqCtx* LookupReqCtx(TWeakReqCtx weakReqCtx) {
        auto it = ReqCtxs_.find(weakReqCtx.ReqId_);
        if (it == ReqCtxs_.end()) {
            return nullptr;
        }
        return it->second.get();
    }

    void AddPostponedRequest(TReqCtx* reqCtx) {
        PostponedReqs_.emplace(TPostponedReq{reqCtx});
    }

    TReqCtx* ExtractPostponedRequestToProcess() {
        while (!PostponedReqs_.empty()) {
            const auto& postponed = PostponedReqs_.top();
            PostponedReqs_.pop();
            auto reqCtx = LookupReqCtx(postponed.WeakReqCtx);
            if (reqCtx) {
                return reqCtx;
            }
        }
        return nullptr;
    }

    void IncInflight() {
        ++TotalInflight_;
    }

    void DecInflight() {
        Y_VERIFY(TotalInflight_ > 0);
        --TotalInflight_;
    }

    size_t GetTotalInflightCount() const {
        return TotalInflight_;
    }

private:
    // Unique number associated with each reqCtx. ReqCtx pointers can be easily reused by allocator
    // and are bad choices to track whether request is already served or is in flight.
    TReqCtx::TReqId RequestNumber_{1};
    absl::flat_hash_map<TReqCtx::TReqId, std::unique_ptr<TReqCtx>> ReqCtxs_;
    std::priority_queue<TPostponedReq, std::vector<TPostponedReq>, TPostponedReqCompare> PostponedReqs_;
    /**
     * Number of requests actually sent to Stockpile (rather than scheduled)
     */
    ui64 TotalInflight_{0};
};

enum class ERequestTypes: ui8 {
    ReadOne = 0,
    ReadMany,
};

template <typename TFunc>
auto WithRespType(ui32 eventType, TFunc fn) {
    switch (static_cast<ERequestTypes>(eventType)) {
        case ERequestTypes::ReadOne:
            return fn.template operator()<TStockpileEvents::TReadOneResp>();
        case ERequestTypes::ReadMany:
            return fn.template operator()<TStockpileEvents::TReadManyResp>();
    }
}

template <typename TReq>
constexpr ERequestTypes MessageTypeFor() {
    if constexpr (std::is_same_v<TStockpileEvents::TReadManyReq, TReq>) {
        return ERequestTypes::ReadMany;
    } else if constexpr (std::is_same_v<TStockpileEvents::TReadOneReq, TReq>) {
        return ERequestTypes::ReadOne;
    } else {
        static_assert(TDependentFalse<TReq>, "unsupported req type");
    }
}

class TStockpileShardActor: public TActor<TStockpileShardActor>, private TPrivateEvents {
    enum {
        Timeout = SpaceBegin,
        Retry,
        End,
    };
    static_assert(End < SpaceEnd, "too many event types");

    struct TTimeoutEvent: public TEventLocal<TTimeoutEvent, Timeout> {
        TTimeoutEvent(TRequestsManager::TWeakReqCtx weakReqCtx)
            : WeakReqCtx(std::move(weakReqCtx))
        {
        }
        TRequestsManager::TWeakReqCtx WeakReqCtx;
    };

    struct TRetryEvent: public TEventLocal<TRetryEvent, Retry> {
        TRetryEvent(TRequestsManager::TWeakReqCtx weakReqCtx)
            : WeakReqCtx(std::move(weakReqCtx))
        {
        }
        TRequestsManager::TWeakReqCtx WeakReqCtx;
    };

public:
    TStockpileShardActor(IStockpileClusterRpcPtr rpc, const TStockpileShardInfo& shard, size_t maxInflight)
        : TActor<TStockpileShardActor>(&TStockpileShardActor::Requesting)
        , Rpc_{std::move(rpc)}
        , NodeRpc_{Rpc_->Get(shard.Location)}
        , HostName_{shard.Location}
        , MaxInflight_(maxInflight)
    {
    }

    STRICT_STFUNC(Requesting,
            hFunc(TStockpileWatcherEvents::TUpdate, OnStateChange);
            hFunc(TStockpileEvents::TReadManyReq, OnRequest<TStockpileEvents::TReadManyReq>)
            hFunc(TStockpileEvents::TReadManyResp, OnResponse<TStockpileEvents::TReadManyResp>)
            hFunc(TStockpileEvents::TReadOneReq, OnRequest<TStockpileEvents::TReadOneReq>)
            hFunc(TStockpileEvents::TReadOneResp, OnResponse<TStockpileEvents::TReadOneResp>)
            hFunc(TStockpileEvents::TError, OnError)
            hFunc(TTimeoutEvent, OnTimeout)
            hFunc(TRetryEvent, OnRetry)
            hFunc(TEvents::TEvPoison, OnPoison)
    )

private:
    void OnStateChange(TStockpileWatcherEvents::TUpdate::TPtr& ev) {
        auto& shard = ev->Get()->Shards.at(0);
        NodeRpc_ = Rpc_->Get(shard.Location);
        HostName_ = std::move(shard.Location);
    }

    template <typename TReqEvent>
    void OnRequest(typename TReqEvent::TPtr& ev) {
        auto reqCtx = RequestsManager_.CreateReqCtx(TActivationContext::Now());

        reqCtx->Deadline = ev->Get()->Message.deadline();
        reqCtx->OriginalRequest = std::unique_ptr<IEventHandle>(ev.Release());
        reqCtx->MessageType = ToUnderlying(MessageTypeFor<TReqEvent>());

        this->TrySendRequest<TReqEvent>(reqCtx);
    }

    template <typename TReqEvent>
    void TrySendRequest(TRequestsManager::TReqCtx* reqCtx) {
        auto deadline = TInstant::MilliSeconds(reqCtx->Deadline);

        Y_VERIFY_DEBUG(deadline != TInstant::Zero());

        if ((TActivationContext::Now() + TDuration::MilliSeconds(100)) >= deadline) {
            Send(SelfId(), new TTimeoutEvent(reqCtx));
            return;
        }

        if (RequestsManager_.GetTotalInflightCount() >= MaxInflight_) { // max inflight is reached
            MON_TRACE(StockpileClient, SelfId() << " Postponed");
            reqCtx->StockpileReqState = EStockpileReqState::Postponed;

            RequestsManager_.AddPostponedRequest(reqCtx);
            return;
        }

        this->SendRequest<TReqEvent>(reqCtx);
        // TODO: use Timer and cancel when reply received
        this->Schedule(deadline, new TTimeoutEvent(reqCtx));
    }

    template <typename TReq, typename TResp = typename TStockpileEvents::TRequestToResponse<TReq>::TResponse>
    void SendRequest(TRequestsManager::TReqCtx* reqCtx) {
        auto* origEv = static_cast<typename TReq::THandle*>(reqCtx->OriginalRequest.get());
        auto* actorSystem = TActorContext::ActorSystem();
        auto selfId = this->SelfId();

        reqCtx->StockpileReqState = EStockpileReqState::Inflight;
        RequestsManager_.IncInflight();

        auto reqCtxAsACookie = TRequestsManager::TWeakReqCtx(reqCtx).ToCookie();

        auto rpcMethod = TReqToRpcMethod<TReq>::RpcMethod;
        auto future = rpcMethod.Call(NodeRpc_, origEv->Get()->Message);
        future.Subscribe([actorSystem, selfId, reqCtxAsACookie, fqdnStr = " [" + HostName_ + "]"](auto response) {
            try {
                auto respOrError = response.ExtractValue();
                if (respOrError.Success()) {
                    if (respOrError.Value().status() == EStockpileStatusCode::OK) {
                        auto event = std::make_unique<TResp>();
                        event->Message = respOrError.Extract();
                        actorSystem->Send(new IEventHandle(selfId, selfId, event.release(), 0, reqCtxAsACookie));
                    } else {
                        // application level error
                        auto event = std::make_unique<TStockpileEvents::TError>();
                        event->RpcCode = grpc::StatusCode::OK;
                        event->StockpileCode = respOrError.Value().status();
                        event->Message = respOrError.Value().statusmessage();
                        
                        if (!event->Message) {
                            event->Message = EStockpileStatusCode_Name(event->StockpileCode);
                        }

                        event->Message += fqdnStr;

                        actorSystem->Send(new IEventHandle(selfId, selfId, event.release(), 0, reqCtxAsACookie));
                    }
                } else {
                    // transport level error
                    auto event = std::make_unique<TStockpileEvents::TError>();
                    event->RpcCode = respOrError.Error().RpcCode;
                    event->StockpileCode = respOrError.Error().StockpileCode;
                    event->Message = respOrError.Error().Message;
                    event->Message += fqdnStr;

                    actorSystem->Send(new IEventHandle(selfId, selfId, event.release(), 0, reqCtxAsACookie));
                }
            } catch (...) {
                // unexpected error
                auto event = std::make_unique<TStockpileEvents::TError>();
                event->RpcCode = grpc::StatusCode::UNKNOWN;
                event->StockpileCode = EStockpileStatusCode::UNKNOWN;
                event->Message = CurrentExceptionMessage();
                event->Message += fqdnStr;
                actorSystem->Send(new IEventHandle(selfId, selfId, event.release(), 0, reqCtxAsACookie));
            }
        });
    }

    void ProcessPostponedRequests() {
        Y_VERIFY(RequestsManager_.GetTotalInflightCount() < MaxInflight_);

        auto* reqCtx = RequestsManager_.ExtractPostponedRequestToProcess();
        if (!reqCtx) { // no reqs to process
            return;
        }

        MON_TRACE(StockpileClient, SelfId() << " Processing postponed request");
        WithRespType(reqCtx->MessageType, [&]<typename TResp>() {
            using TReq = typename TStockpileEvents::TResponseToRequest<TResp>::TRequest;
            this->TrySendRequest<TReq>(reqCtx);
        });
    }

    template <typename TRespEvent, typename TReq = typename TStockpileEvents::TResponseToRequest<TRespEvent>::TRequest>
    void OnResponse(typename TRespEvent::TPtr& ev) {
        auto* reqCtx = RequestsManager_.LookupReqCtx(TRequestsManager::TWeakReqCtx::FromCookie(ev->Cookie));
        if (!reqCtx) {
            MON_WARN(StockpileClient, "Response to already erased request context");
            return;
        }

        auto origReq = std::move(reqCtx->OriginalRequest); 
        auto handle = static_cast<typename TReq::THandle*>(origReq.get());
        MON_TRACE(StockpileClient, "response: { " << ev->Get()->Message.ShortDebugString() << "} for request {" << handle->Get()->Message.ShortDebugString() << " }");

        this->Send(origReq->Sender, ev->Release().Release(), 0, origReq->Cookie, std::move(origReq->TraceId));

        RequestsManager_.DecInflight();
        RequestsManager_.EraseReqCtx(reqCtx);

        ProcessPostponedRequests();
    }

    void OnError(TStockpileEvents::TError::TPtr& ev) {
        auto* reqCtx = RequestsManager_.LookupReqCtx(TRequestsManager::TWeakReqCtx::FromCookie(ev->Cookie));
        if (!reqCtx) {
            MON_WARN(StockpileClient, "Error response to already erased request context");
            return;
        }

        if (reqCtx->StockpileReqState == EStockpileReqState::Inflight) {
            RequestsManager_.DecInflight();
        }

        EStockpileStatusCode statusCode = ev->Get()->StockpileCode;
        if (reqCtx->RetryState.CanRetry(statusCode)) {
            MON_WARN(StockpileClient,
                    "got " << EStockpileStatusCode_Name(statusCode)
                    << ": " << ev->Get()->Message << ", will retry after "
                    << reqCtx->RetryState.NextDelay());

            reqCtx->StockpileReqState = EStockpileReqState::Retry;
            this->Schedule(reqCtx->RetryState.NextDelay(), new TRetryEvent(reqCtx));
        } else {
            MON_WARN(StockpileClient, "got " << EStockpileStatusCode_Name(statusCode) << ": " << ev->Get()->Message);

            auto origReq = std::move(reqCtx->OriginalRequest);
            this->Send(origReq->Sender, ev->Release().Release(), 0, origReq->Cookie, std::move(origReq->TraceId));
            RequestsManager_.EraseReqCtx(reqCtx);
        }

        ProcessPostponedRequests();
    }

    void OnTimeout(typename TTimeoutEvent::TPtr& ev) {
        auto* reqCtx = RequestsManager_.LookupReqCtx(ev->Get()->WeakReqCtx);
        if (!reqCtx) {
            return;
        }

        if (reqCtx->StockpileReqState == EStockpileReqState::Inflight) {
            RequestsManager_.DecInflight();
        }

        auto event = std::make_unique<TStockpileEvents::TError>();
        event->RpcCode = grpc::StatusCode::UNKNOWN;
        event->StockpileCode = EStockpileStatusCode::DEADLINE_EXCEEDED;
        event->Message = TStringBuilder{} << TStringBuf("client-side deadline exceeded: ")
                << TInstant::MilliSeconds(reqCtx->Deadline);
        auto& origReq = reqCtx->OriginalRequest;
        this->Send(origReq->Sender, event.release(), 0, origReq->Cookie, std::move(origReq->TraceId));
        RequestsManager_.EraseReqCtx(reqCtx);
    }

    void OnRetry(typename TRetryEvent::TPtr& ev) {
        auto* reqCtx = RequestsManager_.LookupReqCtx(ev->Get()->WeakReqCtx);
        if (!reqCtx) {
            return;
        }
        WithRespType(reqCtx->MessageType, [&]<typename TResp>() {
            using TReq = typename TStockpileEvents::TResponseToRequest<TResp>::TRequest;

            reqCtx->RetryState.Retry();
            this->TrySendRequest<TReq>(reqCtx);
        });
    }

    void OnPoison(TEvents::TEvPoison::TPtr& ev) {
        // TODO: kill or wait all requests in flight
        Send(ev->Sender, new TEvents::TEvPoisonTaken);
        PassAway();
    }

private:
    IStockpileClusterRpcPtr Rpc_;
    IStockpileRpc* NodeRpc_;
    TString HostName_;

    TRequestsManager RequestsManager_;
    size_t MaxInflight_;
};

} // namespace

std::unique_ptr<IActor> StockpileShardActor(const IStockpileClusterRpcPtr& rpc, const TStockpileShardInfo& shard, size_t maxInflight) {
    return std::make_unique<TStockpileShardActor>(rpc, shard, maxInflight);
}

} // namespace NSolomon::NDataProxy
