#include "server.h"

#include <balancer/server/server.h>
#include <library/cpp/coroutine/engine/helper.h>

#include <util/system/thread.h>

using namespace NBalancerServer;

namespace NYP::NServiceDiscovery::NTesting {
    struct TSDServer::TImpl {
        THolder<TStandaloneServer> Server;
        THolder<TThread> Thread;

        struct THash {
            size_t operator()(const NApi::TReqResolveEndpoints& req) {
                return MultiHash(req.cluster_name(), req.endpoint_set_id());
            }
        };

        struct TEqual {
            bool operator()(const NApi::TReqResolveEndpoints& lhs, const NApi::TReqResolveEndpoints& rhs) {
                return lhs.cluster_name() == rhs.cluster_name() && lhs.endpoint_set_id() == rhs.endpoint_set_id();
            }
        };

        THashMap<NApi::TReqResolveEndpoints, NApi::TRspResolveEndpoints, THash, TEqual> Map;
        TMutex Lock;

        TAtomic RequestCount = 0;
    public:
        TImpl(const TOptions& sdServerOptions) {
            NBalancerServer::TOptions bServerOptions;
            bServerOptions.Port = sdServerOptions.Port;
            bServerOptions.NetworkThreads = sdServerOptions.Threads;

            Server = MakeHolder<TStandaloneServer>(
                [this](THttpRequestEnv& request) {
                    return ServeRequest(request);
                },
                bServerOptions
            );

            Thread = MakeHolder<TThread>(
                [this]() {
                    Server->Run();
                }
            );
        }

        NSrvKernel::TError ServeRequest(THttpRequestEnv& request) {
            AtomicIncrement(RequestCount);

            TString body = request.Body();
            NApi::TReqResolveEndpoints resolveRequest;

            if (!resolveRequest.ParseFromString(body)) {
                return Y_MAKE_ERROR(NSrvKernel::THttpError(400) << "can't parse request");
            }

            NApi::TRspResolveEndpoints resolveResponse;

            with_lock (Lock) {
                resolveResponse = Map[resolveRequest];
            }

            TString response = resolveResponse.SerializeAsStringOrThrow();

            auto reply = request.GetReplyTransport();

            NSrvKernel::TResponse head(200, "Ok");
            head.Props().ContentLength = response.size();

            reply->SendHead(std::move(head));
            reply->SendData(response);
            reply->SendEof();

            return {};
        }

        void Set(const NApi::TReqResolveEndpoints& req, const NApi::TRspResolveEndpoints& rsp) {
            with_lock (Lock) {
                Map[req] = rsp;
            }
        }

        void Remove(const NApi::TReqResolveEndpoints& req) {
            with_lock (Lock) {
                Map.erase(req);
            }
        }

        void Start() {
            Thread->Start();
        }

        void Stop() {
            Server->Shutdown();
            Thread->Join();
        }
    };

    TSDServer::TSDServer(const TOptions& options)
        : Impl_(MakeHolder<TImpl>(options))
        , Options_(options)
    {
    }

    TSDServer::~TSDServer() {
        Impl_->Stop();
    }

    void TSDServer::Start(TDuration timeout) {
        Impl_->Start();
        Y_ENSURE(NCoro::WaitUntilConnectable("localhost", Options_.Port, timeout));
    }

    void TSDServer::Stop() {
        Impl_->Stop();
    }

    void TSDServer::Set(const NApi::TReqResolveEndpoints& request, const NApi::TRspResolveEndpoints& response) {
        Impl_->Set(request, response);
    }

    void TSDServer::Remove(const NApi::TReqResolveEndpoints& request) {
        Impl_->Remove(request);
    }

    size_t TSDServer::GetRequestCount() const {
        return AtomicGet(Impl_->RequestCount);
    }
}
