#include "MemoryStream.hpp"
#include <algorithm>

namespace twitch {
MemoryStream::MemoryStream(size_t blockSize)
    : m_position(0)
    , m_index(0)
    , m_blockSize(blockSize)
{
}

int64_t MemoryStream::read(uint8_t* buffer, size_t size)
{
    if (!buffer) {
        return -1;
    }

    if (size == 0) {
        return 0;
    }

    size_t read = 0;
    while (read < size && m_index < m_buffers.size()) {
        auto& data = m_buffers[m_index];
        size_t offset = m_position - getBlockStartOffset(m_index); // offset inside the block
        size_t remaining = size - read;

        if (remaining > data.size() - offset) {
            remaining = data.size() - offset;
        }

        if (remaining == 0) {
            m_index++;
            continue;
        }

        std::memcpy(buffer + read, data.data() + offset, remaining);
        read += remaining;
        m_position += remaining;
        offset += remaining;

        if (offset >= data.size()) {
            m_index++;
        }
    }

    return static_cast<int64_t>(read);
}

int64_t MemoryStream::write(const uint8_t* buffer, size_t size)
{
    if (!buffer) {
        return -1;
    }

    if (size == 0) {
        return 0;
    }

    size_t written = 0;
    while (size != 0) {
        while (m_index >= m_buffers.size()) {
            m_buffers.emplace_back();
            m_buffers.back().reserve(m_blockSize);
        }
        auto& data = m_buffers[m_index];
        size_t offset = m_position - getBlockStartOffset(m_index);
        size_t write = std::min(size, m_blockSize);
        size_t available = m_blockSize - offset;

        if (available == 0) {
            m_index++;
            if (m_index >= m_buffers.size()) {
                m_buffers.emplace_back();
                m_buffers.back().reserve(m_blockSize);
            }
            continue;
        }
        if (size > available) {
            write = available;
        }

        data.insert(data.begin() + offset, buffer + written, buffer + written + write);
        size -= write;
        m_position += write;
        written += write;
    }

    return static_cast<int64_t>(written);
}

std::vector<uint8_t> MemoryStream::take(int64_t bytes)
{
    std::vector<uint8_t> result;
    result.reserve(static_cast<size_t>(bytes));
    int64_t remaining = bytes;
    while (remaining && !m_buffers.empty()) {
        auto& buffer = m_buffers.front();

        if (remaining < static_cast<int64_t>(buffer.size())) {
            result.insert(result.end(), buffer.begin(), buffer.begin() + remaining);
            buffer.erase(buffer.begin(), buffer.begin() + remaining);
            remaining = 0;
        } else {
            result.insert(result.end(), buffer.begin(), buffer.end());
            remaining -= buffer.size();
            erase(0);
        }
    }

    return result;
}

void MemoryStream::remove(int64_t bytes)
{
    int64_t remaining = bytes;
    while (remaining && !m_buffers.empty()) {
        auto& buffer = m_buffers.front();

        if (remaining < static_cast<int64_t>(buffer.size())) {
            buffer.erase(buffer.begin(), buffer.begin() + remaining);
            remaining = 0;
        } else {
            remaining -= buffer.size();
            erase(0);
        }
    }
}

bool MemoryStream::seek(size_t position)
{
    auto len = length();
    if (len < 0 || static_cast<int64_t>(position) > len) {
        return false;
    }

    m_position = position;
    m_index = 0;

    size_t size = 0;
    for (size_t i = 0; i < m_buffers.size() && position > size; i++) {
        size += m_buffers[i].size();
        m_index = i;
    }

    return true;
}

int64_t MemoryStream::length() const
{
    size_t size = 0;
    for (const auto& buffer : m_buffers) {
        size += buffer.size();
    }
    return static_cast<int64_t>(size);
}

size_t MemoryStream::getBlockStartOffset(size_t index)
{
    if (index < 1) {
        return 0;
    }
    size_t size = 0;
    for (size_t i = 0; i < index; i++) {
        size += m_buffers[i].size();
    }
    return size;
}

void MemoryStream::erase(size_t index)
{
    if (index < m_buffers.size()) {
        m_buffers.erase(m_buffers.begin() + index);
        if (m_index > 0) {
            m_index--;
        }
    }
}
}
