#pragma once

#include <travel/hotels/lib/cpp/logging/common_logger.h>
#include <travel/hotels/lib/cpp/mon/counter.h>
#include <travel/hotels/lib/cpp/mon/counter_hypercube.h>
#include <travel/hotels/lib/cpp/mon/tools.h>
#include <travel/hotels/lib/cpp/tvm/tvm_service.h>
#include <travel/hotels/lib/cpp/util/flag.h>
#include <travel/hotels/lib/cpp/util/profiletimer.h>
#include <travel/hotels/lib/cpp/grpc/reqans_logrecords.pb.h>
#include <travel/proto/commons.pb.h>

#include <grpc/grpc.h>
#include <grpc++/completion_queue.h>
#include <grpc++/server.h>
#include <grpc++/server_builder.h>
#include <grpc++/server_context.h>
#include <grpc++/support/async_stream.h>
#include <grpc++/support/async_unary_call.h>

#include <library/cpp/logger/global/global.h>
#include <library/cpp/json/json_value.h>
#include <library/cpp/json/json_writer.h>

#include <util/thread/factory.h>
#include <util/thread/pool.h>
#include <util/datetime/base.h>
#include <library/cpp/deprecated/atomic/atomic.h>
#include <util/system/mutex.h>
#include <util/system/hostname.h>
#include <util/string/builder.h>
#include <util/generic/ptr.h>
#include <util/generic/guid.h>
#include <contrib/libs/protobuf/src/google/protobuf/util/time_util.h>

#include <functional>
#include <memory>

namespace NTravel {
namespace NGrpc {

#define REGISTER_GRPC_HANDLER(server, grpc_method_name, owner, method) server.AddHandler(&decltype(server)::TAsyncService::Request##grpc_method_name, #grpc_method_name, owner, method)

struct TAsyncServerConfig {
    TString BindAddress;
    size_t ReplyThreadCount;

    TAsyncServerConfig(const TString& bindAddress, size_t replyThreadCount = 8)
        : BindAddress(bindAddress)
        , ReplyThreadCount(replyThreadCount)
    {}
};

struct TServerReqMetadata {
    TString CallId;
    TString RemoteFQDN;
    TString ServiceTicket;
    TString StartedAt;
    TString Timeout;
    ui64 DeadlineTimestampMs;
    TString DeadlineLocalTime;
};

class TServerRespMetadata {
public:
    static TServerRespMetadata BuildOk() {
        return TServerRespMetadata();
    }
    static TServerRespMetadata BuildOk(const TString& logPrefix) {
        auto res = TServerRespMetadata();
        res.LogPrefix = logPrefix;
        return res;
    }
    static TServerRespMetadata BuildFailed(const grpc::Status& status) {
        return BuildFailed(TString(), status);
    }

    static TServerRespMetadata BuildFailed(const TString& logPrefix, const grpc::Status& status) {
        Y_ENSURE(!status.ok(), "Expected not ok status");
        auto res = TServerRespMetadata();
        res.Status = status;
        res.LogPrefix = logPrefix;
        return res;
    }
    static TServerRespMetadata BuildFailed(const grpc::Status& status, const NTravelProto::TError& error) {
        Y_ENSURE(!status.ok(), "Expected not ok status");
        auto res = TServerRespMetadata();
        res.Status = status;
        res.TrailingMeta = {{METADATA_ERROR_KEY, error.SerializeAsString()}};
        return res;
    }

    const TString& GetLogPrefix() const {
        return LogPrefix;
    }

    const grpc::Status& GetStatus() const {
        return Status;
    }

    const THashMap<TString, TString>& GetTrailingMeta() const {
        return TrailingMeta;
    }


private:
    const static TString METADATA_ERROR_KEY;

    TServerRespMetadata() {
    }

    TString LogPrefix;

    THashMap<TString, TString> TrailingMeta;
    grpc::Status Status = grpc::Status::OK;
};

struct TAsyncServerCountersPerMethod: public NMonitor::TCounterSource {
    TAsyncServerCountersPerMethod();

