#include "pch.h"
#include "Internal.h"
#include "../Shared/HttpUtil.h"
#include "../Shared/Http.h"

#pragma comment(lib, "winhttp.lib")

using namespace Twitch;

namespace {
	constexpr size_t webSocketBufferSize = 1 << 12;
}

#define ReturnIfFailedT(result,type) do { auto result_= (result); if(!result_) return { FromPlatformError(GetLastError()), type() }; } while(false)
#define ReturnIfFailed(result) ReturnIfFailedT((result), uintptr_t)

std::tuple<int, uintptr_t> Platform::StartHttpRequest(string_t url,
	HttpRequestType httpReqType,
	std::chrono::milliseconds timeout,
	std::vector<HttpParam> const& requestHeaders,
	std::vector<char> const& requestBody) {
	// Check the URL scheme to determine if SSL is necessary.
	Uri uri{ url };
	bool useSsl;
	if (uri.GetProtocol() == _T("https") || uri.GetProtocol() == _T("wss")) {
		useSsl = true;
	} else if (uri.GetProtocol() == _T("http") || uri.GetProtocol() == _T("ws")) {
		useSsl = false;
	} else {
		return { FromPlatformError(ERROR_BAD_ARGUMENTS), 0 };
	}

	// Create a HTTP session.
	auto httpSession = WinHttpOpen(_T("Twitch Game SDK"), WINHTTP_ACCESS_TYPE_AUTOMATIC_PROXY, WINHTTP_NO_PROXY_NAME, WINHTTP_NO_PROXY_BYPASS, WINHTTP_FLAG_ASYNC);
	ReturnIfFailed(httpSession);
	auto internets = std::make_unique<Internets>();
	internets->Add(httpSession);

	// Enable deflate and gzip decoding.
	DWORD decompressionFlag = WINHTTP_DECOMPRESSION_FLAG_ALL;
	WinHttpSetOption(httpSession, WINHTTP_OPTION_DECOMPRESSION, &decompressionFlag, sizeof(decompressionFlag));

	// Set all time-outs to the requested value.
	auto timeoutMs = static_cast<DWORD>(timeout.count());
	BOOL areTimeoutsSet = WinHttpSetOption(httpSession, WINHTTP_OPTION_CONNECT_TIMEOUT, &timeoutMs, sizeof(timeoutMs)) &&
		WinHttpSetOption(httpSession, WINHTTP_OPTION_RESOLVE_TIMEOUT, &timeoutMs, sizeof(timeoutMs)) &&
		WinHttpSetOption(httpSession, WINHTTP_OPTION_RECEIVE_RESPONSE_TIMEOUT, &timeoutMs, sizeof(timeoutMs)) &&
		WinHttpSetOption(httpSession, WINHTTP_OPTION_SEND_TIMEOUT, &timeoutMs, sizeof(timeoutMs)) &&
		WinHttpSetOption(httpSession, WINHTTP_OPTION_RECEIVE_TIMEOUT, &timeoutMs, sizeof(timeoutMs));
	if (!areTimeoutsSet) {
		DebugWriteLine(_T("[Platform::StartHttpRequest.WinHttpSetOption(timeouts)] error %d"), GetLastError());
		return { FromPlatformError(GetLastError()), 0 };
	}

	// Create a connection.
	std::uint16_t port;
	if (!uri.GetPort(port)) {
		port = INTERNET_DEFAULT_PORT;
	}
	auto httpConnection = WinHttpConnect(httpSession, uri.GetHostName().c_str(), port, 0);
	ReturnIfFailed(httpConnection);
	internets->Add(httpConnection);

	// Determine the HTTP verb.
	string_t verb;
	switch (httpReqType) {
	case HttpRequestType::HTTP_PUT_REQUEST:
		verb = _T("PUT");
		break;
	case HttpRequestType::HTTP_POST_REQUEST:
		verb = _T("POST");
		break;
	case HttpRequestType::HTTP_DELETE_REQUEST:
		verb = _T("DELETE");
		break;
	default:
		verb = _T("GET");
		break;
	}

	// Open a request for the desired resource.
	DWORD flags = WINHTTP_FLAG_REFRESH | (useSsl ? WINHTTP_FLAG_SECURE : 0);
	auto httpRequest = WinHttpOpenRequest(httpConnection, verb, uri.GetResource().c_str(),
		nullptr, WINHTTP_NO_REFERER, WINHTTP_DEFAULT_ACCEPT_TYPES, flags);
	ReturnIfFailed(httpRequest);
	internets->Add(httpRequest);

	// If it's a WebSocket request, configure it as such.
	if (httpReqType == HttpRequestType::HTTP_WEB_SOCKET_REQUEST) {
		ReturnIfFailed(WinHttpSetOption(httpRequest, WINHTTP_OPTION_UPGRADE_TO_WEB_SOCKET, nullptr, 0));
	}

	// Build the headers.
	auto requestHeaderData = BuildHttpHeader(requestHeaders);

	// Send the HTTP request.
	constexpr auto const sendNotifications = WINHTTP_CALLBACK_STATUS_SENDREQUEST_COMPLETE | WINHTTP_CALLBACK_STATUS_HANDLE_CLOSING | WINHTTP_CALLBACK_STATUS_REQUEST_ERROR;
	ReturnIfFailed(WinHttpSetStatusCallback(httpRequest, &Internets::SendCallback, sendNotifications, 0) != WINHTTP_INVALID_STATUS_CALLBACK);
	auto const* const headers = requestHeaderData.c_str();
	auto const headersSize = static_cast<DWORD>(requestHeaderData.size());
	auto* const body = const_cast<char*>(requestBody.data());
	auto const bodySize = static_cast<DWORD>(requestBody.size());
	ReturnIfFailed(WinHttpSendRequest(httpRequest, headers, headersSize, body, bodySize, bodySize, reinterpret_cast<DWORD_PTR>(internets.get())));
	auto sendFuture = internets->sendPromise.get_future();
	if (sendFuture.wait_for(timeout) == std::future_status::timeout) {
		return { FromPlatformError(ERROR_TIMEOUT), 0 };
	} else {
		auto errorCode = sendFuture.get();
		if (errorCode) {
			return { FromPlatformError(errorCode), 0 };
		}
	}

	// Receive the headers of the response to the request.
	constexpr auto const responseNotifications = WINHTTP_CALLBACK_STATUS_HEADERS_AVAILABLE | WINHTTP_CALLBACK_STATUS_HANDLE_CLOSING | WINHTTP_CALLBACK_STATUS_REQUEST_ERROR;
	ReturnIfFailed(WinHttpSetStatusCallback(httpRequest, &Internets::ResponseCallback, responseNotifications, 0) != WINHTTP_INVALID_STATUS_CALLBACK);
	ReturnIfFailed(WinHttpReceiveResponse(httpRequest, nullptr));
	auto responseFuture = internets->responsePromise.get_future();
	if (responseFuture.wait_for(timeout) == std::future_status::timeout) {
		return { FromPlatformError(ERROR_TIMEOUT), 0 };
	} else {
		auto errorCode = responseFuture.get();
		if (errorCode) {
			return { FromPlatformError(errorCode), 0 };
		}
	}

	// If it's a WebSocket request, replace the HTTP request handle with the
	// WebSocket handle in the return value.
	if (httpReqType == HttpRequestType::HTTP_WEB_SOCKET_REQUEST) {
		auto webSocketHandle = WinHttpWebSocketCompleteUpgrade(httpRequest, reinterpret_cast<DWORD_PTR>(internets.get()));
		ReturnIfFailed(webSocketHandle);
		internets->ReplaceHttpRequest(webSocketHandle);
		internets->buffer.resize(webSocketBufferSize);
	}
	return { 0, reinterpret_cast<uintptr_t>(internets.release()) };
}

