#pragma once

#include "requester_ctx.h"

#include <solomon/services/dataproxy/lib/datasource/status_code.h>
#include <solomon/services/dataproxy/lib/memstore/cluster.h>
#include <solomon/services/dataproxy/lib/memstore/shard.h>
#include <solomon/libs/cpp/trace/trace.h>

#include <solomon/libs/cpp/logging/logging.h>
#include <solomon/libs/cpp/grpc/status/code.h>
#include <solomon/libs/cpp/trace/trace.h>

#include <library/cpp/actors/core/actor_bootstrapped.h>
#include <library/cpp/actors/core/hfunc.h>
#include <library/cpp/containers/absl_flat_hash/flat_hash_map.h>

namespace NSolomon::NDataProxy {

template <typename TDerived, typename TReqEvent, typename TRespEvent>
class TMemStoreRequester: public NActors::TActorBootstrapped<TDerived> {
    struct TShardState {
        // TODO: the shard key needed only for Find requester
        TShardSubKey Key;
        // TODO: remember request id and shard actor ids to cancel inflight requests
    };

public:
    TMemStoreRequester(TRequesterContextPtr ctx, TShardSelector shardSelector, TInstant deadline, NTracing::TSpanId span)
        : Ctx_{std::move(ctx)}
        , ShardSelector_{std::move(shardSelector)}
        , Deadline_{deadline}
        , Span_(std::move(span))
    {
    }

    void Bootstrap() {
        this->Become(&TMemStoreRequester::StateFunc);
        // TODO: abort request at deadline

        for (const auto& [clusterId, actorId]: Ctx_->Clusters) {
            AwaitingClusters_++;
            // TODO: add details about method (e.g. str(TDerived))
            auto span = TRACING_NEW_SPAN_START(Span_, "FindShards in " << clusterId);
            this->Send(actorId, new TMemStoreClusterEvents::TFindShardsReq{ShardSelector_}, 0, 0, std::move(span));
        }
    }

private:
    STATEFN(StateFunc) {
        switch (ev->GetTypeRewrite()) {
            hFunc(TMemStoreClusterEvents::TFindShardsResp, OnClusterResp)
            hFunc(TRespEvent, OnShardResponse)
            hFunc(TMemStoreShardEvents::TError, OnShardError)
            hFunc(NActors::TEvents::TEvPoison, OnPoison)
        }
    }

    void OnClusterResp(const TMemStoreClusterEvents::TFindShardsResp::TPtr& ev) {
        Y_VERIFY(--AwaitingClusters_ >= 0);
        TRACING_SPAN_END_EV(ev);

        auto clusterId = ev->Get()->ClusterId;
        auto& shards = ev->Get()->Shards;
        if (!shards.empty()) {
            MON_INFO(Sts, "found " << shards.size() << " shards in " << clusterId);

            for (auto&& shard: shards) {
                SendShardRequest(clusterId, shard);
                Shards_.emplace(shard.Id, TShardState{std::move(shard.Key)});
            }
        }

        if (AwaitingClusters_ == 0 && AwaitingShards_ == 0) {
            TRACING_SPAN_END(Span_);
            Self()->OnFinish(
                    EDataSourceStatus::NOT_FOUND,
                    TStringBuilder{} << "shard(s) not found by selector " << ShardSelector_);
            this->PassAway();
        }
    }

    void SendShardRequest(TClusterId clusterId, const TMemStoreShard& shard) {
        auto* event = new TReqEvent{SelfConst()->RequestTemplate(), Deadline_};
        event->Message.set_num_id(shard.Id);

        MON_DEBUG(Sts, "sent request to " << clusterId << '/' << shard.Id << ": " << event->Message.GetTypeName()
                << "{\n" << event->Message.DebugString() << '}');
        auto span = TRACING_NEW_SPAN_START(Span_, event->Message.GetTypeName() << " in " << clusterId << '/' << shard.Id);
        this->Send(shard.ActorId, event, 0, 0, std::move(span));
        AwaitingShards_++;
    }

    void OnShardResponse(typename TRespEvent::TPtr ev) {
        Y_VERIFY(--AwaitingShards_ >= 0);
        TRACING_SPAN_END_EV(ev);
        auto clusterId = ev->Get()->ClusterId;
        auto shardId = ev->Get()->ShardId;
        try {
            MON_DEBUG(Sts,
                      "received OK from " << clusterId << '/' << shardId << ": " << ev->Get()->Message.GetTypeName()
                                          << "{\n" << ev->Get()->Message.DebugString() << '}');

            if (auto it = Shards_.find(shardId); Y_LIKELY(it != Shards_.end())) {
                Self()->OnResponse(clusterId, it->second.Key, ev->Get()->Message);
            } else {
                MON_ERROR(Sts, "got response from unknown shard: " << shardId);
            }
            if (AwaitingShards_ == 0) {
                TRACING_SPAN_END(Span_);
                Self()->OnFinish(EDataSourceStatus::OK, TString{});
                this->PassAway();
            }
        } catch (const yexception& e) {
            MON_ERROR(Sts, "error while processing sts response: " << e.AsStrBuf());
            Self()->OnError(clusterId, shardId, grpc::StatusCode::UNKNOWN, e.what());
        }
    }

    void OnShardError(const TMemStoreShardEvents::TError::TPtr& ev) {
        Y_VERIFY(--AwaitingShards_ >= 0);
        TRACING_SPAN_END_EV(ev);

        auto clusterId = ev->Get()->ClusterId;
        auto shardId = ev->Get()->ShardId;
        auto code = ev->Get()->StatusCode;
        auto& message = ev->Get()->Message;

        MON_DEBUG(Sts, "received ERROR from " << clusterId << '/' << shardId << ": "
                << ::NGrpc::StatusCodeToString(code) << ' ' << message);

        Self()->OnError(clusterId, shardId, code, std::move(message));

        if (AwaitingShards_ == 0) {
            TRACING_SPAN_END(Span_);
            // TODO: properly convert errors to result
            Self()->OnFinish(EDataSourceStatus::OK, TString{});
            this->PassAway();
        }
    }

    void OnPoison(const NActors::TEvents::TEvPoison::TPtr& ev) {
        TRACING_SPAN_END(Span_);
        Self()->OnFinish(EDataSourceStatus::UNKNOWN, "request was interrupted"); // TODO: use better status
        this->Send(ev->Sender, new NActors::TEvents::TEvPoisonTaken);
        this->PassAway();
    }

    TDerived* Self() {
        return static_cast<TDerived*>(this);
    }

    const TDerived* SelfConst() const {
        return static_cast<const TDerived*>(this);
    }

private:
    TRequesterContextPtr Ctx_;
    const TShardSelector ShardSelector_;
    TInstant Deadline_;
    absl::flat_hash_map<TShardId, TShardState> Shards_;
    i32 AwaitingClusters_{0};
    i32 AwaitingShards_{0};
    NTracing::TSpanId Span_;
};

} // namespace NSolomon::NDataProxy
