#include "pch.h"
#include "AudioDecoder.hpp"
#include "playercore/platform/windows/WindowsPlatform.hpp"
#include "debug/trace.hpp"
#include "checkhr.hpp"

using namespace twitch;
using namespace twitch::windows;

using namespace Microsoft;
using namespace Microsoft::WRL;

namespace
{
    struct MFT_OUTPUT_DATA_BUFFER_Safe : public MFT_OUTPUT_DATA_BUFFER
    {
        MFT_OUTPUT_DATA_BUFFER_Safe(bool allocateForUs)
            : m_allocateForUs(allocateForUs)
        {
            dwStreamID = 0;
            pSample = nullptr;
            dwStatus = 0;
            pEvents = nullptr;
        }

        ~MFT_OUTPUT_DATA_BUFFER_Safe()
        {
            if (!m_allocateForUs) {
                if (pSample != nullptr) {
                    pSample->Release();
                }

                if (pEvents != nullptr) {
                    pEvents->Release();
                }
            }
        }

        bool m_allocateForUs = false;
    };
}

void AudioDecoder::deriveOutputFromInput(const MediaFormat& input, MediaFormat& output)
{
    const int bitsPerSample = input.hasInt(MediaFormat::Audio_SampleSize) ? input.getInt(MediaFormat::Audio_SampleSize) : 16;
    const int samplesPerSecond = input.hasInt(MediaFormat::Audio_SampleRate) ? input.getInt(MediaFormat::Audio_SampleRate) : 48000;
    const unsigned char numChannels = input.hasInt(MediaFormat::Audio_ChannelCount) ? static_cast<unsigned char>(input.getInt(MediaFormat::Audio_ChannelCount)) : 2;

    const int MinimumSamplesPerSecondOnXB1 = 44100;
    const int OutputBitsPerSample = 16;
    const int OutputSamplesPerSecond = std::max(samplesPerSecond, MinimumSamplesPerSecondOnXB1);
    const int OutputNumChannels = numChannels; // Output has as many channels as input

    output.setInt(MediaFormat::Audio_SampleSize, OutputBitsPerSample);
    output.setInt(MediaFormat::Audio_SampleRate, OutputSamplesPerSecond);
    output.setInt(MediaFormat::Audio_ChannelCount, OutputNumChannels);
}

AudioDecoder::AudioDecoder()
{
    if (FAILED(CoCreateInstance(CLSID_MSAACDecMFT, nullptr, CLSCTX_INPROC_SERVER, IID_IMFTransform, (void**)&m_decoderTransform))) {
        TRACE_ERROR("AudioDecoder(): CoCreateInstance for AAC decoder failed.");
    }
}

