#include "AudioDecoder.hpp"

#include "AudioRenderer.hpp"
#include "AudioResampler.hpp"
#include "debug/trace.hpp"
#include "playercore/platform/ps4/PS4Platform.hpp"
#include "Stream.hpp"

#include <libsysmodule.h>

#include <cassert>
#include <iomanip>
#include <iostream>
#include <sstream>
#include <string>

namespace {
const float kEpsilon = 0.01f;

// Return true if two floats are equal within a epsilon tolerance
bool equalWithEpsilon(float first, float second, float epsilon)
{
    return fabs(first - second) <= epsilon;
}
} // namespace

using namespace twitch;
using namespace twitch::ps4;

AudioDecoder::AudioDecoder()
    : m_audioHandle(-1)
    , m_sampleRate(0)
    , m_numChannels(0)
    , m_inputChannels(0)
{
}

AudioDecoder::~AudioDecoder()
{
    deleteDecoder();
}

void AudioDecoder::deleteDecoder()
{
    if (m_audioHandle >= 0) {
        sceAudiodecDeleteDecoder(m_audioHandle);
        m_audioHandle = -1;
    }
}

bool AudioDecoder::isValidFrequency(int freq)
{
    return frequencyToIndex(freq) != -1;
}

int AudioDecoder::frequencyToIndex(int freq)
{
    int valid[] = { 8000, 11025, 12000, 16000, 22050, 24000, 32000, 44100, 48000, 64000, 88200, 96000 };
    int index[] = {
        SCE_AUDIODEC_M4AAC_SAMPLING_FREQ_08000,
        SCE_AUDIODEC_M4AAC_SAMPLING_FREQ_11025,
        SCE_AUDIODEC_M4AAC_SAMPLING_FREQ_12000,
        SCE_AUDIODEC_M4AAC_SAMPLING_FREQ_16000,
        SCE_AUDIODEC_M4AAC_SAMPLING_FREQ_22050,
        SCE_AUDIODEC_M4AAC_SAMPLING_FREQ_24000,
        SCE_AUDIODEC_M4AAC_SAMPLING_FREQ_32000,
        SCE_AUDIODEC_M4AAC_SAMPLING_FREQ_44100,
        SCE_AUDIODEC_M4AAC_SAMPLING_FREQ_48000,
        SCE_AUDIODEC_M4AAC_SAMPLING_FREQ_64000,
        SCE_AUDIODEC_M4AAC_SAMPLING_FREQ_88200,
        SCE_AUDIODEC_M4AAC_SAMPLING_FREQ_96000,
    };

    const int sizeOfValid = sizeof(valid) / sizeof(valid[0]);

    static_assert(sizeof(index) / sizeof(index[0]) == sizeOfValid, "Both array must have the same size");

    for (int i = 0; i < sizeOfValid; ++i)
        if (valid[i] == freq) {
            return index[i];
        }

    return -1;
}

