#include "DrmClient.hpp"
#include "media/ProtectionSystem.hpp"
#include "playercore/MediaType.hpp"
#include "twitch/DrmKeyOS.hpp"

namespace twitch {
const std::string RequestKey = "LicenseKey";
const std::string RequestProvisioning = "Provisioning";

DrmClient::DrmClient(NativePlatform& platform, Listener& listener, const std::shared_ptr<Scheduler>& scheduler)
    : ScopedScheduler(scheduler)
    , m_platform(platform)
    , m_listener(listener)
    , m_log(platform.getLog())
    , m_httpClient(platform.createAsyncHttpClient(scheduler))
    , m_systemID(media::ProtectionSystemClearKey)
    , m_providerPrepared(false)
    , m_keyState(Initialized)
{
}

DrmClient::~DrmClient()
{
    for (auto& request : m_requests) {
        request.cancel();
    }
    m_session.reset();
}

bool DrmClient::onProtectedMedia(const MediaFormat& format)
{
    // check the format protection data is supported by the platform
    for (const auto& pssh : format.getProtectionData()) {
        for (const auto& uuid : m_platform.getSupportedProtectionSystems()) {
            // offset to the uuid, only supports MP4 pssh data right now
            auto uuidStart = pssh.begin() + 12;
            if (uuidStart + 16 < pssh.end() && std::equal(uuid.begin(), uuid.end(), uuidStart)) {
                m_systemID = Uuid::fromBytes(uuid);

                if (pssh == m_initializationData) {
                    // already configured with this initialization data
                    return true;
                }

                if (!m_initializationData.empty() && pssh != m_initializationData) {
                    m_log->info("Reset DRM session");
                    m_session.reset();
                }

                m_initializationData = pssh;
                m_session = m_platform.createDrmSession(uuid, *this);

                if (m_session) {
                    // initialize the session
                    m_session->initialize();
                    // TODO create provider based on the protection system and provider
                    // hardcoding to BuyDRM Widevine for now
                    m_provider.reset(new DrmKeyOs(*this, m_httpClient, m_systemID));
                    m_providerPrepared = false;
                    m_provider->prepare(format);
                    return true;
                }
            }
        }
    }
    m_log->warn("Failed to create session for supported type");
    return false;
}

void DrmClient::onKeyExpired(DrmSession& session)
{
    (void)session;
    schedule([=]() {
        m_keyState = Expired;
        m_listener.onProtectionError(Error(ErrorSource::Decode, MediaResult::Error, "Media keys expired"));
    });
}

void DrmClient::onKeyRequested(DrmSession& session)
{
    schedule([=, &session]() {
        postRequest(session, RequestKey, session.generateKeyRequest(m_initializationData));
    });
}

void DrmClient::onProvisioningRequested(DrmSession& session)
{
    schedule([=, &session]() {
        // check for already in progress provisioning request
        if (isProvisioning()) {
            return;
        }
        auto request = session.generateProvisionRequest();
        postRequest(session, RequestProvisioning, request);
    });
}

void DrmClient::onError(DrmSession& session, const Error& error)
{
    (void)session;
    schedule([=]() {
        m_listener.onProtectionError(error);
    });
}

void DrmClient::onProviderPrepared()
{
    m_providerPrepared = true;
    if (!isProvisioning()) {
        onKeyRequested(*m_session);
    }
}

void DrmClient::onProviderError(const Error& error)
{
    m_listener.onProtectionError(error);
}

void DrmClient::postRequest(DrmSession& session, const std::string& type, const DrmSession::OpaqueRequest& opaqueRequest)
{
    using namespace twitch::media;
    // create a POST request with the opaque data
    std::string url = opaqueRequest.url;
    if (type == RequestKey && url.empty()) {
        url = m_provider->getLicenseAcquisitionUrl();
    } else if (type == RequestProvisioning) {
        if (m_systemID == ProtectionSystemWidevine) {
            // widevine set the provisioning data as a query param and not in the body
            std::string data(reinterpret_cast<const char*>(opaqueRequest.data.data()), opaqueRequest.data.size());
            url += "&signedRequest=" + data;
        }
    }

    auto httpRequest = m_httpClient->createRequest(url, HttpMethod::POST);

    // key request
    if (type == RequestKey) {
        httpRequest->setContent(opaqueRequest.data);
        if (m_systemID == ProtectionSystemPlayReady) {
            httpRequest->setHeader("Content-Type", "text/xml");
            httpRequest->setHeader("SOAPAction", "http://schemas.microsoft.com/DRM/2007/03/protocols/AcquireLicense");
        } else if (m_systemID == ProtectionSystemWidevine) {
            httpRequest->setHeader("Content-Type", MediaType::Application_OctetStream.name);
        } else if (m_systemID == ProtectionSystemClearKey) {
            httpRequest->setHeader("Content-Type", MediaType::Application_Json.name);
        }
        m_provider->onLicenseRequest(*httpRequest);
        m_keyState = KeyState::Pending;
    }
    // provisioning request
    else if (type == RequestProvisioning) {
        if (m_systemID != ProtectionSystemWidevine) {
            httpRequest->setContent(opaqueRequest.data);
        }
    }

    m_requests.emplace_back(type);
    auto& request = m_requests.back();
    request.onRequest(httpRequest);
    request.setUrl(url);
    m_log->info("DRM %s POST %s", type.c_str(), url.c_str());

    m_httpClient->send(
        httpRequest,
        [this, &request, &session, type](std::shared_ptr<HttpResponse> response) {
            // handle http response
            onResponse(request, *response, session);
        },
        [this, &request, type](int error) {
            // handle request error
            onRequestError(request, MediaResult::ErrorNetworkIO, error, "");
        });
}

void DrmClient::onResponse(MediaRequest& request, HttpResponse& response, DrmSession& session)
{
    request.onResponse(response);
    if (response.isSuccess()) {
        std::shared_ptr<std::vector<uint8_t>> bytes = std::make_shared<std::vector<uint8_t>>();
        response.read([this, &session, &request, bytes](const uint8_t* data, size_t size, bool endOfStream) mutable {
            bytes->insert(bytes->end(), data, data + size);
            if (endOfStream) {
                onResponseData(request, session, *bytes);
            } },
            [this, &request](int code) {
                onRequestError(request, MediaResult::ErrorNetworkIO, code, "");
            });
    } else {
        int status = response.getStatus();
        request.readString(
            response, [=, &request](const std::string& content) {
            // get server error
            onRequestError(request, MediaResult::ErrorNetwork, status, content); },
            [=, &request](int code) {
                onRequestError(request, MediaResult::ErrorNetworkIO, code, "");
            });
    }
}

void DrmClient::onResponseData(MediaRequest& request, DrmSession& session, const std::vector<uint8_t>& data)
{
    request.onCompleted();
    if (request.getType() == RequestKey) {
        // set the key status to updated if updateKeyResponse errors synchronously the state is reset
        m_keyState = Updated;
        session.updateKeyResponse(data);
        // recheck state incase of error callback
        if (m_keyState == Updated) {
            m_listener.onKeysProvided();
        }
    } else if (request.getType() == RequestProvisioning) {
        session.updateProvisionResponse(data);
        if (m_providerPrepared) {
            // now get the key
            onKeyRequested(*m_session);
        }
    }
    removeRequest(request.getId());
}

void DrmClient::onRequestError(MediaRequest& request, MediaResult result, int error, const std::string& text)
{
    std::string message = text;
    if (message.empty()) {
        if (request.getType() == RequestKey) {
            message = "Key request failed";
        } else if (request.getType() == RequestProvisioning) {
            message = "Provision request failed";
        }
    }
    if (request.getType() == RequestKey) {
        m_keyState = Errored;
    }
    m_listener.onProtectionError(Error(ErrorSource::Decode, MediaResult(result, error), message));
    removeRequest(request.getId());
}

void DrmClient::removeRequest(int id)
{
    m_requests.erase(std::remove_if(m_requests.begin(), m_requests.end(),
                         [id](MediaRequest& r) { return r.getId() == id; }),
        m_requests.end());
}

bool DrmClient::isProvisioning() const
{
    // check for already in progress provisioning request
    for (auto& request : m_requests) {
        if (request.getType() == RequestProvisioning && request.isPending()) {
            return true;
        }
    }
    return false;
}
}
