#include "module.h"

#include <balancer/kernel/module/module.h>
#include <library/cpp/http/io/compression.h>

using namespace NSrvKernel;

namespace NModCompressor {

class TSendDataStream final : public IOutputStream {
public:
    TSendDataStream(IIoOutput* output) : Output_(output) {}

    void DoWrite(const void* data, size_t size) override {
        if (HasError_) {
            return;
        }

        Error_ = Output_->Send(TChunkList(NewChunkForceCopy(TStringBuf((const char*) data, size))), Deadline_);
        if (Error_) {
            HasError_ = true;
        }
    }

    void SetDeadline(TInstant deadline) {
        Deadline_ = deadline;
    }

    TError GetError() {
        return std::exchange(Error_, TError{});
    }

private:
    IIoOutput* Output_;
    TError Error_;
    bool HasError_ = false;
    TInstant Deadline_ = TInstant::Max();
};

class TCompressionStream final : public IHttpOutput {
public:
    TCompressionStream(IHttpOutput* output, const TRequest& request, const TVector<TStringBuf>& compressionCodecs)
        : Output_(output)
        , AcceptEncodings_(request.Headers().GetValues("Accept-Encoding"))
        , CompressionCodecs_(compressionCodecs)
        , SendDataStream_(output) {}
    ~TCompressionStream() {
        try {
            Encoder_.Reset();
        } catch (...) {}
    }

    TError DoSendHead(TResponse&& response, const bool forceClose, TInstant deadline) override {
        Y_PROPAGATE_ERROR(Init(response));
        return Output_->SendHead(std::move(response), forceClose, deadline);
    }

    TError DoSend(TChunkList lst, TInstant deadline) noexcept override {
        if (!Encoder_) {
            return Output_->Send(std::move(lst), deadline);
        }

        if (lst.Empty()) {
            return Finish();
        }

        if (HasError_) {
            return {};
        }

        TError error = [&] () -> TError {
            try {
                SendDataStream_.SetDeadline(deadline);
                Encoder_->Write(StrInplace(lst));
            } Y_TRY_STORE(yexception);
            return {};
        }();
        if (error) {
            HasError_ = true;
            return error;
        }

        if (auto error = SendDataStream_.GetError()) {
            HasError_ = true;
            return error;
        }

        return {};
    }

    TError DoSendTrailers(THeaders&& trailers, TInstant deadline) override {
        return Output_->SendTrailers(std::move(trailers), deadline);
    }

private:
    TError Init(TResponse& r) {
        if (r.Headers().FindValues("Content-Encoding")) {
            return {};
        }
        if (IsIn({HTTP_NO_CONTENT, HTTP_NOT_MODIFIED, HTTP_SWITCHING_PROTOCOLS}, r.ResponseLine().StatusCode)) {
            return {};
        }

        THashSet<TStringBuf, TCIOps, TCIOps> accepted;
        for (const auto& value : AcceptEncodings_) {
            for (const auto& it : StringSplitter(value.AsStringBuf()).Split(',').SkipEmpty()) {
                accepted.insert(StripString(it.Token()));
            }
        }
        auto encoding = NHttp::ChooseBestCompressionScheme(
            [&accepted](TStringBuf v) {
                return accepted.contains(v);
            },
            CompressionCodecs_ ? CompressionCodecs_ : TCompressionCodecFactory::Instance().GetBestCodecs()
        );

        if (encoding) {
            if (const auto* encoderFactory = TCompressionCodecFactory::Instance().FindEncoder(encoding)) {
                r.Props().ContentLength.reset();
                r.Headers().Add("Content-Encoding", encoding);

                TError error = [&] () -> TError {
                    try {
                        Encoder_ = (*encoderFactory)(&SendDataStream_);
                    } Y_TRY_STORE(yexception);
                    return {};
                }();

                if (error) {
                    HasError_ = true;
                    return error;
                }
            }
        }
        return {};
    }

    TError Finish() {
        TError error = [&] () -> TError {
            try {
                Encoder_.Reset();
            } Y_TRY_STORE(yexception);
            return {};
        }();
        if (error) {
            HasError_ = true;
            return error;
        }

        if (auto error = SendDataStream_.GetError()) {
            HasError_ = true;
            return error;
        }

        return Output_->SendEof(TInstant::Max());
    }

    IHttpOutput* Output_;
    const TVector<TStringStorage> AcceptEncodings_;
    const TVector<TStringBuf>& CompressionCodecs_;
    TSendDataStream SendDataStream_;
    THolder<IOutputStream> Encoder_;
    bool HasError_ = false;
};

class TRecvDataStream final : public IInputStream {
public:
    TRecvDataStream(IIoInput* input, TInstant& deadline) : Input_(input), Deadline_(deadline) {}