MediaResult AudioDecoder::configure(const MediaFormat& input, MediaFormat& output)
{
    if (!input.hasInt(MediaFormat::Audio_ChannelCount)) {
        TRACE_ERROR("AudioDecoder - Missing Audio_ChannelCount attribute");
        return MediaResult::ErrorInvalidParameter;
    }

    if (!input.hasInt(MediaFormat::Audio_SampleRate)) {
        TRACE_ERROR("AudioDecoder - Missing Audio_SampleRate attribute");
        return MediaResult::ErrorInvalidParameter;
    }

    if (!input.hasInt(MediaFormat::Audio_SampleSize)) {
        TRACE_ERROR("AudioDecoder - Missing Audio_SampleSize attribute");
        return MediaResult::ErrorInvalidParameter;
    }

    uint32_t numChannels = input.getInt(MediaFormat::Audio_ChannelCount);
    uint32_t sampleRate = input.getInt(MediaFormat::Audio_SampleRate);
    uint32_t sampleSize = input.getInt(MediaFormat::Audio_SampleSize);

    TRACE_DEBUG("AudioDecoder wants to configure itself with %d channels, %d rate and %d bytes sample size", numChannels,
        sampleRate, sampleSize);

    if (!numChannels || numChannels > SCE_AUDIODEC_M4AAC_MAX_CHANNELS_FOR_8CH) {
        return MediaResult(MediaResult::ErrorNotSupported, numChannels);
    }

    m_inputChannels = numChannels;
    switch (numChannels) {
    case 1:
    case 2:
        break;
    case 3:
    case 4:
    case 5:
    case 6:
        numChannels = 6;
        break;
    case 7:
    case 8:
        numChannels = 8;
        break;
    default:
        assert(0);
    }

    if (!isValidFrequency(sampleRate)) {
        return MediaResult(MediaResult::ErrorNotSupported, sampleRate);
    }

    output.setInt(MediaFormat::Audio_ChannelCount, numChannels);
    output.setInt(MediaFormat::Audio_SampleRate, AudioRenderer::SupportedOutputSampleRate);
    output.setInt(MediaFormat::Audio_SampleSize, sampleSize);
    output.setType(MediaType::Audio_PCM);

    // is the same ? If so, no need to rebuild a decoder at all !
    if (!shouldReconfigure(sampleRate, numChannels)) {
        return MediaResult::Ok;
    }

    m_sampleRate = sampleRate;
    m_numChannels = numChannels;

    m_param.uiSize = sizeof(m_param);
    m_param.iBwPcm = SCE_AUDIODEC_WORD_SZ_FLOAT;
    m_param.uiConfigNumber = SCE_AUDIODEC_M4AAC_CONFIG_NUMBER_RAW;
    m_param.uiSamplingFreqIndex = frequencyToIndex(sampleRate);
    m_param.uiMaxChannels = numChannels;
    m_param.uiEnableHeaac = SCE_AUDIODEC_M4AAC_HEAAC_DISABLE;
    m_param.uiEnableNondelayOutput = SCE_AUDIODEC_M4AAC_NONDELAY_ENABLE;

    m_info.uiSize = sizeof(m_info);

    m_audiodecCtrl.pParam = reinterpret_cast<void*>(&m_param);
    m_audiodecCtrl.pBsiInfo = reinterpret_cast<void*>(&m_info);
    m_audiodecCtrl.pPcmItem = NULL;

    m_au.uiSize = sizeof(m_au);
    m_audiodecCtrl.pAuInfo = &m_au;

    m_pcm.uiSize = sizeof(m_pcm);
    m_audiodecCtrl.pPcmItem = &m_pcm;

    // delete previous decoder if it exists
    deleteDecoder();
    int ret = sceAudiodecCreateDecoder(&m_audiodecCtrl, SCE_AUDIODEC_TYPE_M4AAC);

    if (ret < 0) {
        return MediaResult(MediaResult::ErrorInvalidParameter, ret);
    }

    m_audioHandle = ret;

    // Check if we need to resample to 48khz
    if (m_sampleRate != AudioRenderer::SupportedOutputSampleRate) {
        m_resampler = AudioResampler::create(m_numChannels, m_sampleRate, AudioRenderer::SupportedOutputSampleRate);
        assert(m_resampler);
        if (!m_resampler) {
            return MediaResult::ErrorNotSupported;
        }
    } else {
        // Otherwise, no resampling is needed
        m_resampler.reset();
    }

    return MediaResult::Ok;
}

