#include "snappy.h"

#include <util/generic/algorithm.h>
#include <util/generic/xrange.h>
#include <util/string/hex.h>
#include <util/system/sem.h>

namespace NHistDb {
    namespace {
        const static NBlockCodecs::ICodec* SNAPPY_CODEC = NBlockCodecs::Codec("snappy");

        const static TFlags<EOpenModeFlag> OUTPUT_FLAGS = (
            EOpenModeFlag::OpenAlways | EOpenModeFlag::ForAppend | EOpenModeFlag::Seq | EOpenModeFlag::WrOnly
        );
        const static size_t OUTPUT_BUFFER_SIZE = 1 << 22;
        const static size_t CHUNK_SIZE = 1 << 16 - 1;

        inline TFlags<EOpenModeFlag> GetOutputFlags(bool useDirect) {
            if (useDirect) {
                return OUTPUT_FLAGS | EOpenModeFlag::Direct;
            } else {
                return OUTPUT_FLAGS;
            }
        }

        inline size_t GetUncompressedSize(const TVector<TSnappyBlock>& blocks) noexcept {
            size_t size = 0;
            for (const auto& block : blocks) {
                size += block.UncompressedSize;
            }
            return size;
        }

        inline size_t GetCompressedSize(const TVector<TSnappyBlock>& blocks) noexcept {
            size_t size = 0;
            for (const auto& block : blocks) {
                size += block.CompressedSize;
            }
            return size;
        }

        inline TString TruncateFile(const TString& fileName, const TVector<TSnappyBlock>& blocks) {
            // this is a hack to use function inside direct io file constructor
            // last one doesn't know about seek
            TFile file(fileName, OUTPUT_FLAGS);
            file.Resize(GetCompressedSize(blocks));
            return fileName;
        }

        struct TArbiter {
            TArbiter()
                : Semaphore(2)
            {
            }

            static TFastSemaphore& Get() {
                return Singleton<TArbiter>()->Semaphore;
            }

            TFastSemaphore Semaphore;
        };
    }

    TSnappyOutputStream::TSnappyOutputStream(const TString& fileName, const TVector<TSnappyBlock>& blocks,
                                             bool useDirect, size_t chunkSize, const NBlockCodecs::ICodec* codec)
        : OutputFile_(TruncateFile(fileName, blocks), GetOutputFlags(useDirect), OUTPUT_BUFFER_SIZE)
        , OutputStream_(OutputFile_)
        , ChunkSize_(chunkSize ? chunkSize : CHUNK_SIZE)
        , Codec_(codec ? codec : SNAPPY_CODEC)
        , Position_(GetUncompressedSize(blocks))
        , Blocks_(blocks)
    {
        Buffer_.Reserve(ChunkSize_);
        CompressedBuffer_.Reserve(ChunkSize_);
    }

    TSnappyOutputStream::TSnappyOutputStream(const TString& fileName, bool useDirect,
                                             size_t chunkSize, const NBlockCodecs::ICodec* codec)
        : TSnappyOutputStream(fileName, TVector<TSnappyBlock>(), useDirect, chunkSize, codec)
    {
    }

    void TSnappyOutputStream::DoWrite(const void* buf, size_t len) {
        TStringBuf incoming(static_cast<const char*>(buf), len);

        size_t remaining = len;
        size_t written = 0;

        while (remaining) {
            if (Buffer_.Size() + remaining < ChunkSize_) {
                const TStringBuf toWrite(incoming.SubStr(written));
                Buffer_.Append(toWrite.data(), toWrite.size());
                written += remaining;
                remaining = 0;
            } else {
                const TStringBuf toWrite(incoming.SubStr(written).Head(ChunkSize_));
                Buffer_.Append(toWrite.data(), toWrite.size());
                written += toWrite.size();
                remaining -= toWrite.size();

                CompressBufferAndFlush();
            }
        }

        Position_ += written;
    }

    void TSnappyOutputStream::DoFlush() {
        if (Buffer_.Size()) {
            CompressBufferAndFlush();
        }
        OutputStream_.Flush();
    }

    void TSnappyOutputStream::CompressBufferAndFlush() {
        Y_ASSERT(Buffer_.Size() > 0);
        const TStringBuf incoming(Buffer_.Data(), Buffer_.Size());
        Codec_->Encode(incoming, CompressedBuffer_);
        OutputStream_.Write(CompressedBuffer_.Data(), CompressedBuffer_.Size());
        Blocks_.emplace_back(TSnappyBlock{
            .UncompressedSize=Buffer_.Size(),
            .CompressedSize=CompressedBuffer_.Size()
        });
        Buffer_.Clear();
        CompressedBuffer_.Clear();
    }