    NMonitor::TCounter NActiveProcessThreads;
    NMonitor::TCounter NLimitProcessThreads;
    NMonitor::TCounter NActiveJobs;

    NMonitor::TDerivCounter NRequests;
    NMonitor::TDerivCounter NRespOK;
    NMonitor::TDerivCounter NRespInvalidArgument;
    NMonitor::TDerivCounter NRespError;

    NMonitor::TDerivHistogramCounter TimeWaitingForRequestMs;
    NMonitor::TDerivHistogramCounter TimeWaitingForProcessingMs;
    NMonitor::TDerivHistogramCounter TimeProcessingRequestMs;
    NMonitor::TDerivHistogramCounter TimeSendingResponseMs;

    NMonitor::TDerivCounter NTvmNoTicket;
    NMonitor::TDerivCounter NTvmTicketRefused;

    void QueryCounters(NMonitor::TCounterTable* ct) const override;
};

struct TAsyncServerCounters: public NMonitor::TCounterSource {
    NMonitor::TDerivCounter NCompletionQueueMisses;

    void QueryCounters(NMonitor::TCounterTable* ct) const override;
};

struct TCountHolder {
    explicit TCountHolder(NMonitor::TCounter* counter) : Counter(counter){
        Counter->Inc();
    }
    ~TCountHolder() {
        Counter->Dec();
    }
    NMonitor::TCounter* Counter;
};

class TAutoReqAnsLogWriter {
 public:
    using TReqAnsLogger = TCommonLogger<TString>;

    TAutoReqAnsLogWriter(TReqAnsLogger* logger,
                         TInstant startedAt,
                         const TString& methodName,
                         const ::google::protobuf::Message& req,
                         const THashSet<TString>& reqAnsLoggerBannedMethods)
        : Logger(logger)
        , StartedAt(startedAt)
    {
        if (!Logger) {
            return;
        }
        if (reqAnsLoggerBannedMethods.contains(methodName)) {
            Logger = nullptr;
            return;
        }

        TInstant now = Now();

        LogRecord.set_unixtime(now.Seconds());
        LogRecord.set_local_time(now.ToIsoStringLocal());

        LogRecord.SetMethodName(methodName);

        Req = THolder(static_cast<google::protobuf::Message*>(req.New()));
        Req->CopyFrom(req);

        if (auto mutableInfo = LogRecord.MutableInfo()) {
            mutableInfo->SetLocalHostName(FQDNHostName());
            mutableInfo->SetLocalStartedAt(StartedAt.ToIsoStringLocal());
        }
    }

    ~TAutoReqAnsLogWriter() {
        WriteReqAnsLog();
    }

    void SetReqMeta(const NTravel::NGrpc::TServerReqMetadata& srvReqMeta) {
        if (!Logger) {
            return;
        }

        if (auto mutableInfo = LogRecord.MutableInfo()) {
            mutableInfo->SetCallId(srvReqMeta.CallId);
            mutableInfo->SetRemoteFQDN(srvReqMeta.RemoteFQDN);
            mutableInfo->SetRemoteStartedAt(srvReqMeta.StartedAt);
            mutableInfo->SetTimeout(srvReqMeta.Timeout);
            mutableInfo->SetDeadlineTimestampMs(srvReqMeta.DeadlineTimestampMs);
            mutableInfo->SetDeadlineLocalTime(srvReqMeta.DeadlineLocalTime);
        }
    }

    void SetResp(const ::google::protobuf::Message& resp, const NTravel::NGrpc::TServerRespMetadata& srvRespMeta) {
        if (!Logger) {
            return;
        }
        TInstant now = Now();

        Resp = THolder(static_cast<google::protobuf::Message*>(resp.New()));
        Resp->CopyFrom(resp);

        if (auto mutableInfo = LogRecord.MutableInfo()) {
            mutableInfo->SetDurationMicros(now.MicroSeconds() - StartedAt.MicroSeconds());
        }

        LogRecord.SetStatusErrorCode(srvRespMeta.GetStatus().error_code());
        LogRecord.SetStatusErrorMessage(TString(srvRespMeta.GetStatus().error_message()));
    }