    size_t DoRead(void* buf, size_t len) override {
        if (HasError_) {
            return 0;
        }

        TChunkList lst;
        Error_ = Input_->Recv(lst, GetDeadline());
        if (Error_) {
            HasError_ = true;
            return 0;
        }

        if (lst.Empty()) {
            return 0;
        }

        size_t r = lst.CopyDataTo(buf, len);
        lst.Skip(r);
        if (!lst.Empty()) {
            Input_->UnRecv(std::move(lst));
        }
        return r;
    }

    TError GetError() {
        return std::exchange(Error_, TError{});
    }

private:
    TInstant GetDeadline() {
        return Deadline_;
    }

    IIoInput* Input_;
    TInstant& Deadline_;
    TError Error_;
    bool HasError_ = false;
};

class TDecompressionStream final : public IIoInput {
public:
    TDecompressionStream(IIoInput* input) : Input_(input), RecvDataStream_(Input_, Deadline_) {}
    ~TDecompressionStream() {
        try {
            Decoder_.Reset();
        } catch (...) {}
    }


    TError Init(TRequest& r) {
        if (r.Props().Version != 1) {
            return {};
        }
        if (TStringBuf encoding = GetEncoding(r)) {
            if (const auto* decoderFactory = TCompressionCodecFactory::Instance().FindDecoder(encoding)) {
                r.Props().ContentLength.reset();
                r.Props().ChunkedTransfer = true;

                TError error = [&] () -> TError {
                    try {
                        Decoder_ = (*decoderFactory)(&RecvDataStream_);
                    } Y_TRY_STORE(yexception);
                    return {};
                }();
                return error;
            }
        }
        return {};
    }

    TError DoRecv(TChunkList& lst, TInstant deadline) noexcept override {
        if (!Decoder_) {
            return Input_->Recv(lst, deadline);
        }

        Deadline_ = deadline;

        TChunkPtr chunk = NewChunkReserve();
        size_t size = 0;
        TError error = [&] () -> TError {
            try {
                size = Decoder_->Read(chunk->Data(), chunk->Length());
            } Y_TRY_STORE(yexception);
            return {};
        }();

        if (auto error = RecvDataStream_.GetError()) {
            HasError_ = true;
            return error;
        }

        if (error) {
            HasError_ = true;
            return error;
        }

        if (!size) {
            return {};
        }

        chunk->Shrink(size);
        lst.Push(std::move(chunk));
        return {};
    }

private:
    TStringBuf GetEncoding(const TRequest& r) {
        for (const auto& value : r.Headers().GetValuesRef("Content-Encoding")) {
            for (const auto& it : StringSplitter(value.AsStringBuf()).Split(',').SkipEmpty()) {
                if (it.Token() != "identity") {
                    return it.Token();
                }
            }
        }
        return {};
    }

    IIoInput* Input_;
    TRecvDataStream RecvDataStream_;
    THolder<IInputStream> Decoder_;
    TInstant Deadline_;
    bool HasError_ = false;
};

} // namespace NModCompressor

MODULE_BASE(compressor, TModuleWithSubModule) {
public:
    TModule(const TModuleParams& mp)
        : TModuleBase(mp)
    {
        Config->ForEach(this);

        Y_ENSURE_EX(Submodule_,
            TConfigParseError() << "no submodule configured");
    }

private:
    START_PARSE {
        ON_KEY("enable_compression", EnableCompression_) {
            return;
        }

        ON_KEY("enable_decompression", EnableDecompression_) {
            return;
        }

        if (key == "compression_codecs") {
            CompressionCodecsValue_ = value->AsString();
            CompressionCodecs_ = StringSplitter(CompressionCodecsValue_).Split(',');
            return;
        }

        {
            Submodule_.Reset(Loader->MustLoad(key, Copy(value->AsSubConfig())).Release());
            return;
        }

    } END_PARSE

    TError DoRun(const TConnDescr& descr) const noexcept override {
        NModCompressor::TDecompressionStream in(descr.Input);
        NModCompressor::TCompressionStream out{descr.Output, *descr.Request, CompressionCodecs_};

        TConnDescr newDescr = descr.Copy();
        if (EnableDecompression_) {
            if (TError error = in.Init(*descr.Request)) {
                descr.ExtraAccessLog.SetSummary(GetHandle()->Name(), "decompression error");
                return error;
            }
            newDescr.Input = &in;
        }
        if (EnableCompression_) {
            newDescr.Output = &out;
        }
        return Submodule_->Run(newDescr);
    }

    bool DoExtraAccessLog() const noexcept override {
        return true;
    }

private:
    bool EnableCompression_ = true;
    bool EnableDecompression_ = false;
    TString CompressionCodecsValue_;
    TVector<TStringBuf> CompressionCodecs_;
};

NSrvKernel::IModuleHandle* NModCompressor::Handle() {
    return TModule::Handle();
}
