#include "QualitySelector.hpp"
#include "BandwidthFilter.hpp"
#include "BitrateFilter.hpp"
#include "BufferFilter.hpp"
#include "DroppedFrameFilter.hpp"
#include "MaxBufferFilter.hpp"
#include "PercentileEstimator.hpp"
#include "RandomFilter.hpp"
#include "ResolutionFilter.hpp"
#include "ViewportFilter.hpp"
#include "player/MediaRequest.hpp"

namespace twitch {
namespace abr {
QualitySelector::QualitySelector(std::shared_ptr<Platform> platform)
    : m_log(platform->getLog(), "ABR ")
    , m_playbackRate(1.0)
    , m_state(BufferState::Filling)
    , m_type(StreamType::Live)
{
    m_filters.emplace_back(new BitrateFilter());
    m_filters.emplace_back(new ResolutionFilter());
    m_filters.emplace_back(new ViewportFilter());
    m_filters.emplace_back(new DroppedFrameFilter(m_log));
    m_filters.emplace_back(new BandwidthFilter(m_log, platform->getCapabilities().supportsHttpResponseTiming));
    m_filters.emplace_back(new MaxBufferFilter());
    m_filters.emplace_back(new BufferFilter(m_log));

    setLowLatencyMode(false);
}

void QualitySelector::onRequestSent(const MediaSource::Request& request)
{
    if (request.getType() == MediaRequest::Type::Segment) {
        filter<BandwidthFilter>(&BandwidthFilter::onRequestSent, request);
    }
}

void QualitySelector::onResponseReceived(const MediaSource::Request& request)
{
    if (request.getType() == MediaRequest::Type::Segment) {
        filter<BandwidthFilter>(&BandwidthFilter::onResponseReceived, request);
    }
}

void QualitySelector::onResponseBytes(const MediaSource::Request& request, size_t bytes)
{
    if (request.getType() == MediaRequest::Type::Segment) {
        filter<BandwidthFilter>(&BandwidthFilter::onResponseBytes, request, bytes);
    }
}

void QualitySelector::onResponseEnd(const MediaSource::Request& request)
{
    if (request.getType() == MediaRequest::Type::Segment) {
        filter<BandwidthFilter>(&BandwidthFilter::onResponseEnd, request);
    }
}

void QualitySelector::onRequestError(const MediaSource::Request& request, int error)
{
    if (request.getType() == MediaRequest::Type::Segment) {
        filter<BandwidthFilter>(&BandwidthFilter::onRequestError, request, error);
    }
}

void QualitySelector::onBufferDurationChange(const TimeRange& range)
{
    m_buffered = range.duration;
}

void QualitySelector::onBufferStateChange(BufferState state)
{
    m_state = state;
}

void QualitySelector::onStatistics(const Statistics& statistics, const Quality& quality)
{
    for (auto& filter : m_filters) {
        filter->onStatistics(statistics, quality);
    }
}

void QualitySelector::onStreamChange()
{
    m_selected = Quality();
    for (auto& filter : m_filters) {
        filter->onStreamChange();
    }
}

void QualitySelector::filter(const Filter& filter, const Quality& quality)
{
    (void)filter;
    if (m_filtered.count(quality) == 0) {
        m_filtered.insert(quality);

        if (!m_filterlog.empty()) {
            m_filterlog += ", ";
        }
        m_filterlog += quality.name + " (" + std::to_string(quality.bitrate) + ")";
    }
}

const Quality& QualitySelector::nextQuality(const Qualities& qualities)
{
    const auto& autoQualities = qualities.getAutoQualities();

    // filter quality list, filters do no persist and must be re-added by the filter implementation
    m_filtered.clear();
    for (auto& filter : m_filters) {
        // skip disabled filters from the decision
        if (m_disabledFilters.find(filter->getName()) != m_disabledFilters.end()) {
            continue;
        }
        if (!filter->filter(autoQualities, *this)) {
            m_log.info("%s disabled filter chain", filter->getName().c_str());
            break;
        }
        if (!m_filterlog.empty()) {
            m_log.info("%s: filtered %s", filter->getName().c_str(), m_filterlog.c_str());
            m_filterlog.clear();
        }
    }

    // match to the highest non filtered quality
    Quality target = getTarget(autoQualities);

    if (target.bitrate != m_selected.bitrate) {
        m_selected = target;
        m_log.info("switch quality %s (%d)", m_selected.name.c_str(), m_selected.bitrate);
    }

    return m_selected;
}

bool QualitySelector::canReplaceBuffer(const Qualities& qualities, MediaTime duration)
{
    // check if it's actually worth replacing the buffer
    int bandwidth = getBandwidthEstimate();
    int qualityBitrate = qualities.match(bandwidth).bitrate;
    int targetBitrate = static_cast<int>((duration.seconds() * qualityBitrate) / 8);
    m_log.info("Buffer replace %.2f s with %.2f kbps need %.2f kbps have %.2f kbps",
        duration.seconds(), qualityBitrate / 1000.0, targetBitrate / 1000.0, bandwidth / 1000.0);
    return bandwidth > targetBitrate;
}

int QualitySelector::getAverageBitrate() const
{
    for (auto& f : m_filters) {
        if (f->getName() == BandwidthFilter::Name) {
            return static_cast<BandwidthFilter*>(f.get())->getAverageBitrate();
        }
    }
    return 0;
}

int QualitySelector::getBandwidthEstimate() const
{
    for (auto& f : m_filters) {
        if (f->getName() == BandwidthFilter::Name) {
            return static_cast<BandwidthFilter*>(f.get())->getBandwidthEstimate();
        }
    }
    return Estimator::NoEstimate;
}

const CircularQueue<RequestMetric>& QualitySelector::getTransferHistory() const
{
    for (auto& f : m_filters) {
        if (f->getName() == BandwidthFilter::Name) {
            return static_cast<BandwidthFilter*>(f.get())->getTransferHistory();
        }
    }
    static CircularQueue<RequestMetric> empty(0);
    return empty;
}

void QualitySelector::setPlaybackRate(float rate)
{
    m_playbackRate = rate ? rate : m_playbackRate; // Make sure non-zero
}

void QualitySelector::setStreamType(StreamType type)
{
    if (m_type != type) {
        m_type = type;
        if (m_type == StreamType::VOD) {
            setLowLatencyMode(false);
        }
    }
}

void QualitySelector::setTargetBufferSize(MediaTime duration)
{
    filter<BufferFilter>(&BufferFilter::setTargetBufferSize, duration);
}

void QualitySelector::setLowLatencyMode(bool enable)
{
    filter<BandwidthFilter>(&BandwidthFilter::setLowLatencyMode, enable);
    filter<BufferFilter>(&BufferFilter::setLowLatencyMode, enable);
}

void QualitySelector::setInitialBitrate(int bitrate)
{
    filter<BandwidthFilter>(&BandwidthFilter::setInitialBitrate, bitrate);
}

void QualitySelector::setMaxBitrate(int bitrate)
{
    filter<BitrateFilter>(&BitrateFilter::setMaxBitrate, bitrate);
}

void QualitySelector::setMaxVideoSize(int width, int height)
{
    filter<ResolutionFilter>(&ResolutionFilter::setMaxResolution, width, height);
}

void QualitySelector::setViewportSize(int width, int height)
{
    filter<ViewportFilter>(&ViewportFilter::setViewportSize, width, height);
}

void QualitySelector::setViewportScale(float scale)
{
    filter<ViewportFilter>(&ViewportFilter::setViewportScale, scale);
}

MediaTime QualitySelector::getMinBufferTarget() const
{
    for (auto& f : m_filters) {
        if (f->getName() == BufferFilter::Name) {
            return static_cast<BufferFilter*>(f.get())->getMinBufferTarget();
        }
    }
    return MediaTime::zero();
}

Quality QualitySelector::getTarget(const std::vector<Quality>& qualities) const
{
    Quality target = m_selected;

    for (const auto& quality : qualities) {
        if (!m_filtered.count(quality)) {
            target = quality;
            break;
        }
    }

    return target;
}
}
}