    void WriteReqAnsLog() const {
        if (!Logger) {
            return;
        }

        NProtobufJson::TProto2JsonConfig cfg;
        cfg.SetFormatOutput(false);
        cfg.SetMapAsObject(true);
        cfg.SetEnumMode(NProtobufJson::TProto2JsonConfig::EnumName);
        auto jsonRec = NJson::TJsonValue();
        Proto2Json(LogRecord, jsonRec, cfg);

        if (Req) {
            auto jsonReq = NJson::TJsonValue();
            Proto2Json(*Req, jsonReq, cfg);
            jsonRec.InsertValue("Req", jsonReq);
        } else {
            jsonRec.EraseValue("Req");
        }

        if (Resp) {
            auto jsonResp = NJson::TJsonValue();
            Proto2Json(*Resp, jsonResp, cfg);
            jsonRec.InsertValue("Resp", jsonResp);
        } else {
            jsonRec.EraseValue("Resp");
        }

        auto strRec = NJson::WriteJson(jsonRec, false);
        Logger->AddRecord(strRec);
    }

 private:
    NTravelProto::NGrpcReqAnsLog::TReqAnsLogRecord LogRecord;
    THolder<google::protobuf::Message> Req;
    THolder<google::protobuf::Message> Resp;
    TReqAnsLogger* Logger;
    TInstant StartedAt;
};

TString GetMetadata(const grpc::ServerContext& context, const grpc::string_ref& name, const TString& defValue);

template <class TGrpcAsyncService>
class TAsyncServer : public IThreadFactory::IThreadAble {
public:
    using TAsyncService = TGrpcAsyncService;
    template <class TGrpcResponse>
    using TResponseCb = std::function<void (const TGrpcResponse& resp, const TServerRespMetadata& meta)>;

    template <class TGrpcRequest, class TGrpcResponse>
    using TOnRequest = std::function<void (const TGrpcRequest& req, const TServerReqMetadata& metadata, TResponseCb<TGrpcResponse> respCb)>;

    template <class TGrpcRequest, class TGrpcResponse>
    using TReqFunction = std::function<void (TGrpcAsyncService& service, ::grpc::ServerContext* context, TGrpcRequest* request,  ::grpc::ServerAsyncResponseWriter<TGrpcResponse>* response, ::grpc::CompletionQueue* cq1, ::grpc::ServerCompletionQueue* cq2, void *tag)>;

    using TReqAnsLogger = TAutoReqAnsLogWriter::TReqAnsLogger;

    TAsyncServer(const TAsyncServerConfig& config)
        : Config_(config)
        , CountersPerMethod_({"request_type"})
        , TvmService_(nullptr)
        , ReqAnsLogger_(nullptr)
        , LastJobId_(0)
    {
    }

    ~TAsyncServer() {
        Stop();
    }

    void SetTvm(const NTravel::NTvm::TTvmService* tvmService) {
        TvmService_ = tvmService;
    }

    void SetReqAnsLogger(TReqAnsLogger* logger, std::initializer_list<TString> bannedMethods) {
        ReqAnsLogger_ = logger;
        ReqAnsLoggerBannedMethods_ = THashSet<TString>(bannedMethods);
    }

    void SetReqAnsLogger(TReqAnsLogger* logger) {
        SetReqAnsLogger(logger, {});
    }