int Platform::GetHttpResponseResultCode(uintptr_t handle) {
	// Extract the status code from the response.
	DWORD statusCode = 0;
	DWORD statusCodeSize = sizeof(statusCode);
	auto httpRequest = Internets::GetHttpRequest(handle);
	auto result = WinHttpQueryHeaders(httpRequest, WINHTTP_QUERY_STATUS_CODE | WINHTTP_QUERY_FLAG_NUMBER,
		WINHTTP_HEADER_NAME_BY_INDEX, &statusCode, &statusCodeSize, WINHTTP_NO_HEADER_INDEX);
	return result ? static_cast<int>(statusCode) : FromPlatformError(GetLastError());
}

#undef ReturnIfFailed
#define ReturnIfFailed(result) ReturnIfFailedT((result), stringmap)

std::tuple<int, stringmap> Platform::GetHttpResponseHeaders(uintptr_t handle) {
	// Determine the size of the headers.
	auto httpRequest = Internets::GetHttpRequest(handle);
	DWORD headersSize = 0;
	auto result = WinHttpQueryHeaders(httpRequest, WINHTTP_QUERY_RAW_HEADERS_CRLF,
		WINHTTP_HEADER_NAME_BY_INDEX, nullptr, &headersSize, WINHTTP_NO_HEADER_INDEX);
	if (result) {
		return { 0, {} };
	} else {
		ReturnIfFailed(GetLastError() == ERROR_INSUFFICIENT_BUFFER);
	}

	// Extract the headers.
	std::vector<TCHAR> headerBuffer;
	headerBuffer.resize(headersSize / sizeof(TCHAR));
	result = WinHttpQueryHeaders(httpRequest, WINHTTP_QUERY_RAW_HEADERS_CRLF,
		WINHTTP_HEADER_NAME_BY_INDEX, headerBuffer.data(), &headersSize, WINHTTP_NO_HEADER_INDEX);
	ReturnIfFailed(result);

	// Parse the headers into a dictionary.
	stringmap responseHeaders;
	ParseHeaders(headerBuffer.data(), headerBuffer.size(), responseHeaders);
	return { 0, responseHeaders };
}

