#include "web_server.h"

#include <solomon/libs/cpp/threading/pool/pool.h>
#include <solomon/libs/cpp/threading/metrics/thread_pool_metrics.h>
#include <util/datetime/cputimer.h>
#include <contrib/libs/re2/re2/re2.h>

namespace NMonitoring {
    static THolder<THttpServer> StartHttp(const THttpServerOptions& options,
                                          THttpServer::ICallBack* callBack,
                                          THttpServer::TMtpQueueRef processingQueue) {
        class TFailedHttpRequestMtpQueue: public TThreadPool {
            void Start(size_t, size_t) override {
                // always start with 1 thread without limit
                TThreadPool::Start(1, 0);
            }
        };

        THttpServer::TMtpQueueRef failedQueue(new TFailedHttpRequestMtpQueue());
        THolder<THttpServer> httpServer(new THttpServer(callBack, processingQueue, failedQueue, options));

        if (!httpServer->Start()) {
            ythrow yexception() << "web server fail to start: " << httpServer->GetError();
        }

        return httpServer;
    }

    TWebReplier::TWebReplier(TWebServer* server, TLog& logger)
        : Server(server)
        , Ready(false)
        , Logger(logger) {
    }

    bool TWebReplier::DoReply(const TRequestReplier::TReplyParams& params) {
        if (Ready) {
            return true;
        }
        const TParsedHttpFull parsedLine(params.Input.FirstLine());
        TPathEndpoint* handler = Server->GetReplier(parsedLine);
        TRegexpEndpoint* regexpHandler;
        TVector<TString> pathGroups;
        TMap<int, TString> groupNames;

        if (handler == nullptr) {
            regexpHandler = Server->GetRegexpReplier(parsedLine, pathGroups, groupNames);
        }
        if (handler == nullptr && regexpHandler == nullptr) {
            params.Output << THttpResponse(HTTP_NOT_FOUND).SetContent("Service not found!");
            return true;
        } else {
            auto* metrics = handler ? handler->Metrics : regexpHandler->Metrics;

            char addressData[INET6_ADDRSTRLEN];
            addressData[0] = 0;
            auto gotAddress = GetRemoteAddr(Socket(), addressData, sizeof(addressData));
            TServiceRequest::TRef request(MakeAtomicShared<TServiceRequest>(params, metrics, gotAddress ? addressData : ""));

            TSimpleTimer timer;
            try {
                if (handler != nullptr) {
                    handler->Replier->DoReply(request, parsedLine);
                } else {
                    regexpHandler->Replier->DoReply(request, parsedLine, pathGroups, groupNames);
                }
            } catch (const TTooManyRequestsError& e) {
                request->Finish(THttpResponse(HttpCodes::HTTP_TOO_MANY_REQUESTS).SetContent(e.what()));
            } catch (const TNotFoundError& e) {
                request->Finish(THttpResponse(HttpCodes::HTTP_NOT_FOUND).SetContent(e.what()));
            } catch (const TBadRequest& e) {
                request->Finish(THttpResponse(HttpCodes::HTTP_BAD_REQUEST).SetContent(e.what()));
            } catch (const TUnauthorizedError& e) {
                request->Finish(THttpResponse(HttpCodes::HTTP_UNAUTHORIZED).SetContent(e.what()));
            } catch (const TForbiddenError& e) {
                request->Finish(THttpResponse(HttpCodes::HTTP_FORBIDDEN).SetContent(e.what()));
            } catch (const TServiceUnavailableError& e) {
                request->Finish(THttpResponse(HttpCodes::HTTP_SERVICE_UNAVAILABLE).SetContent(e.what()));
            } catch (const yexception& e) {
                request->Finish(THttpResponse(HttpCodes::HTTP_INTERNAL_SERVER_ERROR).SetContent(e.what()));
                Logger << TLOG_ERR << "HTTP call to " << parsedLine.Path << " failed: " << CurrentExceptionMessage();
            } catch (...) {
                request->Finish(THttpResponse(HttpCodes::HTTP_INTERNAL_SERVER_ERROR)
                                    .SetContent("Exception during request processing"));
                Logger << TLOG_ERR << "HTTP call to " << parsedLine.Path << " failed: " << CurrentExceptionMessage();
            }

            auto reqTime = timer.Get();
            auto level = Server->GetSlowLogDuration() && reqTime > Server->GetSlowLogDuration() ? TLOG_WARNING : TLOG_DEBUG;
            Logger << level << "HTTP call to " << parsedLine.Path << " processed in " << reqTime;

            request->GetFuture().Subscribe([this](const TServiceRequest::TFuture&) {
                // Copy Log pointer for further error logging because `Finish()` will release `this`
                TLog& logger = Logger;
                try {
                    Finish();
                } catch (...) {
                    logger << TLOG_ERR << "Can't finish request: " << CurrentExceptionMessage();
                }
            });
            return false;
        }
    }