    template<class TLocalGrpcAsyncService, class TGrpcRequest, class TGrpcResponse, class TOwner>
    void AddHandler(void (TLocalGrpcAsyncService::*rpcCallFunc)(::grpc::ServerContext*,
                                                                TGrpcRequest*,
                                                                ::grpc::ServerAsyncResponseWriter<TGrpcResponse>*,
                                                                ::grpc::CompletionQueue*,
                                                                ::grpc::ServerCompletionQueue*,
                                                                void*),
                    const TString& methodName,
                    TOwner* owner,
                    void (TOwner::*membFunc)(const TGrpcRequest& request, const TServerReqMetadata& meta, const TResponseCb<TGrpcResponse>& responseCb)) {
        TJobBaseRef job = new TJob<TGrpcRequest, TGrpcResponse>(
            *this,
            [rpcCallFunc](TGrpcAsyncService& service,
                          ::grpc::ServerContext* context,
                          TGrpcRequest* request,
                          ::grpc::ServerAsyncResponseWriter<TGrpcResponse>* response,
                          ::grpc::CompletionQueue* cq1,
                          ::grpc::ServerCompletionQueue* cq2,
                          void* tag) {
                (service.*rpcCallFunc)(context, request, response, cq1, cq2, tag);
            },
            [owner, membFunc](const TGrpcRequest& request, const TServerReqMetadata& meta, const TResponseCb<TGrpcResponse>& responseCb) {
                (owner->*membFunc)(request, meta, responseCb);
            },
            methodName);
        RegisterJob(job);
    }

    void Start() {
        grpc::ServerBuilder builder;
        builder.AddListeningPort(Config_.BindAddress, grpc::InsecureServerCredentials());
        builder.RegisterService(&Service_);
        CQ_ = builder.AddCompletionQueue();
        Server_ = builder.BuildAndStart();
        for (auto it = Jobs_.begin(); it != Jobs_.end(); ++it) {
            it->second->OnRegistered(it->first);
        }
        StartedFlag_.Set();
        INFO_LOG << "gRPC server started at " << Config_.BindAddress << Endl;
        ReplyThreads_.Start(Config_.ReplyThreadCount);
        Thread_ = SystemThreadFactory()->Run(this);
    }

    void Shutdown() {
        StopFlag_.Set();
        Server_->Shutdown();
        CQ_->Shutdown();
    }

    void Stop() {
        if (!Thread_) {
            return;
        }
        Shutdown();
        Thread_->Join();
        Thread_ .Reset();
        ReplyThreads_.Stop();
    }

    void RegisterCounters(NMonitor::TCounterSource& source, const TString& name) {
        source.RegisterSource(&Counters_, name);
        source.RegisterSource(&CountersPerMethod_, name + "PerMethod");
    }

private:
    struct TJobBase: public TThrRefBase {
        virtual void OnRegistered(size_t jobId) = 0;
        virtual void OnComplete() = 0;
    };
    using TJobBaseRef = TIntrusivePtr<TJobBase>;

    template <class TGrpcRequest, class TGrpcResponse>
    struct TJob : public TJobBase  {
        TAsyncServer& Server;
        TReqFunction<TGrpcRequest, TGrpcResponse> ReqFunc;
        TOnRequest<TGrpcRequest, TGrpcResponse> OnRequest;
        bool WaitingRequest;
        size_t JobId;
        grpc::ServerContext Context;
        TGrpcRequest Request;
        grpc::ServerAsyncResponseWriter<TGrpcResponse> Responder;
        TProfileTimer Timer;
        TAtomicSharedPtr<TAsyncServerCountersPerMethod> Counters_;
        TCountHolder JobCountHolder;
        TString MethodName;

        TJob(TAsyncServer& server, TReqFunction<TGrpcRequest, TGrpcResponse> asyncReqFunc, TOnRequest<TGrpcRequest, TGrpcResponse> onRequest, const TString& methodName)
            : Server(server)
            , ReqFunc(asyncReqFunc)
            , OnRequest(onRequest)
            , WaitingRequest(true)
            , Responder(&Context)
            , Counters_(Server.CountersPerMethod_.GetOrCreate({methodName}))
            , JobCountHolder(&Counters_->NActiveJobs)
            , MethodName(methodName)
        {
        }

        void OnRegistered(size_t jobId) override {
            JobId = jobId;
            if (WaitingRequest) {
                ReqFunc(Server.Service_, &Context, &Request, &Responder, Server.CQ_.get(), Server.CQ_.get(), (void*)JobId);
            }
        }

