#include "MediaPlaylist.hpp"
#include "PlaylistParser.hpp"
#include "debug/trace.hpp"
#include "util/Base64.hpp"
#include "util/Hex.hpp"
#include <algorithm>
#include <cstdlib>
#include <ctime>

namespace twitch {
namespace hls {
const static std::string PREFIX = "#";
const static std::string NEWLINE = "\n";
const static std::string EXTM3U = "EXTM3U";
const static std::string EXTINF = "EXTINF:";
const static std::string EXT_X_VERSION = "EXT-X-VERSION:";
const static std::string EXT_X_TARGETDURATION = "EXT-X-TARGETDURATION:";
const static std::string EXT_X_MEDIA_SEQUENCE = "EXT-X-MEDIA-SEQUENCE:";
const static std::string EXT_X_PLAYLIST_TYPE = "EXT-X-PLAYLIST-TYPE:";
const static std::string EXT_X_PROGRAM_DATE_TIME = "EXT-X-PROGRAM-DATE-TIME:";
const static std::string EXT_X_DATERANGE = "EXT-X-DATERANGE:";
const static std::string EXT_X_DISCONTINUITY = "EXT-X-DISCONTINUITY";
const static std::string EXT_X_ENDLIST = "EXT-X-ENDLIST";
const static std::string EXT_X_BYTERANGE = "EXT-X-BYTERANGE:";
const static std::string EXT_X_KEY = "EXT-X-KEY:";
const static std::string EXT_X_MAP = "EXT-X-MAP:";
const static std::string EXT_X_START = "EXT-X-START:";
const static std::string EXT_X_INDEPENDENT_SEGMENTS = "EXT-X-INDEPENDENT-SEGMENTS";

// non standard tags
const static std::string EXT_X_TWITCH_ELAPSED_SECS = "EXT-X-TWITCH-ELAPSED-SECS:";
const static std::string EXT_X_TWITCH_PREFETCH = "EXT-X-TWITCH-PREFETCH:";

// Playlist types
const std::string MediaPlaylist::TypeEvent = "EVENT";
const std::string MediaPlaylist::TypeVOD = "VOD";
const static Segment EmptySegment;

MediaPlaylist::MediaPlaylist()
    : m_version(-1)
    , m_type()
    , m_mediaType(MediaType::Video_MP2T)
    , m_targetDuration(-1)
    , m_ended(false)
    , m_independentSegments(false)
    , m_prefetchSegmentCount(0)
    , m_twitchElapsedTime(0)
{
}

void parseByteRange(Segment& info, const std::string& range)
{
    auto position = range.find_last_of('@');
    if (position == std::string::npos) {
        info.rangeLength = static_cast<int>(std::strtod(range.c_str(), nullptr));
    } else {
        std::string length = range.substr(0, position);
        std::string offset = range.substr(position + 1);
        info.rangeLength = static_cast<int>(std::strtod(length.c_str(), nullptr));
        info.rangeOffset = static_cast<int>(std::strtod(offset.c_str(), nullptr));
    }
}

void MediaPlaylist::parse(const std::string& playlist, bool prefetch)
{
    PlaylistParser parser(playlist);
    parser.nextLine();

    if (!parser.readM3U()) {
        TRACE_ERROR("Invalid variant playlist");
        m_ended = true;
        return;
    }

    m_segments.clear();
    MediaTime duration(0);
    int sequenceNum = 0; // sequence number of next segment
    int prefetchCount = 0; // number of prefetch segments in this manifest
    Segment segment;
    std::shared_ptr<Segment> initializationSegment;

    while (parser.nextLine()) {
        if (!parser.readCommentStart()) {
            continue;
        }

        if (parser.readPrefix(EXTINF) || (prefetch && parser.hasPrefix(EXT_X_TWITCH_PREFETCH))) {
            // EXT_X_MEDIA_SEQUENCE must be parsed before this is run
            segment.sequenceNumber = sequenceNum;

            // The URL and duration are parsed differently with prefetch
            if (parser.readPrefix(EXT_X_TWITCH_PREFETCH)) {
                // URL on rest of line
                segment.url = parser.getLine();
                // Use average duration for prefetch segments
                segment.duration = m_segments.empty() ? MediaTime::zero() : (duration / static_cast<double>(m_segments.size()));
                segment.prefetch = true;
                prefetchCount++;
            } else {
                // duration on rest of line
                segment.duration = MediaTime(parser.parseDouble());
                // read the rest of the segment attributes
                parser.nextLine();

                while (parser.readCommentStart()) {
                    if (parser.readPrefix(EXT_X_DISCONTINUITY)) {
                        segment.discontinuity = true;
                    } else if (parser.readPrefix(EXT_X_BYTERANGE)) {
                        // ex. 51501@360854 length then offset
                        parseByteRange(segment, parser.getLine());
                    }
                    // URL is on the next line
                    parser.nextLine();
                }

                segment.url = parser.getLine();
            }

            // Set cumulative duration as sum of all previous durations
            duration += segment.duration;
            segment.cumulativeDuration = duration;
            segment.initializationSegment = initializationSegment;

            // update program date time, program date is not required on all segments so use
            // previous date time if available
            if (!segment.prefetch) {
                if (segment.programDateTime == Segment::ProgramTimeNone && !m_segments.empty()) {
                    const auto& previous = m_segments.back();
                    if (previous.programDateTime != Segment::ProgramTimeNone) {
                        segment.programDateTime = previous.programDateTime + previous.duration.milliseconds();
                    }
                }
            }

            m_segments.push_back(std::move(segment));
            // prepare for next segment
            sequenceNum++;
            segment = Segment();
        } else if (parser.readPrefix(EXT_X_VERSION)) {
            m_version = parser.parseInt();
        } else if (parser.readPrefix(EXT_X_TARGETDURATION)) {
            m_targetDuration = parser.parseInt();
        } else if (parser.readPrefix(EXT_X_MEDIA_SEQUENCE)) {
            sequenceNum = parser.parseInt();
        } else if (parser.readPrefix(EXT_X_PLAYLIST_TYPE)) {
            m_type = parser.getLine();
        } else if (parser.readPrefix(EXT_X_DISCONTINUITY)) {
            segment.discontinuity = true;
        } else if (parser.readPrefix(EXT_X_START)) {
            std::map<std::string, std::string> startAttributes;
            parser.parseAttributes(startAttributes);
            segment.start = true;
            auto offset = startAttributes.find("TIME-OFFSET");
            if (offset != startAttributes.end()) {
                segment.startOffset = MediaTime(std::strtod(offset->second.c_str(), nullptr));
            }
        } else if (parser.readPrefix(EXT_X_ENDLIST)) {
            m_ended = true;
        } else if (parser.readPrefix(EXT_X_KEY)) {
            auto encryption = Segment::Encryption();
            std::map<std::string, std::string> keyattributes;
            parser.parseAttributes(keyattributes);

            // "com.apple.streamingkeydelivery" // fairplay
            encryption.keyformat = keyattributes["KEYFORMAT"];
            if (encryption.keyformat.empty()) {
                encryption.keyformat = "identity";
            }

            encryption.uri = keyattributes["URI"];
            if ("twitch0" == encryption.keyformat && 0 == encryption.uri.compare(0, 12, "data:base64,")) {
                // TODO full parsing of 'data:' scheme
                encryption.key = Base64::decode(encryption.uri.substr(12));
                encryption.uri.clear();
            }

            auto& iv = keyattributes["IV"];
            if (iv.empty()) {
                iv.resize(16); // auto fills with '0'
                encryption.iv[12] = static_cast<uint8_t>(sequenceNum >> 24);
                encryption.iv[13] = static_cast<uint8_t>(sequenceNum >> 16);
                encryption.iv[14] = static_cast<uint8_t>(sequenceNum >> 8);
                encryption.iv[15] = static_cast<uint8_t>(sequenceNum);
            } else if (34 == iv.size() && 0 == iv.compare(0, 2, "0x")) {
                encryption.iv = Hex::decode(&iv[2], 32);
            } else {
                TRACE_WARN("Unknown IV format");
            }

            auto& method = keyattributes["METHOD"];
            if ("AES-128" == method) {
                encryption.method = Segment::Encryption::AES_128;
            } else if ("SAMPLE-AES" == method) {
                encryption.method = Segment::Encryption::SAMPLE_AES;
            } else if ("COMMON-CENC" == method) {
                encryption.method = Segment::Encryption::COMMON_CENC;
            }
        } else if (parser.readPrefix(EXT_X_MAP)) {
            std::map<std::string, std::string> mapAttributes;
            parser.parseAttributes(mapAttributes);

            initializationSegment = std::make_shared<Segment>();
            initializationSegment->isInitialization = true;
            auto uri = mapAttributes.find("URI");
            if (uri != mapAttributes.end()) {
                initializationSegment->url = uri->second;
                if (uri->second.find(".mp4") != std::string::npos || uri->second.find(".m4s") != std::string::npos) {
                    m_mediaType = MediaType::Video_MP4;
                }
            }

            auto byteRange = mapAttributes.find("BYTERANGE");
            if (byteRange != mapAttributes.end()) {
                parseByteRange(*initializationSegment, byteRange->second);
            }
        } else if (parser.readPrefix(EXT_X_PROGRAM_DATE_TIME)) {
            segment.programDateTime = parser.parseIso8601(parser.getLine());
        } else if (parser.readPrefix(EXT_X_DATERANGE)) {
            std::shared_ptr<Segment::DateRange> dateRange = std::make_shared<Segment::DateRange>();
            parser.parseAttributes(dateRange->attributes);
            dateRange->start = parser.parseIso8601(dateRange->attributes["START-DATE"]);
            dateRange->id = dateRange->attributes["ID"];
            dateRange->endOnNext = dateRange->attributes["END-ON-NEXT"] == "YES";
            // check duration
            auto rangeDuration = dateRange->attributes.find("DURATION");
            if (rangeDuration != dateRange->attributes.end()) {
                dateRange->duration = std::strtod(rangeDuration->second.c_str(), nullptr);
            } else {
                dateRange->duration = Segment::DateRangeInfinite;
            }
            segment.dateRanges.push_back(dateRange);
            // TODO END-DATE not supported
        } else if (parser.readPrefix(EXT_X_INDEPENDENT_SEGMENTS)) {
            m_independentSegments = true;
        } else if (parser.readPrefix(EXT_X_TWITCH_ELAPSED_SECS)) {
            m_twitchElapsedTime = MediaTime(parser.parseInt());
        }
    }

    m_prefetchSegmentCount = prefetchCount;
}

std::string MediaPlaylist::generate()
{
    if (m_segments.empty()) {
        TRACE_ERROR("Invalid Segment List");
        return "";
    }

    // Basic Tags
    std::string playlist = PREFIX + EXTM3U + NEWLINE;
    playlist += PREFIX + EXT_X_VERSION + std::to_string(m_version) + NEWLINE;

    // Media Playlist Tags
    playlist += PREFIX + EXT_X_TARGETDURATION + std::to_string(m_targetDuration) + NEWLINE;
    playlist += PREFIX + EXT_X_MEDIA_SEQUENCE + std::to_string(m_segments.front().sequenceNumber) + NEWLINE;
    if (!m_type.empty()) {
        playlist += PREFIX + EXT_X_PLAYLIST_TYPE + m_type + NEWLINE;
    }
    if (m_independentSegments) {
        playlist += PREFIX + EXT_X_INDEPENDENT_SEGMENTS + NEWLINE;
    }

    std::shared_ptr<Segment> initializationSegment;
    // Media Segment Tags
    for (auto& segment : m_segments) {
        if (segment.discontinuity) {
            playlist += PREFIX + EXT_X_DISCONTINUITY + NEWLINE;
        }

        //TODO generation for EXT_X_KEY
        if (segment.initializationSegment && segment.initializationSegment->isInitialization
            && initializationSegment != segment.initializationSegment) {
            initializationSegment = segment.initializationSegment;
            playlist += PREFIX + EXT_X_MAP;
            if (!initializationSegment->url.empty()) {
                playlist += "URI=\"" + initializationSegment->url + "\"";
            }
            if (initializationSegment->rangeLength >= 0) {
                playlist += ",BYTERANGE=\"" + std::to_string(initializationSegment->rangeLength);

                // ex. 51501@360854 length then offset
                if (initializationSegment->rangeOffset >= 0) {
                    playlist += '@' + std::to_string(initializationSegment->rangeOffset);
                }
                playlist += "\"";
            }
            playlist += NEWLINE;
        }

        if (segment.programDateTime != std::chrono::system_clock::time_point::min()) {
            playlist += PREFIX + EXT_X_PROGRAM_DATE_TIME + generateIso8601(segment.programDateTime) + NEWLINE;
        }

        for (const auto& dateRange : segment.dateRanges) {
            std::string attributes;
            for (const auto& entry : dateRange->attributes) {
                if (!attributes.empty()) {
                    attributes += ",";
                }
                attributes += entry.first;
                attributes += "=";
                attributes += "\"" + entry.second + "\"";
            }
            playlist += PREFIX + EXT_X_DATERANGE + attributes + NEWLINE;
        }

        //EXTINF
        //convert milliseconds to seconds with 3 decimal places.
        std::string duration(16, '\0');
        auto size = std::snprintf(&duration[0], duration.size(), "%.3f", segment.duration.seconds());
        duration.resize(size);
        playlist += PREFIX + EXTINF + duration + ',' + NEWLINE;
        if (segment.rangeLength >= 0) {
            playlist += PREFIX + EXT_X_BYTERANGE + std::to_string(segment.rangeLength);
            // ex. 51501@360854 length then offset
            if (segment.rangeOffset >= 0) {
                playlist += '@' + std::to_string(segment.rangeOffset);
            }
            playlist += NEWLINE;
        }
        playlist += segment.url + NEWLINE;
    }

    // Media Playlist Tags
    if (m_ended) {
        playlist += PREFIX + EXT_X_ENDLIST + NEWLINE;
    }

    return playlist;
}

bool MediaPlaylist::parsed() const
{
    return !m_segments.empty();
}

const std::string& MediaPlaylist::getType() const
{
    return m_type;
}

bool MediaPlaylist::isEnded() const
{
    return m_type == TypeVOD || m_ended;
}

MediaTime MediaPlaylist::getDuration() const
{
    return m_segments.empty() ? MediaTime::zero() : m_segments.back().cumulativeDuration;
}

int MediaPlaylist::getStartSequence() const
{
    // find the sequence matching EXT-X-START if present otherwise return invalid sequence
    for (auto it = m_segments.begin(); it != m_segments.end(); ++it) {
        if (it->start) {
            MediaTime offset = it->startOffset;
            if (offset == MediaTime::zero()) {
                // use current segment
            } else if (offset > MediaTime::zero()) {
                // go forwards until the offset is met
                while (offset > MediaTime::zero() && it != m_segments.end()) {
                    offset -= it->duration;
                    it++;
                }
            } else {
                // go backwards until the offset is met
                while (offset < MediaTime::zero() && it != m_segments.begin()) {
                    offset += it->duration;
                    it--;
                }
            }
            return it->sequenceNumber;
        }
    }
    return Segment::InvalidSequenceNumber;
}

int MediaPlaylist::getTargetDuration() const
{
    return m_targetDuration;
}

const Segment& MediaPlaylist::segmentAt(MediaTime time) const
{
    auto comp = [](MediaTime timestamp, const Segment& segment) {
        return timestamp < segment.cumulativeDuration;
    };
    auto upper = std::upper_bound(m_segments.begin(), m_segments.end(), time, comp);
    return upper != m_segments.end() ? *upper : EmptySegment;
}

const Segment& MediaPlaylist::segmentAt(Segment::ProgramTime datetime) const
{
    auto comp = [](const Segment& segment, Segment::ProgramTime datetime) {
        return segment.programDateTime < datetime;
    };

    auto end = m_segments.end() - m_prefetchSegmentCount;
    auto lower = std::lower_bound(m_segments.begin(), end, datetime, comp);
    return lower != end ? *lower : EmptySegment;
}

bool MediaPlaylist::isFinalSegment(int sequenceNumber) const
{
    return isEnded() && !m_segments.empty() && sequenceNumber >= m_segments.back().sequenceNumber;
}

const std::vector<Segment>& MediaPlaylist::segments() const
{
    return m_segments;
}

std::vector<std::shared_ptr<Segment::DateRange>> MediaPlaylist::getDateRanges(int sequenceNumber) const
{
    std::vector<std::shared_ptr<Segment::DateRange>> dateRanges;
    std::map<std::string, std::shared_ptr<Segment::DateRange>> classed;
    // check for additional date range tags that should apply to the given segment
    for (const auto& s : m_segments) {
        if (s.sequenceNumber <= sequenceNumber) {
            for (const auto& dateRange : s.dateRanges) {
                const auto dateRangeClass = dateRange->attributes["CLASS"];
                bool insert = dateRange->duration == Segment::DateRangeInfinite || s.sequenceNumber == sequenceNumber;

                if (!insert && dateRange->duration > s.duration.seconds()) {
                    // check if the duration overlaps with this segment
                    double duration = 0.0;
                    for (const auto& x : m_segments) {
                        // check all segments ahead of s inclusive but before the input segment
                        if (x.sequenceNumber >= s.sequenceNumber && x.sequenceNumber < sequenceNumber) {
                            duration += x.duration.seconds();
                        }
                    }
                    if (dateRange->duration > duration) {
                        insert = true;
                    }
                }
                if (insert) {
                    if (dateRange->endOnNext) {
                        classed[dateRangeClass] = dateRange;
                    } else {
                        dateRanges.push_back(dateRange);
                    }
                }
            }
        } else {
            break;
        }
    }
    // add the non-ended date ranges
    for (const auto& entry : classed) {
        dateRanges.push_back(entry.second);
    }

    return dateRanges;
}

std::string MediaPlaylist::generateIso8601(Segment::ProgramTime datetime)
{
    time_t t = std::chrono::system_clock::to_time_t(datetime);
    std::tm* time = std::gmtime(&t);
    time->tm_year += 1900;
    time->tm_mon += 1;

    const size_t MaxLength = 100;
    char date[MaxLength];
    std::snprintf(date, MaxLength, "%d-%02d-%02dT%02d:%02d:%02dZ", time->tm_year, time->tm_mon,
        time->tm_mday, time->tm_hour, time->tm_min, time->tm_sec);
    return std::string(date);
}
}
}
