#include "testenv.hpp"
#include "file/FileStream.hpp"
#include "playercore/platform/NativePlatform.hpp"
#include "media/MemoryStream.hpp"
#include <algorithm>
#include <cassert>
#include <gtest/gtest.h>
#include <memory>

using namespace twitch;

// Factory methods for creating Stream objects
template <class T>
std::unique_ptr<Stream> CreateStream();

template <>
std::unique_ptr<Stream> CreateStream<MemoryStream>()
{
    return std::unique_ptr<MemoryStream>(new MemoryStream());
}

// Tests for Stream implementations
template <class T>
class StreamTest : public ::testing::Test {
protected:
    StreamTest()
        : m_stream(CreateStream<T>())
    {
    }

    void SetUp() override
    {
        ASSERT_TRUE((bool)m_stream);
    }

protected:
    void writeTestData(size_t size)
    {
        // Write the the stream
        std::vector<uint8_t> buffer(size);
        for (size_t i = 0; i < size; i++) {
            buffer[i] = i + this->m_stream->position();
        }

        ASSERT_EQ(static_cast<int64_t>(size), this->m_stream->write(buffer.data(), size));
    }

    std::unique_ptr<Stream> m_stream;
};

// Register Stream types
using testing::Types;
typedef Types<MemoryStream> Implementations;
TYPED_TEST_CASE(StreamTest, Implementations);

// Stream type tests

TYPED_TEST(StreamTest, EmptyStream)
{
    EXPECT_EQ(0, this->m_stream->length());
    EXPECT_EQ(0, this->m_stream->position());
}

TYPED_TEST(StreamTest, InvalidWriteInput)
{
    EXPECT_EQ(-1, this->m_stream->write(nullptr, 10));

    uint8_t buffer = 1;
    EXPECT_EQ(0, this->m_stream->write(&buffer, 0));
}

TYPED_TEST(StreamTest, InvalidReadInput)
{
    const int64_t writeBufferSize = 10;
    std::vector<uint8_t> writeBuffer(writeBufferSize, 1);
    ASSERT_EQ(writeBufferSize, this->m_stream->write(writeBuffer.data(), writeBufferSize));

    EXPECT_EQ(-1, this->m_stream->read(nullptr, 10));

    uint8_t buffer;
    EXPECT_EQ(0, this->m_stream->read(&buffer, 0));
}

TYPED_TEST(StreamTest, Write)
{
    int64_t totalSize = 0;

    for (int64_t bufferSize = 1; bufferSize < 10; bufferSize++) {
        std::vector<uint8_t> buffer(bufferSize, bufferSize);
        EXPECT_EQ(bufferSize, this->m_stream->write(buffer.data(), bufferSize));

        totalSize += bufferSize;
        EXPECT_EQ(totalSize, this->m_stream->length()) << "Stream length did not increment by " << bufferSize;
        EXPECT_EQ(totalSize, this->m_stream->position()) << "Stream position did not increment by " << bufferSize;
    }
}

TYPED_TEST(StreamTest, Read)
{
    const int64_t writeBufferSize = 100;
    this->writeTestData(writeBufferSize);

    this->m_stream->seek(0);

    int64_t expectedPosition = 0;
    ASSERT_EQ(expectedPosition, this->m_stream->position());
    size_t readIndex = 0;

    for (int64_t bufferSize = 1; bufferSize < 10; bufferSize++) {
        std::vector<uint8_t> buffer(bufferSize, bufferSize);
        EXPECT_EQ(bufferSize, this->m_stream->read(buffer.data(), bufferSize));

        for (const auto& byte : buffer) {
            EXPECT_EQ(readIndex, byte);
            readIndex++;
        }

        expectedPosition += bufferSize;
        EXPECT_EQ(writeBufferSize, this->m_stream->length()) << "Stream length changed during read";
        EXPECT_EQ(expectedPosition, this->m_stream->position()) << "Stream position did not increment by " << bufferSize;
    }
}