MediaResult AudioDecoder::configure(const MediaFormat& input, MediaFormat& output)
{
    const int bitsPerSample = input.hasInt(MediaFormat::Audio_SampleSize) ? input.getInt(MediaFormat::Audio_SampleSize) : 16;
    const int samplesPerSecond = input.hasInt(MediaFormat::Audio_SampleRate) ? input.getInt(MediaFormat::Audio_SampleRate) : 48000;
    const unsigned char numChannels = input.hasInt(MediaFormat::Audio_ChannelCount) ? static_cast<unsigned char>(input.getInt(MediaFormat::Audio_ChannelCount)) : 2;

    deriveOutputFromInput(input, output);

    if (m_bitsPerSample == bitsPerSample && m_samplesPerSecond == samplesPerSecond && m_numChannels == numChannels &&
        m_internalBitsPerSample == bitsPerSample && m_internalNumChannels == numChannels) {
        TRACE_INFO("AudioDecoder(%p)::configure - ignored since settings are the same.", this);
        return MediaResult::Ok;
    }

    TRACE_INFO("AudioDecoder(%p)::configure - configuring decoder.", this);

    m_internalBitsPerSample = m_bitsPerSample = bitsPerSample;
    m_samplesPerSecond = samplesPerSecond;
    m_internalNumChannels = m_numChannels = numChannels;
    assert((m_bitsPerSample % 8) == 0);

    ComPtr<IMFMediaType> pDecInputMediaType;

    const unsigned int payLoadType = 0;
    CHECK_HR(MFCreateMediaType(&pDecInputMediaType), "AudioDecoder - Failed to create the input MediaType");
    CHECK_HR(m_decoderTransform->GetInputAvailableType(0, 0, &pDecInputMediaType), "AudioDecoder - Can't figure out supported AAC input media type");
    CHECK_HR(pDecInputMediaType->SetGUID(MF_MT_MAJOR_TYPE, MFMediaType_Audio), "AudioDecoder - Can't set AAC Input MF_MT_MAJOR_TYPE");
    CHECK_HR(pDecInputMediaType->SetGUID(MF_MT_SUBTYPE, MFAudioFormat_AAC), "AudioDecoder - Can't set AAC Input MF_MT_SUBTYPE");
    CHECK_HR(pDecInputMediaType->SetUINT32(MF_MT_AAC_PAYLOAD_TYPE, payLoadType), "AudioDecoder - Can't set payload type (raw)");
    CHECK_HR(pDecInputMediaType->SetUINT32(MF_MT_AUDIO_BITS_PER_SAMPLE, m_bitsPerSample), "AudioDecoder - Can't set AAC Input MF_MT_AUDIO_BITS_PER_SAMPLE");
    CHECK_HR(pDecInputMediaType->SetUINT32(MF_MT_AUDIO_SAMPLES_PER_SECOND, m_samplesPerSecond), "AudioDecoder - Can't set AAC Input MF_MT_AUDIO_SAMPLES_PER_SECOND");
    CHECK_HR(pDecInputMediaType->SetUINT32(MF_MT_AUDIO_NUM_CHANNELS, m_numChannels), "AudioDecoder - Can't set AAC Input MF_MT_AUDIO_NUM_CHANNELS");

    struct WAVEUSERDATA {
        WORD wPayloadType;
        WORD wAudioProfileLevelIndication;
        WORD wStructType;
        WORD wReserved1;
        DWORD dwReserved2;
        WORD lastTwoBytes;
    };

    WAVEUSERDATA waveUserData;
    waveUserData.wPayloadType = 0;
    waveUserData.wAudioProfileLevelIndication = 0xFE;
    waveUserData.wStructType = 0;
    waveUserData.wReserved1 = 0;
    waveUserData.dwReserved2 = 0;

    const std::vector<uint8_t> esDs = input.hasCodecData(MediaFormat::Audio_AAC_ESDS) ? input.getCodecData(MediaFormat::Audio_AAC_ESDS) : std::vector<uint8_t>();

    if (esDs.size() < 2) {
        TRACE_ERROR("AudioDecoder(%p)::configure - failed - esDs.size() < 2", this);
        return MediaResult(MediaResult::ErrorInvalidParameter, static_cast<int>(esDs.size()));
    }

    waveUserData.lastTwoBytes = (esDs[0]) | (esDs[1] << 8);

    TRACE_DEBUG("Creating MediaType\n"
                "MF_MT_AAC_PAYLOAD_TYPE:%d\n"
                "MF_MT_AUDIO_BITS_PER_SAMPLE:%d\n"
                "MF_MT_AUDIO_SAMPLES_PER_SECOND:%d\n"
                "MF_MT_AUDIO_NUM_CHANNELS:%d\n"
                "esDS[0]:%d\n"
                "esDS[1]:%d",
        static_cast<int>(payLoadType),
        static_cast<int>(m_bitsPerSample),
        static_cast<int>(m_samplesPerSecond),
        static_cast<int>(m_numChannels),
        static_cast<int>(esDs[0]),
        static_cast<int>(esDs[1]));

    CHECK_HR(pDecInputMediaType->SetBlob(MF_MT_USER_DATA, reinterpret_cast<UINT8*>(&waveUserData), 14), "Can't set AAC Input MF_MT_USERDATA");

    CHECK_HR(m_decoderTransform->SetInputType(0, pDecInputMediaType.Get(), 0), "Failed to set input media type on AAC decoder MFT.");

    Microsoft::WRL::ComPtr<IMFMediaType> outputMediaType;
    MFCreateMediaType(&outputMediaType);
    CHECK_HR(pDecInputMediaType->CopyAllItems(outputMediaType.Get()), "Failed to copy the input MediaType to output MediaType");

    const int OutputBitsPerSample = m_bitsPerSample;
    const int OutputSamplesPerSecond = output.getInt(MediaFormat::Audio_SampleRate);
    const int OutputNumChannels = output.getInt(MediaFormat::Audio_ChannelCount);

    UINT32 blockAlign = OutputNumChannels * (OutputBitsPerSample / 8);
    UINT32 bytesPerSecond = blockAlign * OutputSamplesPerSecond;

    CHECK_HR(outputMediaType->SetGUID(MF_MT_MAJOR_TYPE, MFMediaType_Audio), "AudioDecoder - Can't set AAC Output MF_MT_MAJOR_TYPE");
    CHECK_HR(outputMediaType->SetGUID(MF_MT_SUBTYPE, MFAudioFormat_PCM), "AudioDecoder - Can't set AAC Output MF_MT_SUBTYPE");
    CHECK_HR(outputMediaType->SetUINT32(MF_MT_AUDIO_BITS_PER_SAMPLE, OutputBitsPerSample), "AudioDecoder - Can't set AAC Output MF_MT_AUDIO_BITS_PER_SAMPLE");
    CHECK_HR(outputMediaType->SetUINT32(MF_MT_AUDIO_SAMPLES_PER_SECOND, OutputSamplesPerSecond), "AudioDecoder - Can't set AAC Output MF_MT_AUDIO_SAMPLES_PER_SECOND");
    CHECK_HR(outputMediaType->SetUINT32(MF_MT_AUDIO_NUM_CHANNELS, OutputNumChannels), "AudioDecoder - Can't set AAC Output MF_MT_AUDIO_NUM_CHANNELS");

    CHECK_HR(outputMediaType->SetUINT32(MF_MT_AUDIO_BLOCK_ALIGNMENT, blockAlign), "AudioDecoder - Can't set AAC Output MF_MT_AUDIO_BLOCK_ALIGNMENT");
    CHECK_HR(outputMediaType->SetUINT32(MF_MT_AUDIO_AVG_BYTES_PER_SECOND, bytesPerSecond), "AudioDecoder - Can't set AAC Output MF_MT_AUDIO_AVG_BYTES_PER_SECOND");
    CHECK_HR(outputMediaType->SetUINT32(MF_MT_ALL_SAMPLES_INDEPENDENT, TRUE), "AudioDecoder - Can't set AAC Output MF_MT_ALL_SAMPLES_INDEPENDENT");

    CHECK_HR(m_decoderTransform->SetOutputType(0, outputMediaType.Get(), 0), "AudioDecoder - Failed to set output media type on AAC decoder MFT.");

    DWORD mftStatus;
    CHECK_HR(m_decoderTransform->GetInputStatus(0, &mftStatus), "Failed to get input status from AAC decoder MFT.");

    if (MFT_INPUT_STATUS_ACCEPT_DATA != mftStatus) {
        TRACE_ERROR("AudioDecoder(%p)::configure - AAC decoder MFT is not accepting data.", this);
        return MediaResult(MediaResult::ErrorNotSupported, mftStatus);
    }

    CHECK_HR(m_decoderTransform->ProcessMessage(MFT_MESSAGE_COMMAND_FLUSH, NULL), "AudioDecoder - Failed to process FLUSH command on AAC decoder MFT.");
    CHECK_HR(m_decoderTransform->ProcessMessage(MFT_MESSAGE_NOTIFY_BEGIN_STREAMING, NULL), "AudioDecoder - Failed to process BEGIN_STREAMING command on AAC decoder MFT.");
    CHECK_HR(m_decoderTransform->ProcessMessage(MFT_MESSAGE_NOTIFY_START_OF_STREAM, NULL), "AudioDecoder - Failed to process START_OF_STREAM command on AAC decoder MFT.");

    TRACE_DEBUG("AudioDecoder(%p)::configure - complete!", this);

    return MediaResult::Ok;
}