        void OnComplete() override {
            Counters_->NLimitProcessThreads = Server.Config_.ReplyThreadCount;
            if (WaitingRequest) {
                Counters_->TimeWaitingForRequestMs.Update(Timer.Step().MilliSeconds());
                Counters_->NRequests.Inc();
                Server.RegisterJob(new TJob(Server, ReqFunc, OnRequest, MethodName));// новый джоб вместо текущего
                TJobBaseRef baseJob = this;
                Server.ReplyThreads_.SafeAddFunc([baseJob]() {
                    TJob* job = dynamic_cast<TJob*>(baseJob.Get());
                    TCountHolder holder(&job->Counters_->NActiveProcessThreads);
                    job->ProcessRequest();
                });
            } else {
                Counters_->TimeSendingResponseMs.Update(Timer.Step().MilliSeconds());
            }
        }

        void SendReply(const TGrpcResponse& resp, const TServerReqMetadata& reqMeta, const TServerRespMetadata& respMeta, TAutoReqAnsLogWriter& logWriter) {
            WaitingRequest = false;
            Server.RegisterJob(this);
            Responder.Finish(resp, respMeta.GetStatus(), (void*)JobId);
            auto duration = Timer.Step();
            logWriter.SetResp(resp, respMeta);
            if (respMeta.GetStatus().ok()) {
                INFO_LOG << respMeta.GetLogPrefix() << "Successfully processed '" << MethodName << "' request, callId '" << reqMeta.CallId << "' from '" << reqMeta.RemoteFQDN
                         << "', in " << duration << Endl;
                Counters_->NRespOK.Inc();
            } else if (respMeta.GetStatus().error_code() == grpc::INVALID_ARGUMENT) {
                ERROR_LOG << respMeta.GetLogPrefix() << "Invalid argument during processing '" << MethodName << "' request, callId '" << reqMeta.CallId << "' from '" << reqMeta.RemoteFQDN
                          << "', in " << duration << ", message '" << respMeta.GetStatus().error_message() << "'" << Endl;
                Counters_->NRespInvalidArgument.Inc();
            } else {
                ERROR_LOG << respMeta.GetLogPrefix() << "Error during processing '" << MethodName << "' request, callId '" << reqMeta.CallId << "' from '" << reqMeta.RemoteFQDN
                          << "', in " << duration << ", message '" << respMeta.GetStatus().error_message() << "'" << Endl;
                Counters_->NRespError.Inc();
            }
            Counters_->TimeProcessingRequestMs.Update(duration.MilliSeconds());
        }