TYPED_TEST(StreamTest, ReadCappedByLength)
{
    const int64_t writeBufferSize = 10;
    this->writeTestData(writeBufferSize);

    this->m_stream->seek(0);

    const int64_t readBufferSize = writeBufferSize * 2;
    std::vector<uint8_t> readBuffer(readBufferSize);
    EXPECT_EQ(writeBufferSize, this->m_stream->read(readBuffer.data(), readBufferSize));

    for (size_t i = 0; i < writeBufferSize; i++) {
        EXPECT_EQ(i, readBuffer[i]);
    }
}

TYPED_TEST(StreamTest, InvalidSeekInput)
{
    EXPECT_EQ(0, this->m_stream->position());

    this->m_stream->seek(0);
    EXPECT_EQ(0, this->m_stream->position());

    this->m_stream->seek(1);
    EXPECT_EQ(0, this->m_stream->position());

    const int64_t bufferSize = 10;
    std::vector<uint8_t> buffer(bufferSize, 1);
    ASSERT_EQ(bufferSize, this->m_stream->write(buffer.data(), bufferSize));

    this->m_stream->seek(1);
    EXPECT_EQ(1, this->m_stream->position());

    this->m_stream->seek(10);
    EXPECT_EQ(10, this->m_stream->position());

    this->m_stream->seek(11);
    EXPECT_EQ(10, this->m_stream->position());

    this->m_stream->seek(1000);
    EXPECT_EQ(10, this->m_stream->position());
}

TYPED_TEST(StreamTest, SeekThroughPositions)
{
    const int64_t writeBufferSize = 100;
    this->writeTestData(writeBufferSize);
    this->m_stream->seek(0);

    uint8_t readBuffer;
    for (int64_t i = 0; i < static_cast<int64_t>(writeBufferSize); i++) {
        this->m_stream->seek(i);
        EXPECT_EQ(i, this->m_stream->position());

        ASSERT_EQ(1, this->m_stream->read(&readBuffer, 1));
        EXPECT_EQ(i, readBuffer);
    }

    for (int64_t i = 0; i < static_cast<int64_t>(writeBufferSize); i++) {
        int64_t position = writeBufferSize - i - 1;
        this->m_stream->seek(position);
        EXPECT_EQ(position, this->m_stream->position());

        ASSERT_EQ(1U, this->m_stream->read(&readBuffer, 1));
        EXPECT_EQ(position, readBuffer);
    }
}

TYPED_TEST(StreamTest, WriteReadInterleaved)
{
    const int64_t iterations = 100;

    for (int64_t writeSize = 1; writeSize < iterations; writeSize++) {
        const int64_t writePosition = this->m_stream->position();
        std::vector<uint8_t> writeBuffer(writeSize, writeSize);
        ASSERT_EQ(writeSize, this->m_stream->write(writeBuffer.data(), writeSize));

        this->m_stream->seek(writePosition);
        std::vector<uint8_t> readBuffer(writeSize);
        EXPECT_EQ(writeSize, this->m_stream->read(readBuffer.data(), writeSize));
        for (const auto& byte : readBuffer) {
            ASSERT_EQ(writeSize, byte);
        }
    }

    for (int64_t writeSize = 1; writeSize < iterations; writeSize++) {
        const int64_t writePosition = this->m_stream->position();
        std::vector<uint8_t> writeBuffer(writeSize, writeSize);
        ASSERT_EQ(writeSize, this->m_stream->write(writeBuffer.data(), writeSize));

        std::fill(writeBuffer.begin(), writeBuffer.end(), writeSize * 2);
        ASSERT_EQ(writeSize, this->m_stream->write(writeBuffer.data(), writeSize));

        this->m_stream->seek(writePosition);
        std::vector<uint8_t> readBuffer(writeSize);
        EXPECT_EQ(writeSize, this->m_stream->read(readBuffer.data(), writeSize));
        for (const auto& byte : readBuffer) {
            ASSERT_EQ(writeSize, byte);
        }

        EXPECT_EQ(writeSize, this->m_stream->read(readBuffer.data(), writeSize));
        for (const auto& byte : readBuffer) {
            ASSERT_EQ(writeSize * 2, byte);
        }
    }
}
