#include "circuit_breaker_interceptor.h"

#include <solomon/libs/cpp/circuit_breaker/circuit_breaker.h>
#include <solomon/libs/cpp/sync/rw_lock.h>

#include <library/cpp/containers/absl_flat_hash/flat_hash_map.h>
#include <library/cpp/monlib/metrics/metric_registry.h>

using grpc::experimental::ClientInterceptorFactoryInterface;
using grpc::experimental::ClientRpcInfo;
using grpc::experimental::InterceptionHookPoints;
using grpc::experimental::Interceptor;

namespace NSolomon {
namespace {

class TCircuitBreakersMap {
public:
    TCircuitBreakersMap(double maxFailPercent, TDuration resetInterval) noexcept
        : MaxFailPercent_{maxFailPercent}
        , ResetInterval_{resetInterval}
    {
    }

    TAtomicCircuitBreaker* FindOrCreate(const std::string& peer) {
        {
            auto map = Map_.Read();
            if (auto it = map->find(peer); it != map->end()) {
                return it->second.get();
            }
        }

        auto map = Map_.Write();
        auto& cb = (*map)[peer];
        if (!cb) {
            cb = std::make_unique<TAtomicCircuitBreaker>(MaxFailPercent_, ResetInterval_, TInstant::Now());
        }
        return cb.get();
    }

private:
    const double MaxFailPercent_;
    const TDuration ResetInterval_;
    NSync::TLightRwLock<absl::flat_hash_map<std::string, std::unique_ptr<TAtomicCircuitBreaker>>> Map_;
};

class TMetricsMap {
public:
    explicit TMetricsMap(NMonitoring::IMetricRegistry& registry) noexcept
        : Registry_{registry}
    {
    }

    NMonitoring::IRate* TerminatedCalls(TStringBuf method) {
        {
            auto map = TerminatedCallsPerMethod_.Read();
            if (auto it = map->find(method); it != map->end()) {
                return it->second;
            }
        }

        auto map = TerminatedCallsPerMethod_.Write();
        if (auto it = map->find(method); it != map->end()) {
            return it->second;
        }

        auto labels = NMonitoring::MakeLabels({
                {"sensor", "grpc.client.call.status"},
                {"code", "OPEN_CIRCUIT_BREAKER"},
                {"endpoint", method}});
        auto* terminatedCalls = Registry_.Rate(std::move(labels));
        (*map)[method] = terminatedCalls;
        return terminatedCalls;
    }

private:
    NMonitoring::IMetricRegistry& Registry_;
    NSync::TLightRwLock<absl::flat_hash_map<TStringBuf, NMonitoring::IRate*, THash<TStringBuf>>> TerminatedCallsPerMethod_;
};

class TClientCircuitBreakerFactory: public ClientInterceptorFactoryInterface {
    class TInterceptor: public Interceptor {
    public:
        TInterceptor(
                grpc::ClientContext* clientCtx,
                NMonitoring::IRate* terminatedCalls,
                TCircuitBreakersMap* circuitBreakers)
            : ClientCtx_{clientCtx}
            , TerminatedCalls_{terminatedCalls}
            , CircuitBreakers_{circuitBreakers}
        {
        }

        void Intercept(grpc::experimental::InterceptorBatchMethods* methods) override {
            bool needHijack = false;
            if (methods->QueryInterceptionHookPoint(InterceptionHookPoints::PRE_SEND_INITIAL_METADATA)) {
                std::string peer = ClientCtx_->peer();
                Y_VERIFY(!peer.empty(), "grpc: peer is empty");

                CircuitBreaker_ = CircuitBreakers_->FindOrCreate(peer);
                if (!CircuitBreaker_->TryAcquirePermission(TInstant::Now())) {
                    TerminatedCalls_->Inc();
                    Terminated_ = true;
                    Peer_ = std::move(peer);
                    needHijack = true;
                }
            }

            if (Terminated_) {
                if (methods->QueryInterceptionHookPoint(InterceptionHookPoints::PRE_RECV_STATUS)) {
                    auto* status = methods->GetRecvStatus();
                    *status = grpc::Status{grpc::StatusCode::UNAVAILABLE, "CircuitBreaker::Open " + Peer_};
                    ClientCtx_->TryCancel();
                }
            } else {
                if (methods->QueryInterceptionHookPoint(InterceptionHookPoints::POST_RECV_STATUS)) {
                    auto* status = methods->GetRecvStatus();
                    ReportStatusToCircuitBreaker(status->error_code());
                }
            }

            if (needHijack) {
                // Hijack is called only once when PRE_SEND_INITIAL_METADATA is present in the hook points
                methods->Hijack();
            } else {
                methods->Proceed();
            }
        }

        void ReportStatusToCircuitBreaker(grpc::StatusCode statusCode) {
            switch (statusCode) {
                case grpc::StatusCode::UNKNOWN:
                case grpc::StatusCode::DEADLINE_EXCEEDED:
                case grpc::StatusCode::RESOURCE_EXHAUSTED:
                case grpc::StatusCode::INTERNAL:
                case grpc::StatusCode::UNAVAILABLE:
                case grpc::StatusCode::UNAUTHENTICATED:
                    CircuitBreaker_->MarkFailure(TInstant::Now());
                    break;
                default:
                    CircuitBreaker_->MarkSuccess(TInstant::Now());
                    break;
            }
        }

    private:
        grpc::ClientContext* ClientCtx_;
        NMonitoring::IRate* TerminatedCalls_;
        TCircuitBreakersMap* CircuitBreakers_;
        TAtomicCircuitBreaker* CircuitBreaker_{nullptr};
        bool Terminated_{false};
        std::string Peer_;
    };

public:
    TClientCircuitBreakerFactory(
            NMonitoring::IMetricRegistry& registry,
            double maxFailPercent,
            TDuration resetInterval)
        : Metrics_{registry}
        , CircuitBreakers_{maxFailPercent, resetInterval}
    {
    }

    Interceptor* CreateClientInterceptor(ClientRpcInfo* info) override {
        NMonitoring::IRate* terminatedCalls = Metrics_.TerminatedCalls(info->method());
        // peer is not yet initialized in client context on this stage
        return new TInterceptor{info->client_context(), terminatedCalls, &CircuitBreakers_};
    }

private:
    TMetricsMap Metrics_;
    TCircuitBreakersMap CircuitBreakers_;
};

} // namespace

std::unique_ptr<ClientInterceptorFactoryInterface> ClientCircuitBreakerInterceptorFactory(
        NMonitoring::IMetricRegistry& registry,
        double maxFailPercent,
        TDuration resetInterval)
{
    return std::make_unique<TClientCircuitBreakerFactory>(registry, maxFailPercent, resetInterval);
}

} // namespace NSolomon
