#include "events.h"
#include "requester.h"
#include "shard_actor.h"

#include <solomon/libs/cpp/actors/scheduler/scheduler.h>
#include <solomon/libs/cpp/logging/logging.h>

#include <library/cpp/actors/core/actor.h>
#include <library/cpp/actors/core/hfunc.h>
#include <library/cpp/containers/absl_flat_hash/flat_hash_map.h>

#include <unordered_map>

using namespace NSolomon::NTracing;

namespace NSolomon::NDataProxy {
namespace {

using namespace NActors;
using yandex::solomon::metabase::EMetabaseStatusCode;
using IActorPtr = std::unique_ptr<IActor>;

template <typename T>
concept HasSetShardId = requires(T t, TShardId id) {
    t.set_shard_id(id);
};

// id of a scheduled event inside Scheduler is {actorId, eventId}. Since actorId is unique,
// and we have only one timeout per TShardsRequester, there's no need to create more than one eventId
constexpr const ui64 TIMEOUT_ID = 1;

template <typename TEventReq, typename TEventResp>
class TShardsRequester: public TActor<TShardsRequester<TEventReq, TEventResp>>, TPrivateEvents {
    using TBase = TActor<TShardsRequester<TEventReq, TEventResp>>;

    enum {
        Timedout = SpaceBegin,
        End
    };
    static_assert(SpaceBegin < End, "too many events");

    struct TTimedout: TEventLocal<TTimedout, Timedout> {};

public:
    TShardsRequester(TVector<TShardActorId> shardActors, TActorId schedulerId)
        : TBase(&TShardsRequester::StateFunc)
        , ShardActors_{std::move(shardActors)}
        , SchedulerId_{schedulerId}
        , AwaitingResponses_{static_cast<ui32>(ShardActors_.size())}
    {
        ShardsErrorResponses_.reserve(ShardActors_.size());

        for (const auto& shardActor: ShardActors_) {
            ShardsErrorResponses_[shardActor.ActorId] = {};
        }
    }

    STATEFN(StateFunc) {
        switch (ev->GetTypeRewrite()) {
            hFunc(TEventReq, OnRequest);
            hFunc(TEventResp, OnResponse);
            sFunc(TTimedout, OnTimedout);
            hFunc(TMetabaseEvents::TError, OnError);
        }
    }

    void OnRequest(typename TEventReq::TPtr ev) {
        ReplyTo_ = ev->Sender;
        ReplyCookie_ = ev->Cookie;
        TraceCtx_ = std::move(ev->TraceId);

        MON_TRACE(MetabaseClient, "Requester" << this->SelfId() << ": sending a req to "
                << ShardActors_.size() << " shards" << ": " << ev->Get()->Message->ShortDebugString());

        this->Send(SchedulerId_, new TSchedulerEvents::TScheduleAt{
                TIMEOUT_ID,
                ev->Get()->Deadline,
                std::make_unique<TTimedout>()});

        if constexpr (HasSetShardId<typename TEventReq::TProtoMsg>) {
            for (size_t i = 0; i != ShardActors_.size(); ++i) {
                const auto& [shardId, actor] = ShardActors_[i];
                // TODO: drop shard selector from selectors
                auto evCopy = std::make_unique<TEventReq>(*ev->Get());
                auto msg = std::make_shared<typename TEventReq::TProtoMsg>(*evCopy->Message);
                msg->set_shard_id(shardId);
                evCopy->Message = std::move(msg);
                auto span = TRACING_NEW_SPAN_START(TraceCtx_, ev->Get()->Message->GetTypeName() << " in " << shardId);
                this->Send(actor, evCopy.release(), 0, 0, std::move(span));
            }
        } else {
            Y_VERIFY(ShardActors_.size() == 1, "cross-shard request without shard_id field");
            auto span = TRACING_NEW_SPAN_START(TraceCtx_, ev->Get()->Message->GetTypeName() << " in " << ShardActors_[0].ShardId);
            this->Send(ShardActors_[0].ActorId, ev->Release().Release(), 0, 0, std::move(span));
        }

        Y_VERIFY(AwaitingResponses_ > 0);
    }

    void CancelTimedoutRequests() {
        for (auto& [shardId, errResponse]: ShardsErrorResponses_) {
            auto cancelEv = std::make_unique<TMetabaseShardActorEvents::TCancelRequest>();
            this->Send(shardId, cancelEv.release());
        }
    }

