#pragma once

#include "config.h"

#include <passport/infra/libs/cpp/request/http/request.h>
#include <passport/infra/libs/cpp/utils/atomic.h>
#include <passport/infra/libs/cpp/utils/log/global.h>

#include <library/cpp/containers/stack_vector/stack_vec.h>
#include <library/cpp/http/server/response.h>

#include <util/generic/string.h>
#include <util/generic/yexception.h>
#include <util/stream/format.h>
#include <util/system/thread.h>

namespace NPassport::NDaemon {
    template <class T>
    class TReplier: public TRequestReplier {
    public:
        TReplier(T& s,
                 NUtils::TAtomicNum<>& busy,
                 TDuration requestTimeout,
                 const TConfig::TResponseTimeHeaders& headerNames)
            : S_(s)
            , Busy_(busy)
            , RequestStart_(TInstant::Now())
            , RequestTimeout_(requestTimeout)
            , HeaderNames_(headerNames)
        {
        }

        struct TBusyHolder {
            NUtils::TAtomicNum<>& Busy;
            TBusyHolder(NUtils::TAtomicNum<>& busy)
                : Busy(busy)
            {
                ++Busy;
            }
            ~TBusyHolder() {
                --Busy;
            }
        };

        bool DoReply(const TReplyParams& params) override {
            TBusyHolder h(Busy_);

            const TDuration inQueueTime = TInstant::Now() - RequestStart_;
            if (inQueueTime > RequestTimeout_) {
                TLog::Warning("DoReply(): Waiting in queue was too long");
                GetErrorResp("Waiting in queue was too long", inQueueTime, {}).OutTo(params.Output);
                params.Output.Finish();
                return true;
            }

            const TInstant inHandlerStart = TInstant::Now();
            try {
                NPassport::NHttp::TRequest req(params, *this);
                S_.HandleRequest(req);
                req.Flush(GetResponseTimes(inQueueTime, inHandlerStart));
            } catch (const NPassport::NHttp::TRequest::TMethodException& e) {
                TLog::Warning("Unsupported HTTP method: %s", e.what());

                THttpResponse resp(HTTP_BAD_REQUEST);
                resp.SetContent("Unsupported HTTP method");
                AddResponseTimes(resp, inQueueTime, inHandlerStart);
                resp.OutTo(params.Output);
            } catch (const std::exception& e) {
                TLog::Warning("DoReply() exception: %s", e.what());
                GetErrorResp(e.what(), inQueueTime, inHandlerStart).OutTo(params.Output);
            } catch (...) {
                TLog::Error("DoReply() exception: Unexpected exception");
                GetErrorResp("Unexpected exception", inQueueTime, inHandlerStart).OutTo(params.Output);
            }

            params.Output.Finish();
            return true;
        }

    private:
        THttpResponse GetErrorResp(const char* msg, TDuration inQueueTime, TInstant inHandlerStart) const {
            THttpResponse resp(HTTP_INTERNAL_SERVER_ERROR);
            resp.SetContent(TString("Unexpected error: ") + msg);
            AddResponseTimes(resp, inQueueTime, inHandlerStart);
            return resp;
        }

        void AddResponseTimes(THttpResponse& resp, TDuration inQueueTime, TInstant inHandlerStart) const {
            for (const THttpInputHeader& h : GetResponseTimes(inQueueTime, inHandlerStart)) {
                resp.AddHeader(h);
            }
        }

        TStackVec<THttpInputHeader, 2> GetResponseTimes(TDuration inQueueTime, TInstant inHandlerStart) const {
            TStackVec<THttpInputHeader, 2> res;

            if (HeaderNames_.InQueue) {
                res.push_back(THttpInputHeader(
                    HeaderNames_.InQueue,
                    TimeToString(inQueueTime)));
            }

            if (HeaderNames_.InHandler) {
                res.push_back(THttpInputHeader(
                    HeaderNames_.InHandler,
                    inHandlerStart == TInstant()
                        ? EMPTY_HEADER
                        : TimeToString(TInstant::Now() - inHandlerStart)));
            }

            return res;
        }

        static TString TimeToString(TDuration time) {
            if (time == TDuration()) {
                return EMPTY_HEADER;
            }

            return TStringBuilder() << Prec(time.MicroSeconds() / 1000000., PREC_POINT_DIGITS, 4);
        }

    private:
        T& S_;
        NUtils::TAtomicNum<>& Busy_;
        const TInstant RequestStart_;
        const TDuration RequestTimeout_;
        const TConfig::TResponseTimeHeaders& HeaderNames_;

        static const inline TString EMPTY_HEADER = "-";
    };

    class TCallBackBase: public THttpServer::ICallBack {
    public:
        struct TCounters {
            NUtils::TAtomicNum<> All;
            NUtils::TAtomicNum<> Busy;
            NUtils::TAtomicNum<> FailRequest;
            NUtils::TAtomicNum<> Exception;
            NUtils::TAtomicNum<> MaxCon;
        };

        const TCounters& GetStats() const {
            return Counters_;
        }

    protected:
        TCounters Counters_;
    };

    template <class T>
    class TCallBack: public TCallBackBase {
    public:
        TCallBack(T& s,
                  const TString& name,
                  TDuration requestTimeout,
                  const TConfig::TResponseTimeHeaders& headerNames)
            : S_(s)
            , ThreadName_(name)
            , RequestTimeout_(requestTimeout)
            , HeaderNames_(headerNames)
        {
        }

        TClientRequest* CreateClient() override {
            ++Counters_.All;
            return new TReplier<T>(S_, Counters_.Busy, RequestTimeout_, HeaderNames_);
        }

        void OnFailRequest(int failstate) override {
            ++Counters_.FailRequest;
            TLog::Warning("OnFailRequest(): %s: %d", ThreadName_.c_str(), failstate);
        }

        void OnException() override {
            ++Counters_.Exception;
            TLog::Warning("OnException(): %s: %s", ThreadName_.c_str(), CurrentExceptionMessage().c_str());
        }

        void OnMaxConn() override {
            ++Counters_.MaxCon;
            TLog::Warning("OnMaxConn(): %s: limit of connections reached", ThreadName_.c_str());
        }

    private:
        T& S_;
        const TString ThreadName_;
        const TDuration RequestTimeout_;
        const TConfig::TResponseTimeHeaders HeaderNames_;
    };

}
