#include "call_interceptor.h"
#include "headers.h"

#include <solomon/libs/cpp/grpc/metrics/counters.h>
#include <solomon/libs/cpp/sync/rw_lock.h>

#include <library/cpp/monlib/metrics/metric_registry.h>

#include <util/string/cast.h>

#include <google/protobuf/message.h>

using namespace grpc::experimental;
using namespace NMonitoring;

namespace NSolomon {
namespace {
    const grpc::string CLIENT_ID{"x-solomon-clientid"};
    const grpc::string CREATED_AT_MS{"x-solomon-created-at"};

    class TCounterInterceptor: public Interceptor, TNonCopyable {
    public:
        TCounterInterceptor(TClientCallCounters& counters, const TString& clientId)
            : Counters_{counters}
            , CreatedAt_{TInstant::Now()}
            , ClientId_{clientId}
        {
            Counters_.ReportCallStart();
        }

        void Intercept(InterceptorBatchMethods* methods) override {
            if (methods->QueryInterceptionHookPoint(InterceptionHookPoints::PRE_SEND_INITIAL_METADATA)) {
                auto meta = methods->GetSendInitialMetadata();
                meta->emplace(CLIENT_ID, ClientId_);
                meta->emplace(CREATED_AT_MS, ::ToString(TInstant::Now().MilliSeconds()));
            }

            if (methods->QueryInterceptionHookPoint(InterceptionHookPoints::PRE_SEND_MESSAGE)) {
                if (auto* msg = static_cast<const ::google::protobuf::Message*>(methods->GetSendMessage())) {
                    Counters_.ReportOutboundBytes(msg->ByteSizeLong());
                } else if (auto* buf = methods->GetSerializedSendMessage()) {
                    Counters_.ReportOutboundBytes(buf->Length());
                }
            }

            if (methods->QueryInterceptionHookPoint(InterceptionHookPoints::POST_RECV_INITIAL_METADATA)) {
                const TInstant now{TInstant::Now()};
                auto* meta = methods->GetRecvInitialMetadata();
                auto it = meta->find(CREATED_AT_MS);

                if (it != meta->end()) {
                    ui64 createdAt;
                    TStringBuf createdAtStr = TStringBuf{it->second.data(), it->second.size()};
                    if (TryFromString<ui64>(createdAtStr, createdAt)) {
                        Counters_.ReportDelivery(now.MilliSeconds() - createdAt);
                    }
                }
            }

            if (methods->QueryInterceptionHookPoint(InterceptionHookPoints::POST_RECV_MESSAGE)) {
                if (auto* resp = static_cast<google::protobuf::Message*>(methods->GetRecvMessage())) {
                    Counters_.ReportInboundBytes(resp->ByteSizeLong());
                }
            }

            if (methods->QueryInterceptionHookPoint(InterceptionHookPoints::POST_RECV_STATUS)) {
                auto status = *methods->GetRecvStatus();
                Counters_.ReportCallStats(status.error_code(), TInstant::Now() - CreatedAt_);
            }

            methods->Proceed();
        }

    private:
        TClientCallCounters& Counters_;
        TInstant CreatedAt_;
        const TString& ClientId_;
    };

    struct TCounterInterceptorFactory: public ClientInterceptorFactoryInterface {
        TCounterInterceptorFactory(IMetricRegistry& registry, TString clientId)
            : Registry_{registry}
            , ClientId_{std::move(clientId)}
        {
        }

        Interceptor* CreateClientInterceptor(ClientRpcInfo* info) override {
            return new TCounterInterceptor{*GetOrCreateCounters(info->method()), ClientId_};
        }

    private:
        using TMap = THashMap<TString, THolder<TClientCallCounters>>;

        IMetricRegistry& Registry_;
        const TString ClientId_;
        NSync::TLightRwLock<TMap> Counters_;

        TClientCallCounters* GetOrCreateCounters(TStringBuf endpoint) {
            {
                auto counters = Counters_.Read();
                if (auto* callCounters = counters->FindPtr(endpoint)) {
                    return callCounters->Get();
                }
            }
            {
                auto counters = Counters_.Write();
                auto& clientCallCounters = (*counters)[endpoint];
                if (!clientCallCounters) {
                    clientCallCounters = MakeHolder<TClientCallCounters>(Registry_, TLabels{{"endpoint", endpoint}});
                }
                return clientCallCounters.Get();
            }
        }
    };
} // namespace

    TInterceptorFactoryPtr CreateCounterInterceptorFactory(IMetricRegistry& registry, TString clientId) {
        if (clientId.empty()) {
            clientId = HOSTNAME_HEADER_VALUE;
        }
        return std::make_unique<TCounterInterceptorFactory>(registry, std::move(clientId));
    }
} // namespace NSolomon
