#include "logging_interceptor.h"

#include <crypta/lib/native/proto_serializer/proto_serializer.h>

#include <grpc++/server_context.h>
#include <contrib/libs/grpc/src/core/lib/channel/status_util.h>

#include <util/generic/strbuf.h>
#include <util/generic/yexception.h>
#include <util/string/split.h>

using namespace NCrypta::NGrpc;

constexpr TStringBuf SERVICE = "service";
constexpr TStringBuf HANDLE = "handle";
constexpr TStringBuf PEER = "peer";
constexpr TStringBuf REQUEST_METADATA = "request_metadata";
constexpr TStringBuf RESPONSE_METADATA = "response_metadata";
constexpr TStringBuf REQUEST = "request";
constexpr TStringBuf RESPONSE = "response";
constexpr TStringBuf STATUS = "status";
constexpr TStringBuf CODE = "code";
constexpr TStringBuf MESSAGE = "message";
constexpr TStringBuf SEND_STATUS = "send_status";
constexpr char METHOD_DELIMITER = '/';

namespace {
    template<typename TMap>
    void WriteMetadata(const TMap* metadata, const TStringBuf& fieldName, NJson::TJsonWriter& writer) {
        writer.WriteKey(fieldName);

        if (!metadata) {
            writer.WriteNull();
            return;
        }

        writer.OpenArray();

        for (const auto& [key, value] : *metadata) {
            writer.OpenArray();
            writer.Write(TStringBuf(key.data(), key.length()));
            writer.Write(TStringBuf(value.data(), value.length()));
            writer.CloseArray();
        }

        writer.CloseArray();
    }

    void WriteMessage(const void* message, const TStringBuf& fieldName, NJson::TJsonWriter& writer) {
        if (message) {
            writer.UnsafeWrite(fieldName, NCrypta::NProtoSerializer::ToJson(*static_cast<const google::protobuf::Message*>(message)));
        }
    }
}

TLoggingInterceptor::TLoggingInterceptor(grpc::experimental::ServerRpcInfo* info, TLogHandler& logHandler, NLog::TLogPtr log)
    : LogHandler(logHandler)
    , Log(log)
    , Writer(&Stream, true)
    , Info(info)
{
    Writer.OpenMap();
    SaveMethod(info->method());
    Disabled = Service == "grpc.reflection.v1alpha.ServerReflection";
}

void TLoggingInterceptor::Intercept(grpc::experimental::InterceptorBatchMethods* methods) {
    if (!Disabled) {
        try {
            if (methods->QueryInterceptionHookPoint(grpc::experimental::InterceptionHookPoints::POST_RECV_INITIAL_METADATA)) {
                SaveRequestMetadata(methods->GetRecvInitialMetadata());
            }

            if (methods->QueryInterceptionHookPoint(grpc::experimental::InterceptionHookPoints::POST_RECV_MESSAGE)) {
                SaveRequestMessage(methods->GetRecvMessage());
            }

            if (methods->QueryInterceptionHookPoint(grpc::experimental::InterceptionHookPoints::PRE_SEND_INITIAL_METADATA)) {
                SaveResponseMetadata(methods->GetSendInitialMetadata());
            }

            if (methods->QueryInterceptionHookPoint(grpc::experimental::InterceptionHookPoints::PRE_SEND_MESSAGE)) {
                SaveResponseMessage(methods->GetSendMessage());
            }

            if (methods->QueryInterceptionHookPoint(grpc::experimental::InterceptionHookPoints::PRE_SEND_STATUS)) {
                SaveStatus(methods->GetSendStatus());
            }

            if (methods->QueryInterceptionHookPoint(grpc::experimental::InterceptionHookPoints::POST_SEND_MESSAGE)) {
                SaveSendMessageStatus(methods->GetSendMessageStatus());
            }

            if (methods->QueryInterceptionHookPoint(grpc::experimental::InterceptionHookPoints::POST_RECV_CLOSE)) {
                WriteLog();
            }
        } catch (...) {
            Log->error("Exception: {}", CurrentExceptionMessage());
            Log->error("Backtrace: {}", TBackTrace::FromCurrentException().PrintToString());
        }
    }

    methods->Proceed();
}

void TLoggingInterceptor::SaveRequestMetadata(std::multimap<grpc::string_ref, grpc::string_ref>* metadata) {
    Writer.Write(PEER, Info->server_context()->peer());
    WriteMetadata(metadata, REQUEST_METADATA, Writer);
}

void TLoggingInterceptor::SaveRequestMessage(void* message) {
    WriteMessage(message, REQUEST, Writer);
}

void TLoggingInterceptor::SaveResponseMetadata(std::multimap<TString, TString>* metadata) {
    WriteMetadata(metadata, RESPONSE_METADATA, Writer);
}

void TLoggingInterceptor::SaveResponseMessage(const void* message) {
    WriteMessage(message, RESPONSE, Writer);
}

void TLoggingInterceptor::SaveStatus(grpc::Status status) {
    Writer.OpenMap(STATUS);
    Writer.Write(CODE, grpc_status_code_to_string(static_cast<grpc_status_code>(status.error_code())));
    Writer.Write(MESSAGE, status.error_message());
    Writer.CloseMap();
}

void TLoggingInterceptor::SaveSendMessageStatus(bool status) {
    Writer.Write(SEND_STATUS, status);
}

void TLoggingInterceptor::WriteLog() {
    Writer.CloseMap();
    Writer.Flush();
    LogHandler(std::move(Stream.Str()));
}

void TLoggingInterceptor::SaveMethod(const TString& method) {
    if (StringSplitter(method).Split(METHOD_DELIMITER).SkipEmpty().Limit(2).TryCollectInto(&Service, &Handle)) {
        Writer.Write(SERVICE, Service);
        Writer.Write(HANDLE, Handle);
    }
}
