#include "Mp2TReader.hpp"
#include "../aac/aacutil.hpp"
#include "../avc/avcutil.hpp"
#include "media/SourceFormat.hpp"

namespace twitch {
namespace media {
Mp2TReader::Mp2TReader(const Platform& platform, MediaReader::Listener& listener, const std::string& path)
    : m_path(path)
    , m_listener(listener)
    , m_avcFormat(platform.getCapabilities().avcFormat)
    , m_log(platform.getLog())
{
    m_parser.reset(new TransportStream(*this));
    m_captions.reset(new CEACaptions([this](const std::shared_ptr<MediaSampleBuffer>& frame) {
        TrackId trackId = MetaTrackId;
        if (m_formats.count(trackId) == 0) {
            auto format = std::make_shared<SourceFormat>(MediaType::Text_Json);
            format->setPath(m_path);
            m_formats[trackId] = format;
            m_listener.onMediaTrack(trackId, format);
            frame->isDiscontinuity = true;
        }
        m_listener.onMediaSample(trackId, frame);
    }));
}

Mp2TReader::~Mp2TReader()
{
    m_parser.reset();
}

MediaReader::TrackId getTrackIdForStream(uint8_t type)
{
    MediaReader::TrackId trackId;
    switch (type) {
    default:
    case ElementaryStream::TypeAAC:
        trackId = MediaReader::AudioTrackId;
        break;
    case ElementaryStream::TypeAVC:
        trackId = MediaReader::VideoTrackId;
        break;
    case ElementaryStream::TypeID3:
        trackId = MediaReader::MetaTrackId;
        break;
    }
    return trackId;
}

void Mp2TReader::onElementaryDiscontinuity(uint8_t type)
{
    TrackId trackId = getTrackIdForStream(type);
    m_formats.erase(trackId); // recreate the format
    m_trackSampleCounts[trackId] = 0;
}

void Mp2TReader::onElementarySample(uint8_t type, const std::shared_ptr<MediaSampleBuffer>& sample)
{
    TrackId trackId = getTrackIdForStream(type);

    switch (type) {
    case ElementaryStream::TypeAAC: {
        if (m_formats.count(trackId) == 0) {
            createAACFormat(m_parser->audioExtradata());
        }
    } break;
    case ElementaryStream::TypeAVC: {
        if (m_formats.count(trackId) == 0) {
            createAVCFormat(m_parser->videoExtradata());
        }

        if (m_captions) {
            m_captions->fromMediaSampleBuffer(sample);
        }

        // TODO Keep frame as is and convert to annexB in platform when necessary
        if (m_avcFormat == AVCFormatType::AnnexB) {
            auto& avcc = m_formats[trackId]->getCodecData(MediaFormat::Video_AVC_AVCC);
            sample->buffer = AVCParser::toAnnexB(sample->buffer, avcc);
        }
    } break;
    case ElementaryStream::TypeID3: {
        if (m_formats.count(trackId) == 0) {
            auto format = std::make_shared<SourceFormat>(MediaType::Text_Json);
            format->setPath(m_path);
            m_formats[trackId] = format;
            m_listener.onMediaTrack(trackId, format);
        }
    } break;
    default:
        m_log->error("Received unknown frame type %d", type);
        return;
    }

    if (m_trackSampleCounts[trackId] == 0) {
        sample->isDiscontinuity = true;
    }
    m_trackSampleCounts[trackId]++;
    m_listener.onMediaSample(trackId, sample);
}

void Mp2TReader::createAACFormat(const std::vector<uint8_t>& extradata)
{
    adts_t adts {};
    int channels;
    int sampleRate;

    if (aac_parse_extradata(&adts, extradata.data(), extradata.size())) {
        channels = static_cast<int16_t>(adts.channelConfig);
        sampleRate = aac_frequency(&adts);
    } else {
        m_listener.onMediaError(Error(ErrorSource::Source, MediaResult::Error, "Failed to parse AAC extra data"));
        return;
    }

    std::shared_ptr<SourceFormat> format = SourceFormat::createAudioFormat(MediaType::Audio_AAC, channels, sampleRate, 16);
    format->setCodecData(MediaFormat::Audio_AAC_ESDS, extradata);
    format->setPath(m_path);
    m_formats[AudioTrackId] = format;
    m_listener.onMediaTrack(AudioTrackId, format);
}

void Mp2TReader::createAVCFormat(const std::vector<uint8_t>& extradata)
{
    auto ed = AVCParser::parseExtradata(extradata);
    if (ed.sps.empty() || ed.pps.empty()) {
        m_listener.onMediaError(Error(ErrorSource::Source, MediaResult::Error, "Failed to parse AVC extra data"));
        return;
    }

    const auto& res = AVCParser::parseSps(ed.sps[0]).resolution();
    std::shared_ptr<SourceFormat> format = SourceFormat::createVideoFormat(MediaType::Video_AVC, res.first, res.second);
    //m_log->info("AVC profile %d, level %d NAL length %d", ed.profile, ed.level, ed.lengthSize);

    format->setInt(MediaFormat::Video_AVC_NAL_LengthSize, ed.lengthSize);
    format->setInt(MediaFormat::Video_AVC_Profile, ed.profile);
    format->setInt(MediaFormat::Video_AVC_Level, ed.level);
    format->setCodecData(MediaFormat::Video_AVC_SPS, ed.sps[0]);
    format->setCodecData(MediaFormat::Video_AVC_PPS, ed.pps[0]);
    format->setCodecData(MediaFormat::Video_AVC_AVCC, extradata);
    m_formats[VideoTrackId] = format;
    m_listener.onMediaTrack(VideoTrackId, format);
}

void Mp2TReader::seekTo(MediaTime time)
{
    m_parser->seek(time.scaleTo(TSTimeScale).count());
    m_formats.clear();
    m_trackSampleCounts.clear();
}

void Mp2TReader::addData(const uint8_t* data, size_t size, bool endOfStream)
{
    (void)endOfStream;
    m_parser->addData(data, size);
    m_listener.onMediaFlush();
}

void Mp2TReader::onDiscontinuity(uint32_t flags)
{
    (void)flags;
    reset();
}

void Mp2TReader::reset()
{
    m_parser->finish();
    m_formats.clear();
    m_trackSampleCounts.clear();
}

void Mp2TReader::setStream(std::unique_ptr<Stream> stream)
{
    m_stream = std::move(stream);
}

void Mp2TReader::readSamples(MediaTime duration)
{
    if (!m_stream) {
        m_listener.onMediaError(Error(ErrorSource::Source, MediaResult::ErrorInvalidState, "No stream to read"));
        return;
    }

    MediaTime startDuration = getDuration();
    const int ReadBufferSize = 16 * 1024;
    uint8_t buffer[ReadBufferSize];

    while (duration > (getDuration() - startDuration)) {
        int64_t read = m_stream->read(buffer, ReadBufferSize);
        if (read == -1) {
            m_listener.onMediaError(Error(ErrorSource::Source, MediaResult::ErrorInvalidData, "Error reading TS"));
            return;
        } else if (read == 0) {
            m_listener.onMediaDurationChanged(getDuration());
            m_listener.onMediaEndOfStream();
            return;
        }
        m_parser->addData(buffer, static_cast<size_t>(read));
    }

    m_listener.onMediaFlush();
}

MediaTime Mp2TReader::getDuration() const
{
    return MediaTime(m_parser->duration(), TSTimeScale);
}

std::shared_ptr<const MediaFormat> Mp2TReader::getTrackFormat(TrackId id)
{
    return m_formats[id];
}
}
}