    TSnappyInputStream::TSnappyInputStream(const TString& fileName, const TVector<TSnappyBlock>& blocks,
                                           size_t chunkSize, const NBlockCodecs::ICodec* codec)
        : InputFile_(fileName, EOpenModeFlag::RdOnly | EOpenModeFlag::OpenExisting)
        , Blocks_(blocks)
        , Size_(GetUncompressedSize(Blocks_))
        , ChunkSize_(chunkSize ? chunkSize : CHUNK_SIZE)
        , Codec_(codec ? codec : SNAPPY_CODEC)
    {
        BlockBuffer_.Reserve(ChunkSize_);
        ResultBuffer_.Reserve(ChunkSize_);

        if (Blocks_.size()) {
            CompressedPositions_.assign(Blocks_.size(), 0);
            auto compressed(CompressedPositions_.begin() + 1);

            UncompressedPositions_.assign(Blocks_.size(), 0);
            auto uncompressed(UncompressedPositions_.begin() + 1);

            for (auto it(Blocks_.cbegin());  it + 1 < Blocks_.cend(); ++it, ++compressed, ++uncompressed) {
                *compressed = *(compressed - 1) + it->CompressedSize;
                *uncompressed = *(uncompressed - 1) + it->UncompressedSize;
            }
        }
    }

    void TSnappyInputStream::Seek(i64 offset, SeekDir origin) {
        // Recalculate offset as an absolute file position.
        switch (origin) {
            case sSet: {
                break;
            }
            case sCur: {
                offset = Position_ + offset;
                break;
            }
            case sEnd: {
                offset = Size_ + offset;
                break;
            }
        }

        const auto it(UpperBound(UncompressedPositions_.cbegin(), UncompressedPositions_.cend(), offset));
        NextBlockIndex_ = std::distance(UncompressedPositions_.cbegin(), it) - 1;
        BlockOffset_ = offset - UncompressedPositions_.at(NextBlockIndex_);
        Position_ = offset;
    }

    void TSnappyInputStream::ReadNextBlock(size_t len) {
        if (NextBlockIndex_ != BufferedBlockIndex_) {
            if (Blocks_.empty()) {
                ythrow yexception() << "file is empty";
            }

            const size_t offset(CompressedPositions_.at(NextBlockIndex_));
            const size_t size(Blocks_.at(NextBlockIndex_).CompressedSize);

            RawBuffer_.Resize(size);
            {
                // don't read in parallel from disk
                TGuard<TFastSemaphore> guard(TArbiter::Get());
                InputFile_.Pload(RawBuffer_.Data(), size, offset);
            }

            Codec_->Decode(RawBuffer_, BlockBuffer_);
            BufferedBlockIndex_ = NextBlockIndex_;
        }

        if (len) {
            len = Min(len, BlockBuffer_.Size() - BlockOffset_);
            ResultBuffer_.Append(BlockBuffer_.Data() + BlockOffset_, len);
            Position_ += len;
            BlockOffset_ += len;

        } else {
            const size_t dataSize = BlockBuffer_.Size() - BlockOffset_;
            ResultBuffer_.Append(BlockBuffer_.Data() + BlockOffset_, dataSize);
            Position_ += dataSize;
            BlockOffset_ = 0;
            NextBlockIndex_ += 1;
        }
    }

    size_t TSnappyInputStream::DoNext(const void** ptr, size_t len) {
        const size_t expectedPosition(Position_ + len);
        const auto it(UpperBound(UncompressedPositions_.cbegin(), UncompressedPositions_.cend(), expectedPosition));
        const ssize_t lastBlockIndex(std::distance(UncompressedPositions_.cbegin(), it) - 1);

        ResultBuffer_.Clear();

        for (const auto index : xrange(NextBlockIndex_, lastBlockIndex)) {
            ReadNextBlock();
            Y_UNUSED(index);
        }

        Y_ASSERT(len >= ResultBuffer_.Size());
        size_t remain(len - ResultBuffer_.Size());
        if (remain) {
            ReadNextBlock(remain);
        }

        *ptr = ResultBuffer_.Data();

        return ResultBuffer_.Size();
    }

    TSnappyFile::TSnappyFile(const TString& fileName, ESnappyMode mode, const TVector<TSnappyBlock>& blocks, bool useDirect)
    {
        if (mode == ESnappyMode::APPEND) {
            Output_.ConstructInPlace(fileName, blocks, useDirect);
        } else {
            Input_.ConstructInPlace(fileName, blocks);
        }
    }

    TSnappyFile::TSnappyFile(const TString& fileName, bool useDirect)
        : TSnappyFile(fileName, ESnappyMode::APPEND, TVector<TSnappyBlock>(), useDirect)
    {
    }

    TSnappyInputStream& TSnappyFile::GetInputStream() {
        return Input_.GetRef();
    }

    TSnappyOutputStream& TSnappyFile::GetOutputStream() {
        return Output_.GetRef();
    }

    size_t TSnappyFile::Position() {
        if (Input_.Defined()) {
            return Input_->Position();
        } else {
            return Output_->Position();
        }
    }

    const TVector<TSnappyBlock>& TSnappyFile::Blocks() {
        return Output_->Blocks();
    }

    void TSnappyFile::Seek(i64 offset, SeekDir origin) {
        Input_->Seek(offset, origin);
    }

    size_t TSnappyFile::Write(TStringBuf buf) {
        Output_->Write(buf);
        return buf.size();
    }

    TStringBuf TSnappyFile::Read(size_t len, TBuffer& buf) {
        buf.Resize(len);
        return {buf.Data(), Input_->Read(buf.Data(), buf.Size())};
    }

    TStringBuf TSnappyFile::Read(size_t len) {
        return Read(len, Buffer_);
    }

    void TSnappyFile::Flush() {
        Output_->Flush();
    }

    void TSnappyFile::Finish() {
        Output_->Finish();
    }
}
