#include "BandwidthFilter.hpp"
#include "PercentileEstimator.hpp"
#include "player/MediaRequest.hpp"

namespace twitch {
namespace abr {
std::string BandwidthFilter::Name = "BandwidthFilter";
const int DefaultBitrate = 1000000;
const double DefaultBandwidthUsage = 0.80;
const int LowLatencyMinBitrate = 650000;
const size_t HistorySize = 100;

BandwidthFilter::BandwidthFilter(Log& log, bool useResponseTime)
    : m_log(log)
    , m_useResponseTime(useResponseTime)
    , m_streamChanged(false)
    , m_bandwidthUsage(DefaultBandwidthUsage)
    , m_rebufferingUsagePenalty(DefaultBandwidthUsage)
    , m_initialBitrate(DefaultBitrate)
    , m_estimationMode(EstimationMode::Normal)
    , m_averageBitrate(10)
    , m_transferHistory(HistorySize)
{
}

void BandwidthFilter::onRequestSent(const MediaSource::Request& request)
{
    const int MaxTransfers = 10;
    if (m_requests.size() > MaxTransfers) {
        m_requests.erase(m_requests.begin());
    }
    auto& transfer = m_requests[request.getId()];
    transfer.initiated = MediaTime::now();
    transfer.mediaDuration = request.getMediaDuration();
    transfer.mediaBitrate = request.getMediaBitrate();
}

void BandwidthFilter::onResponseReceived(const MediaSource::Request& request)
{
    auto& transfer = m_requests[request.getId()];
    MediaTime now = MediaTime::now();
    transfer.start = now;
    transfer.bytes = 0;

    if (request.isRangeRequest()) {
        // accuracy of this callback is not good on certain browsers use the whole round trip time
        // for small range requests
        transfer.lastUpdate = m_useResponseTime ? now : transfer.initiated;
    } else {
        transfer.lastUpdate = transfer.start;
    }
}

void BandwidthFilter::onResponseBytes(const MediaSource::Request& request, size_t bytes)
{
    auto& transfer = m_requests[request.getId()];
    transfer.bytes += bytes;

    // Update chunk estimator
    auto now = MediaTime::now();
    getEstimator(request).sample(now - transfer.lastUpdate, bytes);
    transfer.lastUpdate = now;
}

void BandwidthFilter::onResponseEnd(const MediaSource::Request& request)
{
    auto& transfer = m_requests[request.getId()];
    transfer.end = MediaTime::now();
    transfer.mediaDuration = request.getMediaDuration();
    int estimate = getEstimator(request).estimate();

    if (estimate == Estimator::NoEstimate) {
        estimate = m_initialBitrate;
    }

    if (!request.isRangeRequest()) {
        m_log.info("downloaded %04d bitrate %d in %.2f s, bandwidth estimate %.3f kbps",
            request.getId(),
            transfer.mediaBitrate,
            transfer.getDuration().seconds(),
            estimate / 1000.0);

        m_transferHistory.push_back(transfer);
        int segmentBps = static_cast<int>(transfer.bytes * 8 / request.getMediaDuration().seconds());
        m_averageBitrate.add(segmentBps);
    }

    m_requests.erase(request.getId());
}

void BandwidthFilter::onRequestError(const MediaSource::Request& request, int error)
{
    (void)error;
    onResponseEnd(request);
}

bool BandwidthFilter::filter(const std::vector<Quality>& qualities, Filter::Context& context)
{
    int estimatedBitrate = getEstimate(context);

    // take a proportion of the estimated download bandwidth and calculate the corresponding quality
    double estimate = (m_bandwidthUsage * estimatedBitrate) / context.getPlaybackRate();
    // if rebuffering further limit the bandwidth usage again
    double targetBitrate = context.getBufferState() == BufferState::Refilling
        ? (m_rebufferingUsagePenalty * estimate)
        : estimate;

    // low latency mode, low bitrate streams (e.g. 160p) are excluded
    if (m_estimationMode != EstimationMode::Normal) {
        targetBitrate = std::max(targetBitrate, static_cast<double>(LowLatencyMinBitrate));
    }

    // filter out qualities with bitrate greater than the target bitrate
    for (const auto& q : qualities) {
        if (q.bitrate > targetBitrate) {
            context.filter(*this, q);
        }
    }

    // switching to a new stream don't pick source (twitch specific optimization)
    if (m_streamChanged && !qualities.empty()) {
        context.filter(*this, qualities.front());
        m_streamChanged = false;
    }

    return true;
}

void BandwidthFilter::onStatistics(const Statistics& statistics, const Quality& quality)
{
    (void)statistics;
    (void)quality;
}

void BandwidthFilter::onStreamChange()
{
    m_streamChanged = true;
    if (m_estimationMode != EstimationMode::Normal) {
        createEstimator();
    }
}

int BandwidthFilter::getEstimate(const Filter::Context& context)
{
    int estimate = m_estimator->getEstimate();
    if (estimate == Estimator::NoEstimate) {
        estimate = m_initialBitrate;
    }

    if (m_estimationMode == EstimationMode::Probe) {
        int probeEstimate = m_probeEstimator->getEstimate();
        bool isBufferLow = context.getBufferDuration() <= context.getMinBufferTarget()
            && context.getBufferState() == BufferState::Draining;
        if (probeEstimate != Estimator::NoEstimate && !isBufferLow) {
            m_log.info("Probe estimate %.2f kbps", probeEstimate / 1000.0f);
            estimate = probeEstimate;
        }
    }
    return estimate;
}

Estimator& BandwidthFilter::getEstimator(const MediaSource::Request& request)
{
    if (m_estimationMode == EstimationMode::Probe && m_probeEstimator && !request.isMediaPrefetch()) {
        return *m_probeEstimator;
    }
    return *m_estimator;
}

void BandwidthFilter::createEstimator()
{
    switch (m_estimationMode) {
    case EstimationMode::Probe:
        m_estimator.reset(new PercentileEstimator());
        m_probeEstimator.reset(new PercentileEstimator());
        break;
    case EstimationMode::Normal:
        m_estimator.reset(new PercentileEstimator());
        m_probeEstimator.reset();
        break;
    }
}

void BandwidthFilter::setInitialBitrate(int bitrate)
{
    m_initialBitrate = bitrate;
    createEstimator();
}

void BandwidthFilter::setLowLatencyMode(bool enable)
{
    bool reset = false;
    if (enable) {
        if (m_estimationMode != EstimationMode::Probe) {
            m_estimationMode = EstimationMode::Probe;
            reset = true;
        }
    } else {
        if (m_estimationMode != EstimationMode::Normal) {
            m_estimationMode = EstimationMode::Normal;
            reset = true;
        }
    }

    if (!m_estimator || reset) {
        createEstimator();
        m_bandwidthUsage = (!enable || m_useResponseTime) ? DefaultBandwidthUsage : 1.00;
    }
}

int BandwidthFilter::getAverageBitrate() const
{
    return m_averageBitrate.average();
}

int BandwidthFilter::getBandwidthEstimate() const
{
    switch (m_estimationMode) {
    case EstimationMode::Normal:
        if (m_estimator) {
            return m_estimator->getEstimate();
        }
        break;
    case EstimationMode::Probe:
        if (m_probeEstimator) {
            return m_probeEstimator->getEstimate();
        }
        break;
    }
    return Estimator::NoEstimate;
}
}
}