MediaResult AudioDecoder::processInput(const twitch::MediaSampleBuffer& input)
{
    DWORD inputStatus;
    CHECK_HR(m_decoderTransform->GetInputStatus(0, &inputStatus), "AudioDecoder::processInput - GetInputStatus failed.");

    DWORD outputStatus;
    CHECK_HR(m_decoderTransform->GetOutputStatus(&outputStatus), "AudioDecoder::processInput - GetOutputStatus failed.");

    ComPtr<IMFSample> inputSample;
    CHECK_HR(MFCreateSample(&inputSample), "AudioDecoder::processInput - Cannot create input sample for processing");

    // Sample time is in 100-nanoseconds unit. So 0.1 microsecond, or 0.0001 ms
    using namespace std::chrono;
    LONGLONG sampleTimeInHundredsNano = input.decodeTime.nanoseconds().count() / 100;
    LONGLONG sampleDurationInHundredsNano = input.duration.nanoseconds().count() / 100;

    CHECK_HR(inputSample->SetSampleTime(sampleTimeInHundredsNano), "AudioDecoder::processInput - Error setting the input video sample time.");
    CHECK_HR(inputSample->SetSampleDuration(sampleDurationInHundredsNano), "AudioDecoder::processInput - Error setting input video sample duration.");

    ComPtr<IMFMediaBuffer> inputBuffer;
    CHECK_HR(MFCreateMemoryBuffer(static_cast<DWORD>(input.buffer.size()), &inputBuffer), "AudioDecoder::processInput - Failed to create memory buffer.");
    CHECK_HR(inputSample->AddBuffer(inputBuffer.Get()), "AudioDecoder::processInput - Failed to add buffer to re-constructed sample.");

    BYTE* inputByteBuffer;
    DWORD inputBufferLength = 0;
    DWORD inputBufferMaxLength = 0;
    CHECK_HR(inputBuffer->Lock(&inputByteBuffer, &inputBufferMaxLength, &inputBufferLength), "AudioDecoder::processInput - Error locking input buffer.");
    memcpy(inputByteBuffer, input.buffer.data(), input.buffer.size());
    CHECK_HR(inputBuffer->Unlock(), "AudioDecoder::processInput - Error unlocking input buffer.");
    inputBuffer->SetCurrentLength(static_cast<DWORD>(input.buffer.size()));

    HRESULT hrInput = m_decoderTransform->ProcessInput(0, inputSample.Get(), 0);

    if (hrInput == MF_E_NOTACCEPTING) {
        return MediaResult::Ok; // TODO: add to an internal buffer to reprocess the input
    } else if (FAILED(hrInput)) {
        WindowsPlatform::hError("AudioDecoder::processInput - IMFTransform::ProcessInput failed", hrInput);
        return MediaResult(MediaResult::ErrorInvalidData, hrInput);
    }

    return MediaResult::Ok;
}

