#include <balancer/serval/core/config.h>

#include <contrib/libs/brotli/include/brotli/encode.h>
#include <contrib/libs/zlib/zlib.h>

#include <util/string/split.h>
#include <util/string/strip.h>

static bool EnableFor(const NSv::THead& head) {
    // `transfer-encoding: chunked` should never occur here as it is handled by the http library.
    // TODO also disable for small `content-length`
    return !head.has("upgrade") && !head.has("content-encoding") && !head.has("transfer-encoding");
}

static NSv::TAction Compress(const YAML::Node&, NSv::TAuxData&) {
    class TStreamWithGZip: public NSv::TStreamProxy {
    public:
        using NSv::TStreamProxy::TStreamProxy;

        ~TStreamWithGZip() {
            deflateEnd(&Z_);
        }

        bool WriteHead(NSv::THead& head) noexcept override {
            if (!head.IsInformational() && (Enable_ &= EnableFor(head))) {
                head.erase("content-length");
                head.emplace("content-encoding", "gzip");
            }
            Z_.next_out = (unsigned char*)B_;
            Z_.avail_out = sizeof(B_);
            return NSv::TStreamProxy::WriteHead(head);
        }

        bool Write(TStringBuf chunk) noexcept override {
            if (!Enable_) {
                return NSv::TStreamProxy::Write(chunk);
            }
            Z_.next_in = (unsigned char*)chunk.data();
            Z_.avail_in = chunk.size();
            while (Z_.avail_in) {
                int code = deflate(&Z_, Z_NO_FLUSH);
                switch (code) {
                case Z_OK:
                    continue;
                case Z_BUF_ERROR:
                    if (!Flush()) {
                        return false;
                    }
                    continue;
                default:
                    return !mun_error(EINVAL, "zlib returned %d", code);
                };
            }
            return true;
        }

        bool Close(NSv::THeaderVector& tail) noexcept override {
            if (!Enable_) {
                return NSv::TStreamProxy::Close(tail);
            }
            while (true) {
                int code = deflate(&Z_, Z_FINISH);
                switch (code) {
                case Z_OK:
                case Z_BUF_ERROR:
                    if (!Flush() MUN_RETHROW) {
                        return false;
                    }
                    continue;
                case Z_STREAM_END:
                    return Flush() && NSv::TStreamProxy::Close(tail);
                default:
                    return !mun_error(EINVAL, "zlib returned %d", code);
                }
            }
        }

    private:
        bool Flush() noexcept {
            if (!NSv::TStreamProxy::Write(TStringBuf(B_, sizeof(B_) - Z_.avail_out)) MUN_RETHROW) {
                return false;
            }
            Z_.next_out = (unsigned char*)B_;
            Z_.avail_out = sizeof(B_);
            return true;
        }

    private:
        z_stream Z_ = {};
        char B_[2048];
        bool Enable_ = !deflateInit2(&Z_, 6, Z_DEFLATED, MAX_WBITS | 16, 8, Z_DEFAULT_STRATEGY);
    };

    class TStreamWithBrotli: public NSv::TStreamProxy {
    public:
        using NSv::TStreamProxy::TStreamProxy;

        bool WriteHead(NSv::THead& head) noexcept override {
            if (!head.IsInformational() && EnableFor(head)) {
                S_.Reset(BrotliEncoderCreateInstance(nullptr, nullptr, nullptr));
                if (S_) {
                    // TODO only for text/... types?
                    BrotliEncoderSetParameter(S_.Get(), BROTLI_PARAM_MODE, BROTLI_MODE_TEXT);
                    // 4 is quite optimal according to https://quixdb.github.io/squash-benchmark/ and my experiments
                    BrotliEncoderSetParameter(S_.Get(), BROTLI_PARAM_QUALITY, 4);
                    head.erase("content-length");
                    head.emplace("content-encoding", "br");
                }
            }
            return NSv::TStreamProxy::WriteHead(head);
        }

        bool Write(TStringBuf chunk) noexcept override {
            return S_ ? WriteCompressed(chunk, BROTLI_OPERATION_PROCESS) : NSv::TStreamProxy::Write(chunk);
        }

        bool Close(NSv::THeaderVector& tail) noexcept override {
            return (!S_ || WriteCompressed({}, BROTLI_OPERATION_FINISH)) && NSv::TStreamProxy::Close(tail);
        }

    private:
        bool WriteCompressed(TStringBuf chunk, BrotliEncoderOperation op) noexcept {
            auto data = (const ui8*)chunk.data();
            auto size = chunk.size();
            do {
                size_t outSize = 0;
                ui8* outData = nullptr;
                if (BrotliEncoderCompressStream(S_.Get(), op, &size, &data, &outSize, &outData, nullptr) == BROTLI_FALSE) {
                    return !mun_error(EINVAL, "compression failure");
                }
                size_t realOutSize = 0;
                auto realOutData = BrotliEncoderTakeOutput(S_.Get(), &realOutSize);
                if (realOutSize > 0 && !NSv::TStreamProxy::Write({(const char*)realOutData, realOutSize}) MUN_RETHROW) {
                    return false;
                }
            } while (size > 0 || BrotliEncoderHasMoreOutput(S_.Get()));
            return true;
        }

    private:
        struct TDestroy {
            static void Destroy(BrotliEncoderState* s) noexcept {
                BrotliEncoderDestroyInstance(s);
            }
        };
        THolder<BrotliEncoderState, TDestroy> S_;
    };

    return [](NSv::IStreamPtr& s) {
        auto* head = s->Head();
        if (!head) {
            return false;
        }
        bool brotli = false;
        bool gzip = false;
        for (auto range = head->equal_range("accept-encoding"); range.first != range.second; range.first++) {
            for (TStringBuf tok : StringSplitter(range.first->second).Split(',')) {
                brotli |= (StripString(tok) == "br");
                gzip |= (StripString(tok) == "gzip");
            }
        }
        if (brotli) {
            s = std::make_shared<TStreamWithBrotli>(std::move(s));
        } else if (gzip) {
            s = std::make_shared<TStreamWithGZip>(std::move(s));
        }
        return true;
    };
}

SV_DEFINE_ACTION("compress", Compress);