    void FillEmptyResponsesWithTimeoutErr() {
        TStringBuilder errMsg;
        errMsg << "request timed out at " << TActivationContext::Now();

        for (auto& [_, response]: ShardsErrorResponses_) {
            if (response) { // shard actor has already replied with some error
                continue;
            }

            auto errResponse = std::make_unique<TMetabaseEvents::TError>();
            errResponse->RpcCode = grpc::StatusCode::DEADLINE_EXCEEDED;
            errResponse->MetabaseCode = EMetabaseStatusCode::DEADLINE_EXCEEDED;
            errResponse->Message = errMsg;

            response = std::move(errResponse);
        }
    }

    void ReplyWithErrors() {
        for (auto& [_, errResponse]: ShardsErrorResponses_) {
            this->Send(ReplyTo_, errResponse.release(), 0, ReplyCookie_);
        }
    }

    void OnTimedout() {
        CancelTimedoutRequests();
        FillEmptyResponsesWithTimeoutErr();
        ReplyWithErrors();

        this->Send(ReplyTo_, new TMetabaseEvents::TDone, 0, ReplyCookie_, std::move(TraceCtx_));
        this->PassAway();
    }

    void PassAwayIfDone() {
        if (AwaitingResponses_ == 0) {
            this->Send(SchedulerId_, new TSchedulerEvents::TCancel{TIMEOUT_ID});
            this->Send(ReplyTo_, new TMetabaseEvents::TDone, 0, ReplyCookie_, std::move(TraceCtx_));
            this->PassAway();
        }
    }

    void OnResponse(typename TEventResp::TPtr ev) {
        TRACING_SPAN_END_EV(ev);
        --AwaitingResponses_;
        ShardsErrorResponses_.erase(ev->Sender);

        size_t reqNum = ShardActors_.size() - AwaitingResponses_;
        MON_TRACE(MetabaseClient, "Requester" << this->SelfId() << ": got a response "
                << reqNum << "/" << ShardActors_.size());

        this->Send(ReplyTo_, ev->Release().Release(), 0, ReplyCookie_);
        PassAwayIfDone();
    }

    /**
     * Handles both an error and a timeout received from ShardActor
     */
    void OnError(TMetabaseEvents::TError::TPtr& ev) {
        if (!IsRetryableError(ev->Get()->MetabaseCode)) {
            TRACING_SPAN_END_EV(ev);
            --AwaitingResponses_;
            ShardsErrorResponses_.erase(ev->Sender);

            this->Send(ReplyTo_, ev->Release().Release(), 0, ReplyCookie_);
        } else {
            ShardsErrorResponses_[ev->Sender] = std::unique_ptr<IEventBase>(ev->Release().Release());
        }

        size_t reqNum = ShardActors_.size() - AwaitingResponses_;
        MON_TRACE(MetabaseClient, "Requester" << this->SelfId() << ": got an err "
                << reqNum << "/" << ShardActors_.size());

        PassAwayIfDone();
    }

private:
    TActorId ReplyTo_;
    ui64 ReplyCookie_;
    TVector<TShardActorId> ShardActors_;
    TActorId SchedulerId_;
    ui32 AwaitingResponses_;
    absl::flat_hash_map<TActorId, std::unique_ptr<IEventBase>, THash<TActorId>> ShardsErrorResponses_;
    TSpanId TraceCtx_;
};

template <typename TReq, typename... TArgs>
IActorPtr MakeShardsRequester(TArgs... args) {
    using TResp = typename TMetabaseEvents::TRequestToResponse<TReq>::TResponse;

    return std::make_unique<TShardsRequester<TReq, TResp>>(std::forward<TArgs>(args)...);
}

} // namespace

#define MAKE_FUNCTIONS(Req) \
    template <> \
    IActorPtr ShardsRequester<TMetabaseEvents::Req>(TVector<TShardActorId> shardActors, TActorId schedulerId) { \
        return MakeShardsRequester<TMetabaseEvents::Req>(std::move(shardActors), schedulerId); \
    }

MAKE_FUNCTIONS(TFindReq)
MAKE_FUNCTIONS(TResolveOneReq)
MAKE_FUNCTIONS(TResolveManyReq)
MAKE_FUNCTIONS(TMetricNamesReq)
MAKE_FUNCTIONS(TLabelNamesReq)
MAKE_FUNCTIONS(TLabelValuesReq)
MAKE_FUNCTIONS(TUniqueLabelsReq)

} // namespace NSolomon::NDataProxy