MediaResult AudioDecoder::decode(const MediaSampleBuffer& input)
{
    if (m_audioHandle < 0) {
        return MediaResult(MediaResult::ErrorInvalidState, m_audioHandle);
    }

    //TRACE_DEBUG("AudioDecoder::decode: %f with duration: %f isDecodeOnly: %s", input.decodeTime.seconds(),
    //    input.duration.seconds(), input.isDecodeOnly ? "true" : "false");

    m_au.pAuAddr = reinterpret_cast<void*>(const_cast<uint8_t*>(input.buffer.data()));
    m_au.uiAuSize = input.buffer.size();

    MediaSampleBuffer::BufferType sampleBuffer(MaxPcmSize, 0);
    m_pcm.pPcmAddr = sampleBuffer.data();
    m_pcm.uiPcmSize = sampleBuffer.size();

    int ret = sceAudiodecDecode(m_audioHandle, &m_audiodecCtrl);

    if (ret < 0) {
        SceAudiodecM4aacInfo* info = reinterpret_cast<SceAudiodecM4aacInfo*>(m_audiodecCtrl.pBsiInfo);
        fprintf(stderr, "MPEG4-AAC decoder error code: ret: %08x %08x\n", ret, info->iResult);
        PS4Platform::traceError("sceAudiodecDecode failed", ret);
        dumpAACFrameInfo(input);

        if (static_cast<unsigned int>(info->iResult) == SCE_AUDIODEC_M4AAC_RESULT_HEADER_ERROR) {
            TRACE_WARN("Skipping corrupted audio at PTS=%f", input.presentationTime.seconds());

            // Buffer size = input.duration (sec) * sampling rate (samples/sec) * number of channels * 4 bytes/sample
            const auto& duration = input.duration;
            size_t bufferSize = duration.count() * m_sampleRate * m_numChannels * sizeof(float);
            assert(bufferSize % duration.timebase() == 0);
            bufferSize /= duration.timebase();

            // Add empty buffer as silent audio sample
            if (bufferSize != sampleBuffer.size()) {
                sampleBuffer.resize(bufferSize);
            }

            addDecodedOutput(input, sampleBuffer);
            return MediaResult::Ok;
        } else {
            return MediaResult(MediaResult::ErrorInvalidData, ret);
        }
    }

    if (m_pcm.uiPcmSize == 0) {
        return MediaResult::Ok;
    }

    // Check if the sample needs to be trimmed
    if (m_pcm.uiPcmSize != sampleBuffer.size()) {
        sampleBuffer.resize(m_pcm.uiPcmSize);
    }

    // Did we decode the expected amount of data according to the duration ? If not, fix the duration.
    float decodedTime = static_cast<float>(sampleBuffer.size()) / sizeof(float) / m_numChannels / m_sampleRate;
    if (!equalWithEpsilon(decodedTime, input.duration.seconds(), kEpsilon)) {
        bool success = fixupSampleBuffer(decodedTime, input.duration.seconds(), sampleBuffer);
        TRACE_ERROR("AudioDecoder::decode expected sample of size %f seconds but got %f seconds. fixUp: %s",
            input.duration.seconds(), decodedTime, success ? "true" : "false");

        // Want to catch these other edge cases where we fail at fixing it.
        assert(success);

        // Confirm fixup worked and got expected value
        decodedTime = static_cast<float>(sampleBuffer.size()) / sizeof(float) / m_numChannels / m_sampleRate;
        assert(equalWithEpsilon(decodedTime, input.duration.seconds(), kEpsilon));

        // Testing code - Assuming previous fixUp was from Mono to Stereo. This converts back to Mono, and then
        // back to Stereo again. Audio should still be good.
        // Force back to mono
        //m_numChannels = 1;
        //decodedTime = static_cast<float>(sampleBuffer.size()) / sizeof(float) / m_numChannels / m_sampleRate;
        //fixupSampleBuffer(decodedTime, input.duration.seconds(), sampleBuffer);

        //// And back to stereo once more
        //m_numChannels = 2;
        //decodedTime = static_cast<float>(sampleBuffer.size()) / sizeof(float) / m_numChannels / m_sampleRate;
        //fixupSampleBuffer(decodedTime, input.duration.seconds(), sampleBuffer);
    }

    addDecodedOutput(input, sampleBuffer);
    return MediaResult::Ok;
}

