#include "zstd_compressor.h"

#include <contrib/libs/zstd/include/zdict.h>
#include <contrib/libs/zstd/include/zstd.h>
#include <contrib/libs/zstd/include/zstd_errors.h>

#include <library/cpp/packers/packers.h>

#include <util/ysaveload.h>
#include <util/generic/buffer.h>
#include <util/generic/ptr.h>
#include <util/generic/refcount.h>
#include <util/system/src_location.h>

namespace {
    template <class T>
    size_t NoOpDeleter(T*) {
        return 0;
    }
}

namespace NTravel {
    TZStdCompressor::TZStdCompressor(int compressionLevel)
        : CompressionLevel(compressionLevel)
        , CompressionDict(nullptr, NoOpDeleter<ZSTD_CDict>)
        , DecompressionDict(nullptr, NoOpDeleter<ZSTD_DDict>)
    {
    }

    void TZStdCompressor::Compress(TStringBuf in, TBuffer& outbuf) const {
        outbuf.Clear();

        const size_t maxDatSz = ZSTD_compressBound(in.size());

        outbuf.Resize(maxDatSz);

        auto context = CompressionContext.Get();
        if (*context == nullptr) {
            *context = std::make_unique<TCCtx>(ZSTD_createCCtx(), ZSTD_freeCCtx);
            ZSTD_CCtx_refCDict(context->get()->get(), CompressionDict.get());
        }

        const size_t resSz = ZSTD_compress_usingCDict(context->get()->get(), outbuf.data(), maxDatSz, in.data(), in.size(), CompressionDict.get());
        if (ZSTD_isError(resSz)) {
            ythrow yexception() << "Failed to compress string: " << ZSTD_getErrorName(resSz) << " (code " << (int)ZSTD_getErrorCode(resSz) << ")";
        }

        outbuf.Resize(resSz);
    }

    void TZStdCompressor::Decompress(TStringBuf in, TBuffer& outbuf) const {
        outbuf.Clear();

        auto uncompressedSize = ZSTD_getFrameContentSize(in.data(), in.size());
        outbuf.Resize(uncompressedSize);

        auto context = DecompressionContext.Get();
        if (*context == nullptr) {
            *context = std::make_unique<TDCtx>(ZSTD_createDCtx(), ZSTD_freeDCtx);
            ZSTD_DCtx_refDDict(context->get()->get(), DecompressionDict.get());
        }
        auto resSz = ZSTD_decompress_usingDDict(context->get()->get(), outbuf.data(), outbuf.size(), in.data(), in.size(), DecompressionDict.get());
        if (ZSTD_isError(resSz)) {
            ythrow yexception() << "Failed to decompress string: " << ZSTD_getErrorName(resSz) << " (code " << (int)ZSTD_getErrorCode(resSz) << ")";
        }

        Y_ENSURE(outbuf.size() == resSz);
    }

    void TZStdCompressor::Train(const TVector<TString>& strings) {
        Y_ENSURE(CompressionDict == nullptr, "Train called multiple times");

        TBuffer data;
        TVector<size_t> lens;
        lens.reserve(strings.size());
        data.Reserve(strings.size() * strings[0].size() * 2); // approximate size

        for (const auto& s : strings) {
            data.Append(s.data(), s.length());
            lens.push_back(s.length());
        }

        DictBuffer.Resize(Max<size_t>(1u << 20u, data.Size() * 4));

        ZDICT_fastCover_params_t params;
        memset(&params, 0, sizeof(params));
        params.d = 8;
        params.steps = 4;
        params.zParams.compressionLevel = CompressionLevel;
        auto trainResult = ZDICT_optimizeTrainFromBuffer_fastCover(DictBuffer.data(), DictBuffer.size(), data.Data(), lens.data(), lens.size(), &params);

        if (ZSTD_isError(trainResult)) {
            DictBuffer.Clear();
            ythrow yexception() << "Failed to train compressor: " << ZSTD_getErrorName(trainResult) << " (code " << (int)ZSTD_getErrorCode(trainResult) << ")";
        }
        DictBuffer.Resize(trainResult);
        DictBuffer.ShrinkToFit();

        CompressionDict = TCDict(ZSTD_createCDict(DictBuffer.data(), DictBuffer.size(), CompressionLevel), ZSTD_freeCDict);
        DecompressionDict = TDDict(ZSTD_createDDict(DictBuffer.data(), DictBuffer.size()), ZSTD_freeDDict);
    }

    size_t TZStdCompressor::GetDictSize() const {
        return DictBuffer.Capacity();
    }
}
