#include "module.h"

#include <balancer/kernel/custom_io/rewind.h>
#include <balancer/kernel/http/parser/header_validation.h>
#include <balancer/kernel/http/parser/common_headers.h>
#include <balancer/kernel/http/parser/httpdecoder.h>
#include <balancer/kernel/http/parser/httpencoder.h>
#include <balancer/kernel/http/parser/http.h>
#include <balancer/kernel/log/errorlog.h>
#include <balancer/kernel/module/module.h>
#include <balancer/kernel/requester/requester.h>

using namespace NConfig;
using namespace NSrvKernel;

Y_TLS(rpcrewrite) {
    bool KillSwitchFileExists() const noexcept {
        return KillSwitchChecker.Exists();
    }

    TSharedFileExistsChecker KillSwitchChecker;
};

MODULE_WITH_TLS(rpcrewrite) {
private:
    static constexpr TStringBuf CHECK_METABALANCER = "X-Metabalancer-Y";

public:
    TModule(const TModuleParams& mp)
        : TModuleBase(mp)
    {
        Config->ForEach(this);

        if (!Url_) {
            ythrow TConfigParseError() << "no url configured for rpcrewrite";
        }

        if (!Host_) {
            ythrow TConfigParseError() << "no host configured for rpcrewrite";
        }

        if (!Rpcmodule_) {
            ythrow TConfigParseError() << "rpc module is not configured";
        }

        if (!Submodule_) {
            ythrow TConfigParseError() << "submodule is not configured";
        }

        InitRpcRequest();
    }

private:
    START_PARSE {
        ON_KEY("url", Url_) {
            return;
        }

        ON_KEY("host", Host_) {
            return;
        }

        ON_KEY("rpc_success_header", RpcSuccessHeader_) {
            if (!CheckHeaderName(RpcSuccessHeader_)) {
                ythrow TConfigParseError() << "\"rpc_success_header\" is not a valid http header";
            } else if (!CheckRestrictedHeaderName(RpcSuccessHeader_)) {
                ythrow TConfigParseError{} << "\"rpc_success_header\" value " <<  RpcSuccessHeader_.Quote()
                    << " contains one of the restricted headers " << RestrictedHeadersListString() << "\n";
            }
            RpcSuccessFsm_.Reset(new TFsm(RpcSuccessHeader_, TFsm::TOptions().SetCaseInsensitive(true)));
            return;
        }

        ON_KEY("dry_run", DryRun_) {
            return;
        }

        ON_KEY("file_switch", KillSwitchFile_) {
            return;
        }

        if (key == "rpc") {
            TSubLoader(Copy(value->AsSubConfig())).Swap(Rpcmodule_);
            return;
        }

        if (key == "on_rpc_error") {
            TSubLoader(Copy(value->AsSubConfig())).Swap(OnErrorSubmodule_);
            return;
        }

        {
            Submodule_.Reset(Loader->MustLoad(key, Copy(value->AsSubConfig())).Release());
            return;
        }
    } END_PARSE

private:
    enum class ERpcResult {
        SUCCESS,
        FAIL,
        FORWARDED_TO_USER
    };

    class TForwardToUserFsm : public NRegExp::TFsm, public TWithDefaultInstance<TForwardToUserFsm> {
    public:
        TForwardToUserFsm()
            : TFsm("X-ForwardToUser-Y", TFsm::TOptions().SetCaseInsensitive(true))
        {}
    };

    TError DoRunRpc(const TConnDescr& descr, TResponse& response, TChunkList& responseBody) const noexcept {
        TChunkList rpcRequestBody;
        {
            // Creating new request body.
            TChunksOutput requestOut;
            TToBackendEncoder encoder{&requestOut};
            Y_PROPAGATE_ERROR(encoder.WriteRequest(*descr.Request, false, Nothing(), TInstant::Max()));
            Y_PROPAGATE_ERROR(Transfer(descr.Input, &encoder, TInstant::Max()));
            rpcRequestBody = std::move(requestOut.Chunks());
        }

        TRequester requester{*Rpcmodule_, descr};
        response = {};
        return requester.Request(TRequest(RpcRequest_), std::move(rpcRequestBody),
                                 false, response, responseBody);
    }

    TErrorOr<ERpcResult> RunRpc(const TConnDescr& connDescr, TRequest& newRequest, TChunkList& newRequestBody) const {
        TExtraAccessLogEntry rpc{ connDescr, "rpc" };

        TResponse rpcResponse;
        TChunkList rpcResponseBody;

        Y_TRY(TError, error) {
            // TODO: do not ignore rpc response
            return DoRunRpc(connDescr, rpcResponse, rpcResponseBody);
        } Y_CATCH {
            LOG_ERROR(TLOG_ERR, connDescr, "rpc request failed: " << GetErrorMessage(error));
            if (TError error = SkipAll(connDescr.Input, TInstant::Max())) {
                connDescr.ExtraAccessLog.SetSummary(GetHandle()->Name(), "client read error");
                return error;
            }
            FillRpcSuccessHeader(connDescr, false);
            return ERpcResult::FAIL;
        }

        TStringBuf forwardToUserHeaderValue = rpcResponse.Headers().GetFirstValue(TForwardToUserFsm::Instance());

        if (forwardToUserHeaderValue && Match(TTrueFsm::Instance(), forwardToUserHeaderValue)) {
            rpcResponse.Headers().Delete(TForwardToUserFsm::Instance());
            Y_TRY(TError, error) {
                Y_PROPAGATE_ERROR(connDescr.Output->SendHead(std::move(rpcResponse), false, TInstant::Max()));
                if (!rpcResponseBody.Empty()) {
                    Y_PROPAGATE_ERROR(connDescr.Output->Send(std::move(rpcResponseBody), TInstant::Max()));
                }
                return connDescr.Output->SendEof(TInstant::Max());
            } Y_CATCH {
                connDescr.ExtraAccessLog.SetSummary(GetHandle()->Name(), "client write error");
                return error;
            };
            connDescr.ExtraAccessLog.SetSummary(GetHandle()->Name(), "forwarded to user");
            return ERpcResult::FORWARDED_TO_USER;
        }

        Y_TRY(TError, error) {
            // Retrieving new request from rpc response body.
            TChunksInput newRequestInput{std::move(rpcResponseBody)};
            TFromClientDecoder reqDecoder{&newRequestInput};
            Y_PROPAGATE_ERROR(reqDecoder.ReadRequest(newRequest, TInstant::Max()));
            return RecvAll(&reqDecoder, newRequestBody, TInstant::Max());
        } Y_CATCH {
            // TODO: get rid of this useless logs
            if (const auto* e = error.GetAs<THttpError>()) {
                connDescr.ExtraAccessLog << " http_parse_error " << e->Code();
                LOG_ERROR(TLOG_ERR, connDescr, "rpc_http_parse_error " << e->Code());
            } else if (const auto* e = error.GetAs<yexception>()) {
                connDescr.ExtraAccessLog << " rpc_response_decode_error";
                LOG_ERROR(TLOG_ERR, connDescr, "rpc_response_decode_error " << e->what());
            } else {
                connDescr.ExtraAccessLog << " rpc_response_decode_unknown_error";
                LOG_ERROR(TLOG_ERR, connDescr, "rpc_response_decode_unknown_error");
            }

            FillRpcSuccessHeader(connDescr, false);
            return ERpcResult::FAIL;
        }

        return ERpcResult::SUCCESS;
    }

    TError DoRun(const TConnDescr& descr, TTls& tls) const noexcept override {
        if (!tls.KillSwitchFileExists() && descr.Request && !descr.Request->Props().UpgradeRequested) {
            // TODO(velavokr): BALANCER-2183 is unlimited buffer really ok here?
            TLimitedRewindableInput rewindIn{ *descr.Input, Max<size_t>() };
            const TConnDescr connDescr = descr.CopyIn(rewindIn);

            TRequest newRequest;
            TChunkList newRequestBody;

            // TODO(tender-bum): replace with Y_TRY/Y_CATCH
            ERpcResult rpcResult;
            // We are creating new request which we got from Rcpmodule_ response.
            Y_PROPAGATE_ERROR(RunRpc(connDescr, newRequest, newRequestBody).AssignTo(rpcResult));

            if (rpcResult == ERpcResult::FORWARDED_TO_USER) {
                return {};
            }

            rewindIn.Rewind();
            if (rpcResult == ERpcResult::FAIL) {
                if (OnErrorSubmodule_) {
                    TExtraAccessLogEntry rpc{ connDescr, "on_rpc_error" };
                    return OnErrorSubmodule_->Run(connDescr);
                } else {
                    return Submodule_->Run(connDescr);
                }
            }

            if (!DryRun_) {
                newRequest.Props().KeepAlive = descr.Request->Props().KeepAlive;
                newRequest.Props().ExplicitKeepAliveHeader = descr.Request->Props().ExplicitKeepAliveHeader;
                newRequest.Props().HTTP2 = descr.Request->Props().HTTP2;

                // Filling conn descr with new request.
                TChunksInput newRequestIn{std::move(newRequestBody)};
                TConnDescr newDescr = connDescr.Copy(&newRequest);
                newDescr.Input = &newRequestIn;
                FillRpcSuccessHeader(newDescr, true);

                return Submodule_->Run(newDescr);
            } else {
                connDescr.Request->Headers().Replace(CHECK_METABALANCER,
                    newRequest.Headers().GetValues(CHECK_METABALANCER)
                );

                FillRpcSuccessHeader(connDescr, true);

                return Submodule_->Run(connDescr);
            }
        } else {
            return Submodule_->Run(descr);
        }
    }

    THolder<TTls> DoInitTls(IWorkerCtl* process) override {
        auto tls = MakeHolder<TTls>();
        if (KillSwitchFile_) {
            tls->KillSwitchChecker = process->SharedFiles()->FileChecker(KillSwitchFile_, TDuration::Seconds(1));
        }
        return tls;
    }

    bool DoExtraAccessLog() const noexcept override {
        return true;
    }

    void FillRpcSuccessHeader(const TConnDescr& descr, bool succeed) const noexcept {
        if (!RpcSuccessHeader_.empty()) {
            descr.Request->Headers().Delete(*RpcSuccessFsm_);
            descr.Request->Headers().Add(RpcSuccessHeader_, (succeed ? "1" : "0"));
        }

    }

    void InitRpcRequest() {
        // preparing request so we won't have to parse it each request
        TString requestString;

        requestString += "POST ";
        requestString += Url_;
        requestString += " HTTP/1.1";
        requestString += CRLF;
        requestString += CRLF;

        TryRethrowError(RpcRequest_.Parse(std::move(requestString)));
        RpcRequest_.Headers().Add("Content-Type", "application/octet-stream");
        RpcRequest_.Headers().Add("Host", Host_);
        RpcRequest_.Props().KeepAlive = false;
    }

private:
    TString KillSwitchFile_;
    TString Url_;
    TString Host_;
    TString RpcSuccessHeader_;
    THolder<TFsm> RpcSuccessFsm_;
    THolder<IModule> Rpcmodule_;
    THolder<IModule> Submodule_;
    THolder<IModule> OnErrorSubmodule_;

    bool DryRun_ = false;

    TRequest RpcRequest_;
};

IModuleHandle* NModRpcRewrite::Handle() {
    return TModule::Handle();
}