/*
 * Fix up the sample buffer in case we decoded more, or less than expected. Within reason. Return true if the buffer was fixed.
 */
bool AudioDecoder::fixupSampleBuffer(
    float decodedTime, float expectedTime, MediaSampleBuffer::BufferType& sampleBuffer)
{
    if (sampleBuffer.empty()) {
        return false;
    }

    // Note: This assumes that the AudioRenderer is always configured with the same number of channels as the AudioDecoder
    // If this is not the case, then we will need to actually read 'm_numChannels' from the AudioRenderer

    // We expect Stereo output, but the decoded buffer is exactly 50% of the expectation. So it's likely we got mono instead. Double the buffer to get stereo back.
    if (m_numChannels == 2) {
        if (equalWithEpsilon(expectedTime / decodedTime, 2.f, kEpsilon)) {
            sampleBuffer.resize(sampleBuffer.size() * 2);

            // This code assumes the data is stored as floats
            assert(m_param.iBwPcm == SCE_AUDIODEC_WORD_SZ_FLOAT);

            float* bufferAsFloat = reinterpret_cast<float*>(&(sampleBuffer[0]));

            int sourceIndex = sampleBuffer.size() / sizeof(float) / 2 - 1;
            int destIndex = sampleBuffer.size() / sizeof(float) - 1;

            for (; sourceIndex >= 0; sourceIndex--, destIndex -= 2) {
                bufferAsFloat[destIndex] = bufferAsFloat[sourceIndex];
                bufferAsFloat[destIndex - 1] = bufferAsFloat[sourceIndex];
            }

        } else {
            TRACE_DEBUG(
                "AudioDecoder::fixupSampleBuffer. Don't know how to fix decodedTime: %f seconds to expectedTime %f seconds given stereo output.",
                decodedTime, expectedTime);
            return false;
        }

        // We expect Mono output, but the decoded buffer is exactly 200% of the expectation. So it's likely we got stereo instead. Shrink the buffer by half to get mono back.
        // Requires a Fourier transform to do properly, but it's likely the Stereo stream has already been duplicated (it was Mono originally), so running this should
        // be fine since both channels are probably equal (i.e.: bufferAsFloat[sourceIndex] == bufferAsFloat[sourceIndex+1])
    } else if (m_numChannels == 1) {
        if (equalWithEpsilon(expectedTime / decodedTime, 0.5f, kEpsilon)) {

            // This code assumes the data is stored as floats
            assert(m_param.iBwPcm == SCE_AUDIODEC_WORD_SZ_FLOAT);

            float* bufferAsFloat = reinterpret_cast<float*>(&(sampleBuffer[0]));
            int sourceIndex = 0;
            int destIndex = 0;

            int endDestIndex = sampleBuffer.size() / sizeof(float) / 2;

            for (; destIndex < endDestIndex; destIndex++, sourceIndex += 2) {
                bufferAsFloat[destIndex] = bufferAsFloat[sourceIndex] * 0.5f + bufferAsFloat[sourceIndex + 1] * 0.5f;
            }

            sampleBuffer.resize(sampleBuffer.size() / 2);

        } else {
            TRACE_DEBUG(
                "AudioDecoder::fixupSampleBuffer. Don't know how to fix decodedTime: %f seconds to expectedTime %f seconds given mono output.",
                decodedTime, expectedTime);
            return false;
        }
    }

    TRACE_DEBUG("AudioDecoder:fixUpSampleBuffer. Don't know how to fix decodedTime: %f seconds to expectedTime %f seconds given numChannels:%d",
        decodedTime, expectedTime, m_numChannels);
    return false;
}

MediaResult AudioDecoder::hasOutput(bool& hasOutput)
{
    hasOutput = !m_outputSamples.empty();
    return MediaResult::Ok;
}

