#include "lz4.h"

#include <util/generic/buffer.h>
#include <contrib/libs/lz4/lz4.h>
#include <contrib/libs/lz4/lz4hc.h>

#include <string.h>

//=========================================================================================


const char Sign[] = "LZ4.";
const size_t SignSize = 4;


//=========================================================================================


class TLz4Compress::TImpl {
private:
    size_t BlockSize;

    std::function<void(const void *buf, size_t len)>DoWrite;
    std::function<void()>DoFlush;
    std::function<void()>DoFinish;

    bool DoUseHC = true;

    TBuffer InBuf;
    TBuffer OutBuf;

public:
    TImpl(IOutputStream *out,
          size_t blockSize)
        : BlockSize(blockSize)
    {
        DoWrite  = [out](const void *buf, size_t len){ out->Write(buf, len); };
        DoFlush  = [out](){ out->Flush(); };
        DoFinish = [out](){ out->Finish(); };
        Init();
    }

    TImpl(std::stringstream *out,
          size_t blockSize)
        : BlockSize(blockSize)
    {
        DoWrite  = [out](const void *buf, size_t len){ out->write((const char *)buf, len); };
        DoFlush  = [out](){ out->flush(); };
        DoFinish = [](){};
        Init();
    }

private:
    void Init() {
        InBuf.Reserve(BlockSize);
        OutBuf.Reserve((size_t)(BlockSize*1.1));
        DoWrite(Sign, SignSize);
        DoWrite(&BlockSize, sizeof(BlockSize));
    }

public:
    void UseHC() {
        DoUseHC = true;
    }

    void UseDefault() {
        DoUseHC = false;
    }

    void Write(const char *buf,
               size_t len)
    {
        size_t chunk;
        const char *p = buf;
        const char *e = buf + len;
        while (p != e) {
            chunk = (size_t)(e - p);
            chunk = (chunk <= InBuf.Avail()) ? chunk : InBuf.Avail();
            InBuf.Append(p, chunk);
            p += chunk;
            if (InBuf.Avail() == 0) {
                Flush();
            }
        }
    }

    void Flush() {
        if (InBuf.Size() > 0) {
            ssize_t size;
            if (DoUseHC) {
                size = LZ4_compress_HC(InBuf.Data(), OutBuf.Data(), InBuf.Size(), OutBuf.Capacity(), LZ4HC_CLEVEL_MIN);
            } else {
                size = LZ4_compress_default(InBuf.Data(), OutBuf.Data(), InBuf.Size(), OutBuf.Capacity());
            }
            if (size <= 0) {
                ythrow TLZ4RuntimeError() << "fail to compress stream";
            }
            InBuf.Clear();
            DoWrite(&size, sizeof(size));
            DoWrite(OutBuf.Data(), size);
            DoFlush();
        }
    }

    void Finish() {
        Flush();
        DoFinish();
    }
};


//=========================================================================================
#include <util/string/printf.h>
#include <iostream>

class TLz4Decompress::TImpl {
private:
    class TOutBuffer {
    private:
        TBuffer Buffer;
        size_t Len_ = 0;
        size_t Head_ = 0;
        size_t Size_ = 0;

    public:
        TOutBuffer() {
        }
        TOutBuffer(size_t size) {
            Buffer.Reserve(size);
            Reserve(size);
        }
        void Reserve(size_t size) {
            Buffer.Reserve(size);
            Len_ = size;
        }
        char *Data() {
            return Buffer.Data() + Head_;
        }
        size_t Size() {
            return Size_;
        }
        size_t Capacity() {
            return Len_ - Head_;
        }
        void Advance(size_t size) {
            Size_ += size;
            Buffer.Reserve(Head_ + Size_);
        }
        void Clear() {
            Head_ = 0;
            Size_ = 0;
        }
        void ChopHead(size_t size) {
            if (size >= Size_) {
                Clear();
            } else {
                Head_ += size;
                Size_ -= size;
            }
        }
    };

    size_t BlockSize;
    std::function<size_t(void *buf, size_t len)>DoRead;

    TBuffer InBuf;
    TOutBuffer OutBuf;

public:
    TImpl(IInputStream *in){
        DoRead = [in](void *buf, size_t len){ return in->Read(buf, len); };
        Init();
    }

