#include "pch.h"
#include "HttpClient.hpp"
#include "HttpRequest.hpp"

namespace {
    enum HttpStatusCode { RequestTimeout = 408 };

    inline HINTERNET CreateSession() {
        auto httpSession = WinHttpOpen(L"PlayerCore (Windows)", WINHTTP_ACCESS_TYPE_AUTOMATIC_PROXY,
            WINHTTP_NO_PROXY_NAME, WINHTTP_NO_PROXY_BYPASS, WINHTTP_FLAG_ASYNC);
        if(!httpSession) {
            throw std::runtime_error("CreateSession.WinHttpOpen");
        }
        return httpSession;
    }

    inline int GetStatusCode(HINTERNET httpRequest) {
        int statusCode;
        auto n = static_cast<DWORD>(sizeof(statusCode));
        constexpr DWORD flags = WINHTTP_QUERY_STATUS_CODE | WINHTTP_QUERY_FLAG_NUMBER;
        if(!WinHttpQueryHeaders(httpRequest, flags, WINHTTP_HEADER_NAME_BY_INDEX, &statusCode, &n, WINHTTP_NO_HEADER_INDEX)) {
            throw std::runtime_error("GetStatusCode.WinHttpQueryHeaders");
        }
        return statusCode;
    }

    bool SetStatusCallback(HINTERNET httpHandle, WINHTTP_STATUS_CALLBACK callback) {
        auto result = WinHttpSetStatusCallback(httpHandle, callback, WINHTTP_CALLBACK_FLAG_ALL_NOTIFICATIONS, 0);
        return result != WINHTTP_INVALID_STATUS_CALLBACK;
    }
}

using namespace twitch::windows;

HttpClient::HttpClient()
    : m_httpSession(CreateSession())
{
    DWORD decompressionOption = WINHTTP_DECOMPRESSION_FLAG_ALL;
    verify(WinHttpSetOption(m_httpSession, WINHTTP_OPTION_DECOMPRESSION, &decompressionOption, sizeof(decompressionOption)));
}

HttpClient::~HttpClient()
{
    verify(!m_httpSession || WinHttpCloseHandle(m_httpSession));
}

std::shared_ptr<twitch::HttpRequest> HttpClient::createRequest(const std::string& url, twitch::HttpMethod method)
{
    return std::make_shared<HttpRequest>(m_httpSession, url, method);
}

void HttpClient::send(std::shared_ptr<twitch::HttpRequest> request, ResponseHandler onResponse, ErrorHandler onError)
{
    auto windowsRequest = std::static_pointer_cast<HttpRequest>(request);
    windowsRequest->send(onResponse, onError);
}

HttpResponse::HttpResponse(HINTERNET httpRequest, HttpClient::ResponseHandler onResponse, ErrorHandler onError)
    : twitch::HttpResponse(0)
    , m_httpResponse(httpRequest)
    , m_onResponse(onResponse)
    , m_onError(onError)
{
}

void HttpResponse::Initialize() {
    auto* const p = new std::shared_ptr<HttpResponse>(shared_from_this());
    try {
        auto context = reinterpret_cast<DWORD_PTR>(p);
        if(!WinHttpSetOption(m_httpResponse, WINHTTP_OPTION_CONTEXT_VALUE, &context, sizeof(context))) {
            TRACE_ERROR("HttpResponse::Initialize.WinHttpSetOption error %d", GetLastError());
            throw std::runtime_error("HttpResponse::Initialize.WinHttpSetOption");
        }
        if(!SetStatusCallback(m_httpResponse, OnStatusChanged)) {
            TRACE_ERROR("HttpResponse::Initialize.SetStatusCallback error %d", GetLastError());
            throw std::runtime_error("HttpResponse::Initialize.SetStatusCallback");
        }
        if(!WinHttpSendRequest(m_httpResponse, WINHTTP_NO_ADDITIONAL_HEADERS, 0, WINHTTP_NO_REQUEST_DATA, 0, 0, 0)) {
            TRACE_ERROR("HttpResponse::Initialize.WinHttpSendRequest error %d", GetLastError());
            throw std::runtime_error("HttpResponse::Initialize.WinHttpSendRequest");
        }
    } catch(...) {
        verify(SetStatusCallback(m_httpResponse, nullptr));
        delete p;
        throw;
    }
}

std::string HttpResponse::getHeader(const std::string& key) const
{
    auto name = AsWstring(key);
    auto n = DWORD();
    verify(!WinHttpQueryHeaders(m_httpResponse, WINHTTP_QUERY_CUSTOM, name.c_str(), WINHTTP_NO_OUTPUT_BUFFER, &n, 0));
    if(GetLastError() == ERROR_INSUFFICIENT_BUFFER) {
        std::wstring value(n / sizeof(std::wstring::value_type), L'\0');
        if(WinHttpQueryHeaders(m_httpResponse, WINHTTP_QUERY_CUSTOM, name.c_str(), &value[0], &n, 0)) {
            value.resize(n / sizeof(std::wstring::value_type));
            return std::string(value.cbegin(), value.cend());
        }
    }
    if(GetLastError() != ERROR_WINHTTP_HEADER_NOT_FOUND) {
        TRACE_ERROR("HttpResponse::getHeader.WinHttpQueryHeaders error %d", GetLastError());
        throw std::runtime_error("HttpResponse::getHeader.WinHttpQueryHeaders");
    }
    return std::string();
}

