#include "channel_listener.h"

#include <balancer/kernel/custom_io/channel_io.h>
#include <balancer/kernel/process/thread_info.h>

namespace NSrvKernel {

    TChannelListener::TChannelListener(size_t workersCount)
        : WorkersCount_(workersCount)
        , WorkersChannels_(WorkersCount_ + 1)
    {
        for (size_t i = 1u; i <= WorkersCount_; ++i) {
            WorkersChannels_[i] = MakeHolder<TU2WChannel<IChannelRequest*>>(100);
        }
    }

    TU2WChannel<IChannelRequest*>& TChannelListener::GetChannel(size_t id) {
        Y_ASSERT(0 < id && id < WorkersChannels_.size());
        return *WorkersChannels_[id];
    }

    EChannelStatus TChannelListener::SendRequest(IChannelRequest* request) {
        ui64 requestId = AtomicGetAndIncrement(CurrentRequestId_);
        return ChooseChannel(requestId).Send(std::move(request), TInstant::Max());
    }

    TU2WChannel<IChannelRequest*>& TChannelListener::ChooseChannel(size_t reqId) {
        return *WorkersChannels_[1 + reqId % WorkersCount_];
    }

    void IChannelRequest::RunRequest(IWorkerCtl* worker) {
        Worker_ = worker;
        Cont_ = worker->Executor().Create<IChannelRequest, &IChannelRequest::Run>(this, "channel_request_processor");
    }

    void IChannelRequest::Run(TCont*) {
        try {
            DoRun();
        } catch (const yexception& exc) {
            SetResponseError("", MakeError<yexception>(exc));
        } catch (...) {
            SetResponseError("", MakeError<yexception>(yexception{} << CurrentExceptionMessage()));
        }
    }

    void IChannelRequest::DoRun() {
        // input
        bool useStreaming = InputChannel_.Get() != nullptr;

        std::variant<std::monostate, TChunksInput, TChannelInput> inputHolder;
        IIoInput* inputStream = nullptr;

        if (useStreaming) {
            inputStream = &inputHolder.emplace<TChannelInput>(std::move(InputChannel_), &Worker_->Executor());
        } else {
            inputStream = &inputHolder.emplace<TChunksInput>(TChunkList(std::move(Data)));
        }

        // output
        THolder<IHttpOutput> outputHolder;

        std::variant<std::monostate, TChunksOutput, TChannelOutput> dataOutputHolder;
        IIoOutput* dataOutputStream = nullptr;
        if (OutputChannel_) {
            dataOutputStream = &dataOutputHolder.emplace<TChannelOutput>(OutputChannel_, &Worker_->Executor());
        } else {
            dataOutputStream = &dataOutputHolder.emplace<TChunksOutput>();
        }

        TResponse response;
        THeaders trailers;
        if (AttemptsHolder_->ShouldValidateBody()) {
            auto headOp = [](TResponse&&, const bool, TInstant) {
                return TError{};
            };
            auto bodyOp = [](TChunkList, TInstant) {
                return TError{};
            };
            auto trailersOp = [](THeaders&&, TInstant) {
                return TError{};
            };
            outputHolder = MakeHolder<NSrvKernel::TAnyHttpOutput<decltype(headOp), decltype(bodyOp), decltype(trailersOp)>>(std::move(headOp), std::move(bodyOp), std::move(trailersOp));
        } else {
            auto headOp = [&response](TResponse&& result, const bool, TInstant) {
                response = std::move(result);
                return TError{};
            };
            auto bodyOp = [&](TChunkList lst, TInstant deadline) {
                return dataOutputStream->Send(std::move(lst), deadline);
            };
            auto trailersOp = [&trailers](THeaders&& headers, TInstant) {
                trailers = std::move(headers);
                return TError{};
            };
            outputHolder = MakeHolder<NSrvKernel::TAnyHttpOutput<decltype(headOp), decltype(bodyOp), decltype(trailersOp)>>(std::move(headOp), std::move(bodyOp), std::move(trailersOp));
        }
        IHttpOutput* outputStream = outputHolder.Get();

        Y_VERIFY(outputStream);

        // basic things for conndescr
        TAddrHolder addrHolder(&TDummyAddr::Instance());
        TTcpConnProps tcpConnProps(*Worker_, addrHolder, addrHolder, nullptr);
        TConnProps connProps(tcpConnProps, TInstant::Now(), 0ull);

        TConnDescr descr(*inputStream, *outputStream, connProps);

        descr.HaveFullBody = !useStreaming;
        descr.AttemptsHolder = AttemptsHolder_.Get();
        descr.ClientRequest = ClientRequest_.Get();
        Y_ASSERT(descr.AttemptsHolder);
        descr.Request = &Request;
        if (descr.AttemptsHolder->SrcrwrAddrs() && !descr.AttemptsHolder->SrcrwrAddrs()->Empty()) {
            descr.Properties->SrcrwrAddrs = descr.AttemptsHolder->SrcrwrAddrs();
        }
        descr.ErrorLog = ErrorLog_;

        auto overridenHash = descr.AttemptsHolder->Hash();
        if (!overridenHash.Empty()) {
            descr.Hash = *overridenHash;
        }

        if (auto err = Entry.Run(descr)) {
            SetResponseError(descr.AttemptsHolder->GetEndpoint(), std::move(err));
        } else {
            TString data;
            if (auto* chunksOutput = std::get_if<TChunksOutput>(&dataOutputHolder)) {
                data = ToString(chunksOutput->Chunks());
            }

            SetResponseValue(descr.AttemptsHolder->GetEndpoint(), std::move(response), std::move(data), std::move(trailers));
        }
    }

    IAttemptsHolder& IChannelRequest::GetAttemptsHolder() const {
        return *AttemptsHolder_;
    }
}