        void ProcessRequest() {
            Counters_->TimeWaitingForProcessingMs.Update(Timer.Step().MilliSeconds());

            auto startedAt = Now();
            auto autoReqAnsLogWriter = MakeAtomicShared<TAutoReqAnsLogWriter>(Server.ReqAnsLogger_, startedAt, MethodName, Request, Server.ReqAnsLoggerBannedMethods_);

            TServerReqMetadata reqMeta{};
            reqMeta.CallId = GetMetadata(Context, "ya-grpc-call-id", "");
            reqMeta.RemoteFQDN = GetMetadata(Context, "ya-grpc-fqdn", TString(Context.peer()));
            reqMeta.ServiceTicket = GetMetadata(Context, "ya-service-ticket", GetMetadata(Context, "x-ya-service-ticket", ""));
            reqMeta.StartedAt = GetMetadata(Context, "ya-grpc-started-at", "");
            reqMeta.Timeout = GetMetadata(Context, "grpc-timeout", "");
            auto deadlineMs = std::chrono::duration_cast<std::chrono::milliseconds>(Context.deadline().time_since_epoch()).count();
            reqMeta.DeadlineTimestampMs = deadlineMs;
            reqMeta.DeadlineLocalTime = TInstant::MilliSeconds(deadlineMs).ToIsoStringLocal();

            autoReqAnsLogWriter->SetReqMeta(reqMeta);

            Context.AddInitialMetadata("ya-grpc-fqdn", FQDNHostName());
            TJobBaseRef baseJob = this;

            if (Server.TvmService_ && Server.TvmService_->IsEnabled()) {
                bool ok = false;
                TString error = "";
                auto ticket = reqMeta.ServiceTicket;
                if (!ticket) {
                    error = "Request without tvm service ticket (ya-service-ticket or x-ya-service-ticket)";
                    Counters_->NTvmNoTicket.Inc();
                } else if (!Server.TvmService_->IsAllowedServiceTicket(ticket)) {
                    error = "Tvm ticket refused for request";
                    Counters_->NTvmTicketRefused.Inc();
                } else {
                    ok = true;
                }
                if (!ok) {
                    WARNING_LOG << "GRPC: " << error << " (method: " << MethodName
                                << ", callId: " << reqMeta.CallId << ", remoteFQDN: "
                                << reqMeta.RemoteFQDN << ")" << Endl;

                    TJob* job = dynamic_cast<TJob*>(baseJob.Get());
                    auto respMeta = TServerRespMetadata::BuildFailed(grpc::Status(grpc::StatusCode::UNAUTHENTICATED, error));
                    job->SendReply(TGrpcResponse(), reqMeta, respMeta, *autoReqAnsLogWriter);
                    return;
                }
            }

            OnRequest(Request, reqMeta, [baseJob, reqMeta, autoReqAnsLogWriter](const TGrpcResponse& resp, const TServerRespMetadata& respMeta) {
                TJob* job = dynamic_cast<TJob*>(baseJob.Get());
                for (const auto& [key, value]: respMeta.GetTrailingMeta()) {
                    job->Context.AddTrailingMetadata(key, value);
                }
                job->SendReply(resp, reqMeta, respMeta, *autoReqAnsLogWriter);
            });
        }
    };

    const TAsyncServerConfig Config_;

    TGrpcAsyncService Service_;
    std::unique_ptr<grpc::ServerCompletionQueue> CQ_;
    std::unique_ptr<grpc::Server> Server_;
    NMonitor::TCounterHypercube<TAsyncServerCountersPerMethod> CountersPerMethod_;
    TAsyncServerCounters Counters_;
    const NTravel::NTvm::TTvmService* TvmService_;
    TReqAnsLogger* ReqAnsLogger_;
    THashSet<TString> ReqAnsLoggerBannedMethods_;

    TThreadPool ReplyThreads_;
    TAutoPtr<IThreadFactory::IThread> Thread_;
    TAtomicFlag StopFlag_;
    TAtomicFlag StartedFlag_;

    // In thread
    TMutex Lock_;
    size_t LastJobId_;
    THashMap<size_t, TJobBaseRef> Jobs_;
private:
    void DoExecute() override {
        void* tag;  // uniquely identifies a request.
        bool ok;
        while (!StopFlag_) {
            if (!CQ_->Next(&tag, &ok) || !ok) {
                Counters_.NCompletionQueueMisses.Inc();
                continue;
            }
            size_t jobId = (size_t)tag;
            TJobBaseRef job;
            with_lock (Lock_) {
                auto it = Jobs_.find(jobId);
                if (it == Jobs_.end()) {
                    ERROR_LOG << "gRPC server job not found by id " << jobId << Endl;
                    continue;
                }
                job = it->second;
                Jobs_.erase(it->first);
            }
            job->OnComplete();
        }
        // Drain da queue
        while (CQ_->Next(&tag, &ok)) {
        }
    }

    void RegisterJob(TJobBaseRef job) {
        size_t jobId;
        with_lock (Lock_) {
            LastJobId_ += 1;
            jobId = LastJobId_;
            Jobs_[jobId] = job;
        }
        if (StartedFlag_) {
            job->OnRegistered(jobId);
        }
    }

    TString GetClientMetadata(const grpc::ServerContext& context, const grpc::string_ref& name, const TString& defValue) {
        auto it = context.client_metadata().find(name);
        if (it == context.client_metadata().end()) {
            return defValue;
        }
        return TString(it->second.data(), it->second.size());
    }

};

}// Namespace NGrpc
}// Namespace NTravel
