#pragma once

#include <library/cpp/logger/log.h>
#include <util/generic/bitmap.h>
#include <util/random/shuffle.h>

#include <infra/monitoring/common/logging.h>

#include <infra/yasm/stockpile_client/common/base_types.h>
#include <infra/yasm/stockpile_client/rpc.h>

namespace NHistDb::NStockpile::NDataProxyClient {
    using TLabelKeysCallMethod = decltype(&TDataProxyService::Stub::PrepareAsyncLabelKeys);
    using TLabelKeysCallState = TGrpcAsyncCallState<TLabelKeysCallMethod>;

    using TLabelValuesCallMethod = decltype(&TDataProxyService::Stub::PrepareAsyncLabelValues);
    using TLabelValuesCallState = TGrpcAsyncCallState<TLabelValuesCallMethod>;

    using TUniqueLabelsCallMethod = decltype(&TDataProxyService::Stub::PrepareAsyncUniqueLabels);
    using TUniqueLabelsCallState = TGrpcAsyncCallState<TUniqueLabelsCallMethod>;

    void FillLabelKeysRequest(const TString& projectId, const TString& selectors, TInstant from, TInstant to,
                              TLabelKeysCallState::TRequestType& requestToFill);

    void FillLabelValuesRequest(const TString& projectId, const TString& selectors, const TSet<TString>& keys,
                                TInstant from, TInstant to, size_t limit,
                                TLabelValuesCallState::TRequestType& requestToFill);

    void FillUniqueLabelsRequest(const TString& projectId, const TString& selectors, const TSet<TString>& keys,
                                 TUniqueLabelsCallState::TRequestType& requestToFill);

    template <class TCall>
    class TCallbackOnlyState: public TGrpcState {
    public:
        TCallbackOnlyState(TCall call, TString callName)
            : Call(std::move(call))
            , CallName(callName){
        }

        TStringBuf GetRequestName() const override {
            return CallName;
        }

        void Handle() override {
            Call.Check(); // throws exceptions to propagate the response code to handler
        }

        TCall& GetCall() {
            return Call;
        }

    private:
        TCall Call;
        TString CallName;
    };

    class TDataProxyHostSelector {
    public:
        using TClusterAndHostIndices = std::pair<size_t, size_t>; // first - cluster index, second - host index
        using TAttemptTargetSequence = TVector<TClusterAndHostIndices>;
        using TClusterHosts = TVector<TVector<TAtomicSharedPtr<NHistDb::NStockpile::TGrpcRemoteHost>>>;

        TDataProxyHostSelector(TClusterHosts hosts, size_t clustersToRequest, size_t attemptsPerCluster);

        TAtomicSharedPtr<NHistDb::NStockpile::TGrpcRemoteHost> GetHost(const TClusterAndHostIndices& hostIndex) const;
        const TAttemptTargetSequence& GetDefaultTargetSequence() const;

    private:
        void ReshuffleClusters(size_t clustersToKeep);
        void InitDefaultTargetSequence();

        TVector<TVector<TAtomicSharedPtr<NHistDb::NStockpile::TGrpcRemoteHost>>> Hosts;
        const size_t AttemptsPerCluster;
        TAttemptTargetSequence DefaultTargetSequence;
    };

    struct TDefaultGrpcHandlerFactory {
        using THandlerType = NHistDb::NStockpile::TGrpcStateLoggingHandler<NMonitoring::TRequestLog>;
        THandlerType operator()(NMonitoring::TRequestLog& logger) const {
            return THandlerType(logger);
        }
    };

    template <class TMethod, class TGrpcHandlerFactory = TDefaultGrpcHandlerFactory>
    class TDataProxyRequester {
    public:
        using TCall = NHistDb::NStockpile::TGrpcAsyncCallState<TMethod>;
        using TRequest = typename TCall::TRequestType;
        using TResponse = typename TCall::TResponseType;

        static constexpr TDuration POLL_INTERVAL = TDuration::MilliSeconds(100);

        TDataProxyRequester(TMethod method,
                            TVector<TVector<TAtomicSharedPtr<NHistDb::NStockpile::TGrpcRemoteHost>>> hosts,
                            TString callName,
                            size_t clustersToRequest,
                            size_t attemptsPerCluster,
                            NMonitoring::TRequestLog& logger,
                            TGrpcHandlerFactory handlerFactory = TGrpcHandlerFactory())
            : Method(method)
            , CallName(std::move(callName))
            , HostSelector(std::move(hosts), clustersToRequest, attemptsPerCluster)
            , Logger(logger)
            , GrpcHandler(handlerFactory(Logger))
            , CallbacksFailed(0) {
        }

        void Reserve(size_t callsCount) {
            Calls.reserve(callsCount);
        }

        template <class TCallback>
        TRequest& PrepareCall(TCallback&& callback) {
            Calls.emplace_back();
            auto callId = Calls.size() - 1;
            auto& callInfo = Calls[callId];
            callInfo.RemainingAttempts = HostSelector.GetDefaultTargetSequence();

            auto host = SelectNextHost(callId);
            if (!host) {
                throw yexception() << "No hosts to request";
            }

            callInfo.UserCallback = std::move(callback);
            InitializeCallState(
                callInfo,
                GrpcAsyncCallState(Method, host, EGrpcRetryModeFlags::ON_DEADLINE | EGrpcRetryModeFlags::ON_INTERNAL_ERROR, false),
                callId);
            return callInfo.CallState->GetCall().GetRequest();
        }