#undef ReturnIfFailed
#define ReturnIfFailed(result) ReturnIfFailedT((result), std::vector<char>)

std::tuple<int, std::vector<char>> Platform::GetHttpResponse(uintptr_t handle, std::chrono::milliseconds timeout) {
	// Receive the payload of the response to the request.
	auto* internets = reinterpret_cast<Internets*>(handle);
	auto request = Internets::GetHttpRequest(handle);
	constexpr auto const dataNotifications = WINHTTP_CALLBACK_STATUS_DATA_AVAILABLE | WINHTTP_CALLBACK_STATUS_READ_COMPLETE | WINHTTP_CALLBACK_STATUS_HANDLE_CLOSING | WINHTTP_CALLBACK_STATUS_REQUEST_ERROR;
	ReturnIfFailed(WinHttpSetStatusCallback(request, &Internets::DataCallback, dataNotifications, 0) != WINHTTP_INVALID_STATUS_CALLBACK);
	ReturnIfFailed(WinHttpQueryDataAvailable(request, nullptr));
	auto dataFuture = internets->dataPromise.get_future();
	if (dataFuture.wait_for(timeout) == std::future_status::timeout) {
		return { FromPlatformError(ERROR_TIMEOUT), {} };
	} else {
		auto errorCode = dataFuture.get();
		if (errorCode) {
			return { FromPlatformError(errorCode), {} };
		}
	}
	return { 0, internets->buffer };
}

void Platform::FinishHttpRequest(uintptr_t& handle, std::chrono::milliseconds timeout) {
	auto handle_ = InterlockedExchange(&handle, std::remove_reference_t<decltype(handle)>());
	if (handle_) {
		std::unique_ptr<Internets> internets(reinterpret_cast<Internets*>(handle_));
		if (!WinHttpWebSocketShutdown(internets->handles.back(), WINHTTP_WEB_SOCKET_ENDPOINT_TERMINATED_CLOSE_STATUS, nullptr, 0)) {
			auto future = internets->shutdownPromise.get_future();
			if (future.wait_for(timeout) == std::future_status::ready) {
				future.get();
			}
		} else if (timeout.count()) {
			DebugWriteLine(_T("WinHttpWebSocketShutdown error %d"), GetLastError());
		}
	}
}
