#include "module.h"

#include <balancer/kernel/custom_io/chunkio.h>
#include <balancer/kernel/helpers/errors.h>
#include <balancer/kernel/http/parser/header_validation.h>
#include <balancer/kernel/http/parser/httpencoder.h>
#include <balancer/kernel/http/parser/http.h>
#include <balancer/kernel/memory/chunks.h>
#include <balancer/kernel/module/module.h>
#include <balancer/kernel/custom_io/stream.h>

#include <library/cpp/http/misc/httpcodes.h>
#include <library/cpp/string_utils/base64/base64.h>

#include <util/stream/file.h>

using namespace NConfig;
using namespace NSrvKernel;
using namespace NModErrorDocument;

namespace NSrvKernel {
    using TContent = TMaybe<TString>;

    static size_t ContentLength(const TContent& c) {
        return c ? c->size() : 0;
    }
}

MODULE(errordocument) {
public:
    TModule(const TModuleParams& mp)
        : TModuleBase(mp)
    {
        Config->ForEach(this);

        Y_ENSURE_EX(Status_,
            TConfigParseError() << "Status is not set");

        if (IsIn({HTTP_NO_CONTENT, HTTP_NOT_MODIFIED}, Status_) && ContentLength(Content_)) {
            ythrow TConfigParseError() << "Status code " << Status_ << " must not have a body";
        }

        TString statusLine(TString("HTTP/1.1 ").append(HttpCodeStrEx(Status_)).append("\r\n\r\n"));
        TryRethrowError(Response_.Parse(std::move(statusLine)));
        Response_.Props().ContentLength = ContentLength(Content_);
    }

private:
    START_PARSE {
        ON_KEY("status", Status_) {
            if (!IsHttpCode(Status_)) {
                ythrow TConfigParseError() << "Invalid http code " << Status_;
            }
            return;
        }

        TString file;
        ON_KEY("file", file) {
            Y_ENSURE_EX(!Content_, TConfigParseError() << "content redefined");
            Content_ = TFileInput(file).ReadAll();
            return;
        }

        TString content;
        ON_KEY("content", content) {
            Y_ENSURE_EX(!Content_, TConfigParseError() << "content redefined");
            Content_ = content;
            return;
        }

        TString base64;
        ON_KEY("base64", base64) {
            Y_ENSURE_EX(!Content_, TConfigParseError() << "content redefined");
            Content_ = Base64Decode(base64);
            return;
        }

        TString copyHeaders;
        ON_KEY("remain_headers", copyHeaders) {
            RemainHeaders_.Reset(new TFsm(copyHeaders, TFsm::TOptions().SetCaseInsensitive(true)));
            return;
        }

        if ("headers" == key) {
            ParseMap(value->AsSubConfig(), [&](auto&& key, auto&& val) {
                Y_ENSURE_EX(
                    key && CheckHeaderName(key)
                    && !AsciiEqualsIgnoreCase(key, "content-length")
                    && !AsciiEqualsIgnoreCase(key, "transfer-encoding")
                    && !AsciiEqualsIgnoreCase(key, "connection"),
                    TConfigParseError() << "invalid header name: " << key.Quote()
                );
                Y_ENSURE_EX(CheckHeaderValue(val->AsString()),
                    TConfigParseError() << "invalid value for header " << key << ": " << val->AsString().Quote()
                );
                Headers_.emplace_back(key, val->AsString());
            });
            return;
        }

        ON_KEY("force_conn_close", ForceConnClose_) {
            return;
        }
    } END_PARSE

    TError DoRun(const TConnDescr& descr) const noexcept override {
        TResponse response = Response_;

        for (auto&& header : Headers_) {
            response.Headers().Add(header.first, header.second);
        }

        if (RemainHeaders_) {
            for (auto& header : descr.Request->Headers()) {
                if (Match(*RemainHeaders_, header.first.AsStringBuf())) {
                    for (auto & headerVal : header.second) {
                        response.Headers().Add(header.first.AsString(), headerVal.AsString());
                    }
                }
            }
        }

        Y_TRY(TError, error) {
            const ui16 statusCode = response.ResponseLine().StatusCode;
            Y_PROPAGATE_ERROR(descr.Output->SendHead(std::move(response), ForceConnClose_, TInstant::Max()));
            if (Content_ && !Content_->Empty()) {
                Y_PROPAGATE_ERROR(descr.Output->Send(TChunkList{*Content_}, TInstant::Max()));
            }
            Y_PROPAGATE_ERROR(descr.Output->SendEof(TInstant::Max()));

            descr.ExtraAccessLog << " succ " << statusCode;

            // reading body for not failing on parsing of next keepalive request
            return SkipAll(descr.Input, TInstant::Max());
        } Y_CATCH {
            descr.Properties->ConnStats.ClientError += 1;
            descr.ExtraAccessLog << " error";
            descr.ExtraAccessLog.SetSummary(GetHandle()->Name(), "client error");
            return error;
        }

        if (ForceConnClose_) {
            descr.ExtraAccessLog.SetSummary(GetHandle()->Name(), "success, force conn close");
            return Y_MAKE_ERROR(TForceStreamClose{});
        }
        descr.ExtraAccessLog.SetSummary(GetHandle()->Name(), "success");
        return {};
    }

    bool DoExtraAccessLog() const noexcept override {
        return true;
    }

private:
    TContent Content_;
    std::vector<std::pair<TString, TString>> Headers_;
    THolder<TFsm> RemainHeaders_;
    ui32 Status_ = 0;
    bool ForceConnClose_ = false;
    TResponse Response_;
};

IModuleHandle* NModErrorDocument::Handle() {
    return TModule::Handle();
}
