#include "server_call_interceptor.h"

#include <solomon/libs/cpp/grpc/metrics/counters.h>
#include <solomon/libs/cpp/sync/rw_lock.h>

#include <library/cpp/monlib/metrics/metric_registry.h>

#include <util/string/cast.h>

#include <google/protobuf/message.h>

using namespace grpc::experimental;
using namespace NMonitoring;

namespace NSolomon {
namespace {
    const grpc::string CLIENT_ID{"x-solomon-clientid"};
    const grpc::string CREATED_AT_MS{"x-solomon-created-at"};

    struct TRpcMetrics {
    public:
        TRpcMetrics(IMetricRegistry& registry, TLabels labels)
            : Registry_{registry}
            , Labels_{std::move(labels)}
            , Total_(registry, MakeLabels("total"))
        {
        }

        THolder<TAggregateCallCounters> Resolve(TStringBuf clientId) {
            return MakeHolder<TAggregateCallCounters>(ResolveByClient(clientId), &Total_);
        }

    private:
        using TMap = THashMap<TString, THolder<TServerCallCounters>>;

        IMetricRegistry& Registry_;
        const TLabels Labels_;
        TServerCallCounters Total_;
        NSync::TLightRwLock<TMap> Counters_;

        TLabels MakeLabels(TStringBuf clientId) {
            auto result = Labels_;
            result.Add("clientId", clientId);
            return result;
        }

        TServerCallCounters* ResolveByClient(TStringBuf clientId) {
            {
                auto counters = Counters_.Read();
                if (auto* callCounters = counters->FindPtr(clientId)) {
                    return callCounters->Get();
                }
            }
            {
                auto counters = Counters_.Write();
                auto& serverCallCounters = (*counters)[clientId];
                if (!serverCallCounters) {
                    serverCallCounters = MakeHolder<TServerCallCounters>(Registry_, MakeLabels(clientId));
                }
                return serverCallCounters.Get();
            }
        }
    };

    struct TEndpointMetrics {
    public:
        explicit TEndpointMetrics(IMetricRegistry& registry)
            : Registry_{registry}
        {
        }

        TRpcMetrics* Resolve(TStringBuf endpoint) {
            {
                auto counters = Counters_.Read();
                if (auto* callCounters = counters->FindPtr(endpoint)) {
                    return callCounters->Get();
                }
            }
            {
                auto counters = Counters_.Write();
                auto& rpcMetrics = (*counters)[endpoint];
                if (!rpcMetrics) {
                    rpcMetrics = MakeHolder<TRpcMetrics>(Registry_, TLabels{{"endpoint", endpoint}});
                }
                return rpcMetrics.Get();
            }
        }

    private:
        using TMap = THashMap<TString, THolder<TRpcMetrics>>;

        IMetricRegistry& Registry_;
        NSync::TLightRwLock<TMap> Counters_;
    };

    class TServerCounterInterceptor: public Interceptor, TNonCopyable {
        using TMetaPtr = std::multimap<grpc::string_ref, grpc::string_ref>*;

    public:
        explicit TServerCounterInterceptor(TRpcMetrics& metrics)
            : Metrics_(metrics)
            , StartedAt_{TInstant::Now()}
        {
        }

        void Intercept(InterceptorBatchMethods* methods) override {
            if (methods->QueryInterceptionHookPoint(InterceptionHookPoints::POST_RECV_INITIAL_METADATA)) {
                const TMetaPtr meta = methods->GetRecvInitialMetadata();
                ByClientMetrics_ = Metrics_.Resolve(ResolveClientId(meta));
                ByClientMetrics_->ReportCallStart();
                TInstant createdAt = ResolveCreatedAt(meta);
                if (createdAt != TInstant::Zero()) {
                    ByClientMetrics_->ReportDelivery(TInstant::Now() - createdAt);
                }
            }

            if (methods->QueryInterceptionHookPoint(InterceptionHookPoints::POST_RECV_MESSAGE)) {
                if (auto* resp = static_cast<google::protobuf::Message*>(methods->GetRecvMessage())) {
                    ByClientMetrics_->ReportInboundBytes(resp->ByteSizeLong());
                }
            }

            if (methods->QueryInterceptionHookPoint(InterceptionHookPoints::PRE_SEND_INITIAL_METADATA)) {
                auto meta = methods->GetSendInitialMetadata();
                meta->emplace(CREATED_AT_MS, ::ToString(TInstant::Now().MilliSeconds()));
            }

            if (methods->QueryInterceptionHookPoint(InterceptionHookPoints::PRE_SEND_MESSAGE)) {
                if (auto* msg = static_cast<const ::google::protobuf::Message*>(methods->GetSendMessage())) {
                    ByClientMetrics_->ReportOutboundBytes(msg->ByteSizeLong());
                } else if (auto* buf = methods->GetSerializedSendMessage()) {
                    ByClientMetrics_->ReportOutboundBytes(buf->Length());
                }
            }

            if (methods->QueryInterceptionHookPoint(InterceptionHookPoints::PRE_SEND_STATUS)) {
                auto status = methods->GetSendStatus();
                ByClientMetrics_->ReportCallStats(status.error_code(), TInstant::Now() - StartedAt_);
            }

            methods->Proceed();
        }

    private:
        TRpcMetrics& Metrics_;
        TInstant StartedAt_;
        THolder<TAggregateCallCounters> ByClientMetrics_;

        static TStringBuf ResolveClientId(const TMetaPtr meta) {
            auto it = meta->find(CLIENT_ID);

            if (it != meta->end()) {
                return TStringBuf{it->second.data(), it->second.size()};
            } else {
                return "unknown";
            }
        }

        static TInstant ResolveCreatedAt(const TMetaPtr meta) {
            ui64 createdAt = 0;
            auto it = meta->find(CREATED_AT_MS);
            if (it != meta->end()) {
                TStringBuf createdAtStr = TStringBuf{it->second.data(), it->second.size()};
                TryFromString<ui64>(createdAtStr, createdAt);
            }
            return TInstant::MilliSeconds(createdAt);
        }
    };

    struct TServerCounterInterceptorFactory: public ServerInterceptorFactoryInterface {
    public:
        explicit TServerCounterInterceptorFactory(IMetricRegistry& registry)
            : MetricsByEndpoint_{registry}
        {
        }

        Interceptor* CreateServerInterceptor(ServerRpcInfo* info) override {
            return new TServerCounterInterceptor{*MetricsByEndpoint_.Resolve(info->method())};
        }

    private:
        TEndpointMetrics MetricsByEndpoint_;
    };
} // namespace

    TServerInterceptorFactoryPtr CreateServerCounterInterceptorFactory(IMetricRegistry& registry) {
        return std::make_unique<TServerCounterInterceptorFactory>(registry);
    }
} // namespace NSolomon
