#include "checker_export.h"

#include <balancer/kernel/custom_io/concat.h>
#include <balancer/kernel/custom_io/limitio.h>
#include <balancer/kernel/helpers/default_instance.h>
#include <balancer/kernel/http/parser/common_headers.h>
#include <balancer/kernel/http/parser/headers.h>
#include <balancer/kernel/http/parser/httpencoder.h>
#include <balancer/kernel/memory/chunks.h>
#include <balancer/kernel/requester/requester.h>

using namespace NSrvKernel;

namespace {

constexpr TStringBuf FORWARD_TO_USER = "X-ForwardToUser-Y";
constexpr TStringBuf ANTIROBOT_BAN_IP = "X-Antirobot-Ban-Source-Ip";
constexpr TStringBuf FORWARDED_FOR_USER = "X-Forwarded-For-Y";

class TOnlyBufferedInput : public IIoInput {
public:
    TOnlyBufferedInput(IIoInput* slave) noexcept
        : Slave_(slave)
    {}

    TError DoRecv(TChunkList& lst, TInstant) noexcept override {
        if (ShouldEof_) {
            return {};
        }

        Y_TRY(TError, error) {
            return Slave_->Recv(lst, TInstant());
        } Y_CATCH {
            const auto* e = error.GetAs<TSystemError>();
            if (e && e->Status() == ETIMEDOUT) {
                ShouldEof_ = true;
            } else {
                return error;
            }
        }
        return {};
    }

private:
    IIoInput* Slave_ = nullptr;
    bool ShouldEof_ = false;
};

}  // namespace

namespace NModAntiRobot {

TChecker::TChecker(IModule* submodule, const TConnDescr& descr, const NSrvKernel::TRequest& request,
                   size_t cutRequestBytes, bool wrapper, bool cutRequest) noexcept
    : Submodule_(submodule)
    , Descr_(descr)
    , Request_(request)
    , CutRequestBytes_(cutRequestBytes)
    , Wrapper_(wrapper)
    , CutRequest_(cutRequest)
{
}

TError TChecker::RunForRequest(TRequest request, IIoInput& body) noexcept {
    TRequester requester{*Submodule_, Descr_};
    TChunkList responseBody;

    Response_ = {};
    if (auto error = requester.Request(std::move(request), body, Response_, responseBody)) {
        if (Wrapper_) {
            return error;
        } else {
            IsRobot_ = false;
            ShouldBan_ = false;
            return {};
        }
    }


    const auto val =  Response_.Headers().GetFirstValue(FORWARD_TO_USER);
    if (!val || !Match(TTrueFsm::Instance(), val)) {
        Descr_.ExtraAccessLog << " not_robot";
    } else {
        Descr_.ExtraAccessLog << " robot";
        Descr_.ExtraAccessLog << " " << Response_.ResponseLine().StatusCode;
        IsRobot_ = true;
    }

    if (!Wrapper_) {
        Response_.Headers().Delete(FORWARD_TO_USER);
    }

    const auto ban = Response_.Headers().GetFirstValue(ANTIROBOT_BAN_IP);
    if (ban && Match(TTrueFsm::Instance(), ban)) {
        ShouldBan_ = true;
    }

    if (!Wrapper_) {
        Response_.Headers().Delete(ANTIROBOT_BAN_IP);
    }

    if (ShouldBan_) {
        if (Descr_.Request->Headers().GetFirstValue(FORWARDED_FOR_USER) != Descr_.RemoteAddrStr()) {
            ShouldBan_ = false;
            Descr_.ExtraAccessLog << " skip_ban_ip";
        } else {
            Descr_.ExtraAccessLog << " ban_ip";
        }
    }


    if (Wrapper_ || IsRobot_) {
        // TODO: Refactor BALANCER-1716
        Y_TRY(TError, error) {
            Y_PROPAGATE_ERROR(Descr_.Output->SendHead(std::move(Response_), false, TInstant::Max()));
            if (!responseBody.Empty()) {
                Y_PROPAGATE_ERROR(Descr_.Output->Send(std::move(responseBody), TInstant::Max()));
            }
            return Descr_.Output->SendEof(TInstant::Max());
        } Y_CATCH {
            Descr_.ExtraAccessLog << " client error";
            if (Wrapper_) {
                return error;
            }
        }
        SentRedirect_ = true;
    }
    return {};
}

TError TChecker::RunCutRequest(TChunkList requestChunks) noexcept {
    Y_ASSERT(Descr_.Input != nullptr);
    TOnlyBufferedInput onlyBuffered{Descr_.Input};
    TLengthLimitInput limitInput{&onlyBuffered, CutRequestBytes_};
    if (CutRequestBytes_ != Max<size_t>()) {
        TChunkList bodyChunks;
        Y_PROPAGATE_ERROR(RecvAll(&limitInput, bodyChunks, TInstant::Max()));
        requestChunks.Append(std::move(bodyChunks));
        TChunksInput input{std::move(requestChunks)};
        return RunForRequest(Request_, input);
    } else {
        TChunksInput requestInput{std::move(requestChunks)};
        TConcatInput input{&requestInput, &limitInput};
        return RunForRequest(Request_, input);
    }
}

TError TChecker::RunFullRequest(TChunkList requestChunks) noexcept {
    TChunksInput requestInput{std::move(requestChunks)};
    TConcatInput input{&requestInput, Descr_.Input};

    TRequest request = Request_;
    request.Headers().Add("Host", "localhost");

    return RunForRequest(std::move(request), input);
}

TError TChecker::Run() noexcept {
    IsRobot_ = false;

    const TExtraAccessLogEntry logEntry(Descr_, "sub_antirobot");

    TChunksOutputStream requestChunks(4196);
    Descr_.Request->BuildTo(requestChunks);

    const TBaseProperties& baseProps = Descr_.Request->Props();
    if (baseProps.ExplicitConnectionHeader) {
        if (baseProps.KeepAlive) {
            requestChunks << "Connection: Keep-Alive\r\n\r\n";
        } else {
            requestChunks << "Connection: Close\r\n\r\n";
        }
    } else {
        requestChunks << CRLF;
    }

    if (CutRequest_) {
        Y_PROPAGATE_ERROR(RunCutRequest(std::move(requestChunks.Chunks())));
    } else {
        Y_PROPAGATE_ERROR(RunFullRequest(std::move(requestChunks.Chunks())));
    }
    return {};
}

}  // namespace NModAntiRobot