        size_t Request(TInstant deadline = TInstant::Zero()) {
            Deadline = deadline;
            for (auto& callInfo: Calls) {
                ExecuteCallState(callInfo.CallState.GetRef());
            }

            bool needNextWait = true;
            bool addedNewCalls = true;
            while (needNextWait || addedNewCalls) {
                needNextWait = GrpcHandler.WaitAsync(POLL_INTERVAL);
                addedNewCalls = ProcessCompletedCalls(false);
                if (deadline && TInstant::Now() >= deadline) {
                    break;
                }
            }

            size_t callsFailed = 0;
            for (auto& callInfo: Calls) {
                if (!callInfo.CallState->IsSuccess()) {
                    const TGrpcRemoteHost& curHost = callInfo.CallState->GetCall().GetRemoteHost();
                    if (!callInfo.CallState->IsFinished()) {
                        Logger << ELogPriority::TLOG_WARNING << "Call to " << curHost << " timed out";
                    } else {
                        Logger << ELogPriority::TLOG_WARNING << "Call to " << curHost << " failed";
                    }
                    callsFailed += 1;
                    callInfo.CallState->Cancel();
                }
            }

            GrpcHandler.Wait(); // wait for cancelled calls to be cancelled
            ProcessCompletedCalls(true);
            return callsFailed + CallbacksFailed;
        }

    private:
        struct TCallInfo: public TMoveOnly {
            TMaybe<TCallbackOnlyState<TCall>> CallState;
            TDataProxyHostSelector::TAttemptTargetSequence RemainingAttempts;
            std::function<void(TCallbackOnlyState<TCall>&, const TResponse*)> UserCallback;
        };

        bool ProcessCompletedCalls(bool stopping) {
            bool madeNewCalls = false;
            while (!CompletedCallIds.empty()) {
                auto callId = CompletedCallIds.front();
                madeNewCalls |= ProcessCompletedCall(callId, stopping);
                CompletedCallIds.pop_front();
            }
            return madeNewCalls;
        }

        bool ProcessCompletedCall(size_t callId, bool stopping) {
            TCallInfo& callInfo = Calls[callId];
            bool callRetried = false;
            auto nextHost = SelectNextHost(callId);
            if (!stopping && callInfo.CallState->IsRetriable() && nextHost) {
                const TGrpcRemoteHost& prevHost = callInfo.CallState->GetCall().GetRemoteHost();
                Logger << ELogPriority::TLOG_INFO << "Call to " << prevHost << " failed. Retrying on " << *nextHost;
                InitializeCallState(callInfo, callInfo.CallState->GetCall().RecreateWithNewHost(nextHost), callId);
                ExecuteCallState(callInfo.CallState.GetRef());
                callRetried = true;
            } else {
                const TResponse* userResponse = (callInfo.CallState->IsSuccess()) ?
                                                &callInfo.CallState->GetCall().GetResponse() :
                                                nullptr;
                try {
                    callInfo.UserCallback(callInfo.CallState.GetRef(), userResponse);
                } catch (...) {
                    Logger << ELogPriority::TLOG_ERR << "Unexpected user callback error: " << CurrentExceptionMessage();
                    ++CallbacksFailed;
                }
            }
            return callRetried;
        }

        TAtomicSharedPtr<NHistDb::NStockpile::TGrpcRemoteHost> SelectNextHost(size_t callId) {
            TAtomicSharedPtr<NHistDb::NStockpile::TGrpcRemoteHost> result;
            auto& remainingAttempts = Calls[callId].RemainingAttempts;
            if (!remainingAttempts.empty()) {
                auto bestHostIt = remainingAttempts.rbegin();
                for (auto it = bestHostIt; it != remainingAttempts.rend(); ++it) {
                    auto host = HostSelector.GetHost(*it);
                    if (host->GetChannel()->GetState(true) == GRPC_CHANNEL_READY) {
                        bestHostIt = it;
                        break;
                    }
                }
                std::swap(*bestHostIt, remainingAttempts.back());
                result = HostSelector.GetHost(remainingAttempts.back());
                remainingAttempts.pop_back();
            }
            return result;
        }

        void InitializeCallState(TCallInfo& callInfo, TCall call, size_t callId) {
            callInfo.CallState.ConstructInPlace(std::move(call), CallName);
            callInfo.CallState->SetCallback([this, callId](auto&) {
                CompletedCallIds.push_back(callId);
            });
        }

        void ExecuteCallState(TCallbackOnlyState<TCall>& callState) {
            if (Deadline) {
                // SetDeadline actually accepts timeout. Calculate it.
                auto now = TInstant::Now();
                auto timeout = (now < Deadline) ? Deadline - now : TDuration::Zero();
                callState.GetCall().SetDeadline(timeout);
            }
            GrpcHandler.Execute(callState.GetCall(), callState);
        }

        const TMethod Method;
        const TString CallName;
        TInstant Deadline;
        const TDataProxyHostSelector HostSelector;
        NMonitoring::TRequestLog& Logger;
        typename TGrpcHandlerFactory::THandlerType GrpcHandler;
        TList<size_t> CompletedCallIds;
        size_t CallbacksFailed;

        TVector<TCallInfo> Calls;
    };
}