void HttpResponse::read(ContentHandler contentHandler, ErrorHandler errorHandler)
{
    TRACE_DEBUG("HttpResponse::read()");
    if(!m_httpResponse) {
        TRACE_INFO("HttpResponse::read() - Request is cancelled");
        return;
    }
    if(m_onContent) {
        TRACE_WARN("HttpResponse::read invoked multiple times for same response");
    } else {
        m_timer = std::make_shared<HttpTimer>(getReadTimeout(), 0);
        m_timer->setTimeoutFunction([self = shared_from_this()]{
            TRACE_ERROR("HttpResponse::read() Task cancelled due to timeout");
            self->cancel();
            });
    }
    m_onContent = contentHandler;
    m_onError = errorHandler;
    if(!WinHttpQueryDataAvailable(m_httpResponse, nullptr)) {
        TRACE_ERROR("HttpResponse::OnStatusChanged.WinHttpQueryDataAvailable error %d", GetLastError());
        m_timer.reset();
        throw std::runtime_error("HttpResponse::OnStatusChanged.WinHttpQueryDataAvailable");
    }
}

void HttpResponse::cancel()
{
    TRACE_DEBUG("HttpResponse::cancel()");
    if(m_timer) {
        m_timer->stop();
        m_timer->setTimeoutFunction(nullptr);
    }
    auto httpResponse = InterlockedExchangePointer(&m_httpResponse, nullptr);
    if(httpResponse && !WinHttpCloseHandle(httpResponse)) {
        TRACE_ERROR("HttpResponse::cancel.WinHttpCloseHandle error %d", GetLastError());
        throw std::runtime_error("HttpResponse::cancel.WinHttpCloseHandle");
    }
}

void WINAPI HttpResponse::OnStatusChanged(HINTERNET /*request*/, DWORD_PTR context, DWORD status, LPVOID information, DWORD size) {
    auto pp = reinterpret_cast<std::shared_ptr<HttpResponse>*>(context);
    auto p = *pp;
    try {
        if(p->OnStatusChanged(status, information, size)) {
            delete pp;
        }
    } catch(...) {
        auto error = GetLastError();
        try {
            if(p->m_httpResponse) {
                p->cancel();
                p->m_onError(p->m_timer && p->m_timer->hasTimedOut() ? HttpStatusCode::RequestTimeout : error);
            }
        } catch(...) {
            TRACE_WARN("HttpResponse::OnStatusChanged:  ignoring exception during exception recovery");
        }
    }
}

bool HttpResponse::OnStatusChanged(DWORD status, LPVOID information, DWORD size) {
    switch(status) {
    case WINHTTP_CALLBACK_STATUS_SENDREQUEST_COMPLETE:
        if(!WinHttpReceiveResponse(m_httpResponse, nullptr)) {
            TRACE_ERROR("HttpResponse::OnStatusChanged.WinHttpReceiveResponse error %d", GetLastError());
            throw std::runtime_error("HttpResponse::OnStatusChanged.WinHttpReceiveResponse");
        }
        break;
    case WINHTTP_CALLBACK_STATUS_HEADERS_AVAILABLE:
        setStatus(GetStatusCode(m_httpResponse));
        m_onResponse(shared_from_this());
        break;
    case WINHTTP_CALLBACK_STATUS_DATA_AVAILABLE:
        if(*reinterpret_cast<DWORD*>(information)) {
            if(!WinHttpReadData(m_httpResponse, &m_data[0], static_cast<DWORD>(m_data.size()), nullptr)) {
                TRACE_ERROR("HttpResponse::OnStatusChanged.WinHttpReadData error %d", GetLastError());
                throw std::runtime_error("HttpResponse::OnStatusChanged.WinHttpReadData");
            }
            m_timer->stop();
        } else {
            m_onContent(nullptr, 0, true);
            cancel();
        }
        break;
    case WINHTTP_CALLBACK_STATUS_READ_COMPLETE:
        assert(information == m_data.data());
        m_onContent(m_data.data(), size, false);
        if(!WinHttpQueryDataAvailable(m_httpResponse, nullptr)) {
            TRACE_ERROR("HttpResponse::OnStatusChanged.WinHttpQueryDataAvailable error %d", GetLastError());
            throw std::runtime_error("HttpResponse::OnStatusChanged.WinHttpQueryDataAvailable");
        }
        break;
    case WINHTTP_CALLBACK_STATUS_HANDLE_CLOSING:
        return true;
    case WINHTTP_CALLBACK_STATUS_REQUEST_ERROR:
        auto error = reinterpret_cast<WINHTTP_ASYNC_RESULT*>(information)->dwError;
        TRACE_ERROR("HttpResponse::OnStatusChanged REQUEST_ERROR %d", error);
        SetLastError(error);
        throw std::runtime_error("HttpResponse::OnStatusChanged REQUEST_ERROR");
    }
    return false;
}
