#include "pch.h"
#include "AudioSample.hpp"
#include "debug/trace.hpp"
#include "playercore/platform/windows/WindowsPlatform.hpp"
#include "MediaSample.hpp"

using namespace twitch;
using namespace windows;
using namespace Microsoft;
using namespace Microsoft::WRL;

static_assert(sizeof(AudioSample) == sizeof(twitch::MediaSampleBuffer), "AudioSample must be same size as MediaSample since the structure will be casted");

namespace
{
    void DuplicateMono16Bits(const unsigned char * sourceBuffer, unsigned char * destBuffer, int sourceBufferLength)
    {
        const int LANE_SIZE = 16;

        if (sourceBufferLength < LANE_SIZE) {
            TRACE_ERROR("Audio buffer size too small (%d)", sourceBufferLength);
            return;
        }

        while (sourceBufferLength > 0)
        {
            __m128i sample = _mm_loadu_si128((__m128i *)sourceBuffer);

            __m128i s0 = _mm_unpacklo_epi16(sample, sample);
            __m128i s1 = _mm_unpackhi_epi16(sample, sample);

            _mm_store_si128((__m128i *)destBuffer, s0);
            _mm_store_si128((__m128i *)(destBuffer + LANE_SIZE), s1);

            sourceBufferLength -= LANE_SIZE;
            sourceBuffer += LANE_SIZE;
            destBuffer += LANE_SIZE << 1;
        }
    }

    void TransformFloatMonoIntoStereo16Bits(const unsigned char * sourceBuffer, unsigned char * destBuffer, int sourceBufferLength)
    {
        const int LANE_SIZE = 16;
        const __m128 MODULATE = _mm_set_ps1(32767.0f);

        if (sourceBufferLength < LANE_SIZE) {
            TRACE_ERROR("Audio buffer size too small (%d)", sourceBufferLength);
            return;
        }

        while (sourceBufferLength > 0)
        {
            __m128 sample0 = _mm_load_ps((const float*)sourceBuffer);
            sample0 = _mm_mul_ps(sample0, MODULATE);

            __m128i sample1 = _mm_cvtps_epi32(sample0);
            sample1 = _mm_packs_epi32(sample1, sample1);

            _mm_store_si128((__m128i *)destBuffer, _mm_unpacklo_epi16(sample1, sample1));

            sourceBufferLength -= LANE_SIZE;
            sourceBuffer += LANE_SIZE;
            destBuffer += LANE_SIZE;
        }
    }
}

AudioSample::AudioSample(IMFSample* sample,
                         int bitsPerSampleSource, unsigned char numChannelsSource,
                         int bitsPerSampleDest, unsigned char numChannelsDest)
    : MediaSample(sample)
{
    extractBufferAndConvertIfNeeded(sample, bitsPerSampleSource, numChannelsSource, bitsPerSampleDest, numChannelsDest);
}

bool AudioSample::extractBufferAndConvertIfNeeded(IMFSample* sample,
                                                  int bitsPerSampleSource, unsigned char numChannelsSource,
                                                  int bitsPerSampleDest, unsigned char numChannelsDest)
{
    ComPtr<IMFMediaBuffer> mediaBuffer;
    HRESULT hr = sample->ConvertToContiguousBuffer(&mediaBuffer);

    if (FAILED(hr)) {
        WindowsPlatform::hError("ConvertToContiguousBuffer failed", hr);
        return false;
    }

    BYTE* pMediaBuffer;
    DWORD bufferMaxLength;
    DWORD bufferLength;
    hr = mediaBuffer->Lock(&pMediaBuffer, &bufferMaxLength, &bufferLength);

    if (FAILED(hr)) {
        WindowsPlatform::hError("Can't lock buffer", hr);
        return false;
    }

    bool converted = false;

    if ((bitsPerSampleSource != bitsPerSampleDest) || (numChannelsSource != numChannelsDest)) {

        if ((bitsPerSampleDest == 16) && (numChannelsDest == 2)) {
            if (bitsPerSampleSource == 32) {
                buffer.resize(bufferLength);
                TransformFloatMonoIntoStereo16Bits(pMediaBuffer, buffer.data(), bufferLength);
                converted = true;
            } else if (bitsPerSampleSource == 16 && numChannelsSource == 1) {
                buffer.resize(2 * bufferLength);
                DuplicateMono16Bits(pMediaBuffer, buffer.data(), bufferLength);
                converted = true;
            }
        }

        if (!converted) {
            TRACE_WARN_ONCE("AudioSample needs conversion: %d/%d to %d/%d", bitsPerSampleSource, numChannelsSource, bitsPerSampleDest, numChannelsDest);
        }
    }

    if (!converted) {
        buffer.resize(bufferLength);
        memcpy(buffer.data(), pMediaBuffer, bufferLength);
    }

    mediaBuffer->Unlock();

    return true;
}