MediaResult AudioDecoder::getOutput(std::shared_ptr<MediaSample>& output)
{
    if (m_outputSamples.empty()) {
        return MediaResult::ErrorInvalidState;
    } else {
        output = m_outputSamples.front();
        m_outputSamples.pop();
        return MediaResult::Ok;
    }
}

void AudioDecoder::dumpAACFrameInfo(const MediaSampleBuffer& input) const
{
    std::stringstream ss;

    ss << "Decoding AAC frame (PTS=" << input.presentationTime.seconds() << ") of size " << input.buffer.size()
       << "with payload: ";

    bool bFirst = true;

    for (uint32_t i = 0; i < input.buffer.size(); ++i) {
        if (bFirst) {
            bFirst = false;
        } else {
            ss << ", ";
        }

        unsigned char* data = reinterpret_cast<unsigned char*>(const_cast<uint8_t*>(input.buffer.data()));
        ss << std::setfill('0') << std::setw(2) << std::hex << static_cast<int>(data[i]);
        fprintf(stderr, "0x%02x", data[i]);
    }

    ss << std::endl;
    TRACE_DEBUG("%s", ss.str().c_str());
}

MediaResult AudioDecoder::flush()
{
    if (m_resampler) {
        m_resampler->flush();
    }

    return MediaResult::Ok;
}

MediaResult AudioDecoder::reset()
{
    if (m_audioHandle >= 0) {
        sceAudiodecClearContext(m_audioHandle);
    }

    return MediaResult::Ok;
}

void AudioDecoder::addDecodedOutput(const MediaSampleBuffer& input, MediaSampleBuffer::BufferType& sampleBuffer)
{
    // Create output sample and resample to its buffer
    if (m_resampler) {
        AudioResampler::BufferType resampledBuffer;
        m_resampler->resample(resampledBuffer, sampleBuffer.data(), sampleBuffer.size());
        sampleBuffer = std::move(resampledBuffer);
    }

    // To increase audio clock granularity, send the sample buffer in smaller chunks
    size_t numSubsamples = 1;
    if (sampleBuffer.size() % 4 == 0) {
        numSubsamples = 4;
    } else if (sampleBuffer.size() % 2 == 0) {
        numSubsamples = 2;
    }

    //TRACE_DEBUG("AudioDecoder::addDecodedOutput: presentationTime: %f, total duration: %f with %d subsamples",
    //    input.presentationTime.seconds(), input.duration.seconds(), numSubsamples);

    const size_t subsampleSize = sampleBuffer.size() / numSubsamples;
    for (size_t i = 0; i < numSubsamples; i++) {
        auto ptsOffset = input.duration * static_cast<double>(i) / static_cast<double>(numSubsamples);

        auto outputSample = std::make_shared<MediaSampleBuffer>();
        outputSample->presentationTime = input.presentationTime + ptsOffset;
        outputSample->buffer = MediaSampleBuffer::BufferType(
            sampleBuffer.begin() + i * subsampleSize, sampleBuffer.begin() + (i + 1) * subsampleSize);
        outputSample->isDecodeOnly = input.isDecodeOnly;
        outputSample->isSyncSample = input.isSyncSample;
        outputSample->isDiscontinuity = input.isDiscontinuity;
        outputSample->duration = input.duration / static_cast<double>(numSubsamples);
        m_outputSamples.push(outputSample);

        //TRACE_DEBUG("AudioDecoder::addDecodedOutput[%d/%d], presentationTime: %f, duration: %f", i, numSubsamples,
        //    outputSample->presentationTime.seconds(), outputSample->duration.seconds());
    }
}

bool AudioDecoder::shouldReconfigure(uint32_t sampleRate, uint32_t numChannels) const
{
    if (sampleRate != m_sampleRate || numChannels != m_numChannels) {
        return true;
    }

    if (m_audioHandle < 0) {
        return true;
    }

    return false;
}