MediaResult AudioDecoder::processOutput(HRESULT& hr, ProcessType pt)
{
    MFT_OUTPUT_STREAM_INFO outputStreamInfo;
    CHECK_HR(m_decoderTransform->GetOutputStreamInfo(0, &outputStreamInfo), "Can't figure out the output stream info type");

    // Will the Transform allocate the output sample for us ?
    bool allocateForUs = (outputStreamInfo.dwFlags & MFT_OUTPUT_STREAM_PROVIDES_SAMPLES) || (outputStreamInfo.dwFlags & MFT_OUTPUT_STREAM_CAN_PROVIDE_SAMPLES);

    // Check if we got any output
    MFT_OUTPUT_DATA_BUFFER_Safe outputDataBuffer(allocateForUs);

    if (!allocateForUs) {
        ComPtr<IMFMediaBuffer> outputBuffer;

        CHECK_HR(MFCreateSample(&outputDataBuffer.pSample), "Failed to create MF sample.");
        CHECK_HR(MFCreateMemoryBuffer(outputStreamInfo.cbSize, &outputBuffer), "Failed to create output memory buffer.");
        CHECK_HR(outputDataBuffer.pSample->AddBuffer(outputBuffer.Get()), "Failed to add sample to buffer.");
    }

    DWORD processOutputStatus = 0;
    hr = m_decoderTransform->ProcessOutput(0, 1, &outputDataBuffer, &processOutputStatus);

    if (hr == MF_E_TRANSFORM_NEED_MORE_INPUT) {
        return MediaResult::Ok;
    }
    else if (hr == MF_E_TRANSFORM_STREAM_CHANGE) {
        CHECK_HR(m_decoderTransform->ProcessMessage(MFT_MESSAGE_COMMAND_FLUSH, NULL), "AudioDecoder - Failed to process FLUSH command on AAC decoder MFT.");
        CHECK_HR(m_decoderTransform->ProcessMessage(MFT_MESSAGE_COMMAND_DRAIN, NULL), "Failed to process DRAIN command on AAC decoder MFT.");

        Microsoft::WRL::ComPtr<IMFMediaType> outputMediaType;
        CHECK_HR(m_decoderTransform->GetOutputAvailableType(0, 0, &outputMediaType), "AudioDecoder - Failed to get available input types");
        CHECK_HR(m_decoderTransform->SetOutputType(0, outputMediaType.Get(), 0), "AudioDecoder - Failed to set output media type");

        UINT32 OutputBitsPerSample;
        UINT32 OutputNumChannels;
        CHECK_HR(outputMediaType->GetUINT32(MF_MT_AUDIO_BITS_PER_SAMPLE, &OutputBitsPerSample), "AudioDecoder - Can't set AAC Output MF_MT_AUDIO_BITS_PER_SAMPLE");
        CHECK_HR(outputMediaType->GetUINT32(MF_MT_AUDIO_NUM_CHANNELS, &OutputNumChannels), "AudioDecoder - Can't set AAC Output MF_MT_AUDIO_NUM_CHANNELS");
        m_internalBitsPerSample = static_cast<int>(OutputBitsPerSample);
        m_internalNumChannels = static_cast<unsigned char>(OutputNumChannels);

        // Retry to process output with new output MediaType
        return processOutput(hr, pt);
    }
    else if (FAILED(hr)) {
        WindowsPlatform::hError("ProcessOutput failed", hr);
        return MediaResult(MediaResult::ErrorInvalidData, hr);
    }

    handleOutputDataBuffer(outputDataBuffer, pt);

    return MediaResult::Ok;
}