    TImpl(std::stringstream *in) {
        DoRead = [in](void *buf, size_t len){ return in->read((char *)buf, len).gcount(); };
        Init();
    }

private:
    void Init() {
        char sig[SignSize];
        DoRead(sig, SignSize);
        if (memcmp(sig, Sign, SignSize)) {
            ythrow TLZ4TypeError() << "fail to decompress stream, bad signature";
        }
        DoRead(&BlockSize, sizeof(BlockSize));
        InBuf.Reserve((size_t)(BlockSize*1.1));
        OutBuf.Reserve(BlockSize);
    }

public:
    size_t ReadToBuffer() {
        ssize_t size;

        if (OutBuf.Size()) {
            return OutBuf.Size();
        }
        if (!DoRead(&size, sizeof(size))) {
            return 0;
        }
        if (InBuf.Capacity() < (size_t)size) {
            ythrow TLZ4RuntimeError() << "fail to decompress stream, bad chunk size";
        }
        while (InBuf.Size() < (size_t)size) {
            size_t r = DoRead(InBuf.Pos(), size - InBuf.Size());
            if (!r) {
                ythrow TLZ4RuntimeError() << "fail to decompress stream, corrupted chunk";
            }
            InBuf.Advance(r);
        }
        size = LZ4_decompress_safe(InBuf.Data(), OutBuf.Data(), size, OutBuf.Capacity());
        if (size < 0) {
            ythrow TLZ4RuntimeError() << "fail to decompress stream";
        }
        OutBuf.Advance(size);
        InBuf.Clear();

        return size;
    }

    size_t Read(char *buf,
                size_t len)
    {
        size_t size = ReadToBuffer();

        if (size > len) {
            memcpy(buf, OutBuf.Data(), len);
            OutBuf.ChopHead(len);
            return len;
        } else {
            memcpy(buf, OutBuf.Data(), size);
            OutBuf.Clear();
            return size;
        }
    }

    size_t ReadTo(TString &st,
                  char ch)
    {
        size_t result = 0;
        st.clear();

        while (size_t size = ReadToBuffer()) {
            if (const char *pos = (const char *)memchr(OutBuf.Data(), ch, size)) {
                size_t readed = (pos - OutBuf.Data()) + 1;
                result += readed;
                if (readed > 1) {
                    st.append(OutBuf.Data(), pos);
                }
                OutBuf.ChopHead(readed);
                break;
            } else {
                result += size;
                st.append(OutBuf.Data(), size);
                OutBuf.Clear();
            }
        }
        return result;
    }

    ui64 DoReadAll(std::ostream& out) {
        TTempBuf buffer;
        char *ptr = buffer.Data();
        size_t size = buffer.Size();

        ui64 result = 0;
        while (size_t read = Read(ptr, size)) {
            out.write(ptr, read);
            result += read;
        }
        return result;
    }

    TString ReadAll() {
        std::stringstream ss;

        DoReadAll(ss);
        return TString(ss.str());
    }
};


//=========================================================================================


TLz4Compress::TLz4Compress(IOutputStream *out,
                           size_t blockSize)
    : Impl_(new TImpl(out, blockSize))
{
}

TLz4Compress::TLz4Compress(std::stringstream *out,
                           size_t blockSize)
    : Impl_(new TImpl(out, blockSize))
{
}

TLz4Compress::~TLz4Compress() {
    try {
        Finish();
    } catch (...) {
    }
}

void TLz4Compress::UseHC() {
    Impl_->UseHC();
}

void TLz4Compress::UseDefault() {
    Impl_->UseDefault();
}

void TLz4Compress::DoWrite(const void* buf, size_t len) {
    if (!Impl_) {
        ythrow yexception() << "can not write to finalized stream";
    }
    Impl_->Write((const char*)buf, len);
}

void TLz4Compress::DoFlush() {
    if (!Impl_) {
        ythrow yexception() << "can not flush finalized stream";
    }
    Impl_->Flush();
}

void TLz4Compress::DoFinish() {
    THolder<TImpl> impl(Impl_.Release());

    if (impl) {
        impl->Finish();
    }
}

TLz4Decompress::TLz4Decompress(IInputStream *in)
    : Impl_(new TImpl(in))
{
}

TLz4Decompress::TLz4Decompress(std::stringstream *in)
    : Impl_(new TImpl(in))
{
}

TLz4Decompress::~TLz4Decompress() {
}

TString TLz4Decompress::ReadAll() {
    return Impl_->ReadAll();
}

size_t TLz4Decompress::ReadAll(std::ostream &out) {
    return Impl_->DoReadAll(out);
}

size_t TLz4Decompress::DoRead(void *buf,
                              size_t len)
{
    return Impl_->Read((char *)buf, len);
}

size_t TLz4Decompress::DoReadTo(TString &st,
                                char ch)
{
    return Impl_->ReadTo(st, ch);
}