    TWebServer::TWebServer(const THttpServerOptions& options, THolder<IThreadPool> pool, TLog& logger)
        : Logger(logger)
        , HttpOptions(options)
        , ProcessingQueue(std::move(pool))
        , ServerMetrics(*NMonitoring::TMetricRegistry::Instance()) {
    }

    TWebServer::TWebServer(const THttpServerOptions& options, TLog& logger)
        : Logger(logger)
        , HttpOptions(options)
        , ServerMetrics(*NMonitoring::TMetricRegistry::Instance()) {
        auto name = TStringBuilder() << "WebServer-" << options.Port;
        auto metrics = std::make_shared<NSolomon::TThreadPoolMetrics>(name, *NMonitoring::TMetricRegistry::Instance());
        ProcessingQueue = NSolomon::CreateThreadPoolProxyUnique(std::make_unique<TThreadPool>(), std::move(metrics)).release();
    }

    TWebServer::~TWebServer() {
        HttpServer.Reset();
        ProcessingQueue.Reset();
    }

    TClientRequest* TWebServer::CreateClient() {
        return new TWebReplier(this, Logger);
    }

    void* TWebServer::ThreadProc() noexcept {
        Y_ASSERT(HttpServer);
        try {
            HttpServer->Wait();
        } catch (...) {
            Logger << TLOG_ERR << "Unhandled exception: " << CurrentExceptionMessage();
        }
        return this;
    }

    void TWebServer::Stop() {
        auto guard = Guard(ThreadMutex);
        if (HttpServer) {
            HttpServer->Shutdown();
            Join();
            HttpServer.Reset();
        }
    }

    void TWebServer::Add(const TString& path, IServiceReplier& service) {
        auto* metrics = ServerMetrics.Metrics(path);
        Services[path] = MakeHolder<TPathEndpoint>(metrics, &service);
    }

    void TWebServer::AddRegexp(const TString& reString, const TString& metricPath, IServiceRegexpReplier& service) {
        THolder<RE2> re = MakeHolder<RE2>(reString, RE2::Quiet);
        if (!re->ok()) {
            Logger << TLOG_WARNING << "Bad regexp \"" << reString << "\": " << re->error();
            return;
        }
        MaxNumberOfCapturingGroups = Max(MaxNumberOfCapturingGroups, re->NumberOfCapturingGroups());
        auto* metrics = ServerMetrics.Metrics(metricPath);
        ServicesRegexp.push_back(MakeHolder<TRegexpEndpoint>(metrics, std::move(re), &service));
    }

    TPathEndpoint* TWebServer::GetReplier(const TParsedHttpFull& parsedHttp) {
        auto it = Services.find(parsedHttp.Path);
        return it.IsEnd() ? nullptr : it->second.Get();
    }

    TRegexpEndpoint* TWebServer::GetRegexpReplier(const TParsedHttpFull& parsedHttp, TVector<TString>& pathGroups, TMap<int, TString>& groupNames) {
        TVector<TString> matches(MaxNumberOfCapturingGroups);
        RE2::Arg argv[MaxNumberOfCapturingGroups];
        RE2::Arg* args[MaxNumberOfCapturingGroups];
        for (int i = 0; i < MaxNumberOfCapturingGroups; i++) {
            argv[i] = &matches[i];
            args[i] = &argv[i];
        }
        for (auto& endpoint : ServicesRegexp) {
            auto& re = endpoint->Regexp;
            int numberOfCapturingGroups = re->NumberOfCapturingGroups();
            if (RE2::FullMatchN(TString(parsedHttp.Path), *re, args, numberOfCapturingGroups)) {
                matches.resize(numberOfCapturingGroups);
                pathGroups = matches;
                for (auto& [k, v] : re->CapturingGroupNames()) {
                    groupNames[k] = v;
                }
                return endpoint.Get();
            }
        }
        return nullptr;
    }

    void TWebServer::StartServing() {
        auto guard = Guard(ThreadMutex);
        ProcessingQueue->Start(HttpOptions.nThreads, HttpOptions.MaxQueueSize);
        HttpServer.Reset(StartHttp(HttpOptions, this, ProcessingQueue));
        Start();
    }
} // namespace NMonitoring