void AudioDecoder::handleOutputDataBuffer(MFT_OUTPUT_DATA_BUFFER& outputDataBuffer, ProcessType pt)
{
    if (outputDataBuffer.dwStatus & MFT_OUTPUT_DATA_BUFFER_INCOMPLETE) {
        TRACE_WARN_ONCE("ProcessOutput set status to MFT_OUTPUT_DATA_BUFFER_INCOMPLETE - we ignore subsequent samples. Your output may be missing frames");
    }

    if ((outputDataBuffer.dwStatus & MFT_OUTPUT_DATA_BUFFER_STREAM_END)) {
        TRACE_DEBUG("AudioDecoder - Signaled MFT_OUTPUT_DATA_BUFFER_STREAM_END");
    }

    if ((outputDataBuffer.dwStatus & MFT_PROCESS_OUTPUT_STATUS_NEW_STREAMS)) {
        TRACE_DEBUG("AudioDecoder - Signaled MFT_PROCESS_OUTPUT_STATUS_NEW_STREAMS");
    }

    if ((outputDataBuffer.dwStatus & MFT_OUTPUT_DATA_BUFFER_FORMAT_CHANGE)) {
        TRACE_DEBUG("AudioDecoder - Signaled MFT_OUTPUT_DATA_BUFFER_FORMAT_CHANGE");
    }

    if (!(pt & ProcessType::Discard)) {
        associateOutput(outputDataBuffer.pSample);
    }
}

void AudioDecoder::associateOutput(IMFSample* sample)
{
    m_outputSamples.push_back(std::make_shared<AudioSample>(sample, m_internalBitsPerSample, m_internalNumChannels, m_bitsPerSample, m_numChannels));
}

MediaResult AudioDecoder::decode(const twitch::MediaSampleBuffer& input)
{
    MediaResult inputResult = processInput(input);
    if (inputResult != MediaResult::Ok) {
        return inputResult;
    }

    return MediaResult::Ok;
}

MediaResult AudioDecoder::hasOutput(bool& hasOutput)
{
    HRESULT hr;
    MediaResult outputResult = processOutput(hr);
    hasOutput = !m_outputSamples.empty();
    return outputResult;
}

MediaResult AudioDecoder::getOutput(std::shared_ptr<twitch::MediaSample>& output)
{
    if (m_outputSamples.empty()) {
        return MediaResult::ErrorInvalidState;
    } else {
        output = m_outputSamples.front();
        m_outputSamples.pop_front();
        return MediaResult::Ok;
    }
}

MediaResult AudioDecoder::flush()
{
    CHECK_HR(m_decoderTransform->ProcessMessage(MFT_MESSAGE_COMMAND_DRAIN, NULL), "Failed to process DRAIN command on AAC decoder MFT.");

    HRESULT hr;
    MediaResult outputResult;
    while ((outputResult = processOutput(hr)) == MediaResult::Ok) {
        if (hr == MF_E_TRANSFORM_NEED_MORE_INPUT) {
            break;
        }
    }

    return outputResult;
}

MediaResult AudioDecoder::reset()
{
    CHECK_HR(m_decoderTransform->ProcessMessage(MFT_MESSAGE_COMMAND_FLUSH, NULL), "Failed to process FLUSH command on AAC decoder MFT.");
    return MediaResult::Ok;
}
