#include "pch.h"
#include "Internal.h"
#include "../Shared/HttpUtil.h"
#include "../Shared/Http.h"

#pragma comment(lib, "winhttp.lib")

using namespace Twitch;

namespace {
	struct Internets {
		~Internets() {
			std::for_each(handles.crbegin(), handles.crend(), WinHttpCloseHandle);
		}
		void Add(HINTERNET internet) {
			handles.emplace_back(internet);
		}
		void ReplaceHttpRequest(HINTERNET internet) {
			assert(handles.size() == 3);
			WinHttpCloseHandle(handles.back());
			handles.back() = internet;
		}
		static HINTERNET GetHttpRequest(uintptr_t handle) {
			if(handle) {
				auto const* p = reinterpret_cast<Internets*>(handle);
				return p->handles.back();
			}
			return HINTERNET();
		}

	private:
		std::vector<HINTERNET> handles;
	};
}

#define ReturnIfFailedT(result,type) do { auto result_= (result); if(!result_) return { FromPlatformError(GetLastError()), type() }; } while(false)
#define ReturnIfFailed(result) ReturnIfFailedT((result), uintptr_t)

std::tuple<int, uintptr_t> Platform::StartHttpRequest(string_t url,
	HttpRequestType httpReqType,
	std::chrono::milliseconds timeout,
	std::vector<HttpParam> const& requestHeaders,
	std::vector<char> const& requestBody) {
	// Check the URL scheme to determine if SSL is necessary.
	Uri uri{ url };
	bool useSsl;
	if(uri.GetProtocol() == _T("https") || uri.GetProtocol() == _T("wss")) {
		useSsl = true;
	} else if(uri.GetProtocol() == _T("http") || uri.GetProtocol() == _T("ws")) {
		useSsl = false;
	} else {
		return { FromPlatformError(ERROR_BAD_ARGUMENTS), 0 };
	}

	// Open access to the Internet.
	auto internet = WinHttpOpen(_T("Twitch Game SDK"), WINHTTP_ACCESS_TYPE_AUTOMATIC_PROXY,
		WINHTTP_NO_PROXY_NAME, WINHTTP_NO_PROXY_BYPASS, 0);
	ReturnIfFailed(internet);
	auto internets = std::make_unique<Internets>();
	internets->Add(internet);

	// Enable deflate and gzip decoding.
	DWORD decompressionFlag = WINHTTP_DECOMPRESSION_FLAG_ALL;
	WinHttpSetOption(internet, WINHTTP_OPTION_DECOMPRESSION, &decompressionFlag, sizeof(decompressionFlag));

	// Set all time-outs to the requested value.
	auto timeoutMs = static_cast<DWORD>(timeout.count());
	BOOL areTimeoutsSet = WinHttpSetOption(internet, WINHTTP_OPTION_CONNECT_TIMEOUT, &timeoutMs, sizeof(timeoutMs)) &&
		WinHttpSetOption(internet, WINHTTP_OPTION_RECEIVE_RESPONSE_TIMEOUT, &timeoutMs, sizeof(timeoutMs)) &&
		WinHttpSetOption(internet, WINHTTP_OPTION_SEND_TIMEOUT, &timeoutMs, sizeof(timeoutMs)) &&
		WinHttpSetOption(internet, WINHTTP_OPTION_RECEIVE_TIMEOUT, &timeoutMs, sizeof(timeoutMs));
	if(!areTimeoutsSet) {
		DebugWriteLine(_T("[Platform::StartHttpRequest] unable to set timeouts: error %d"), GetLastError());
		return { FromPlatformError(GetLastError()), 0 };
	}

	// Connect with the server.
	uint16_t port;
	if(!uri.GetPort(port)) {
		port = INTERNET_DEFAULT_PORT;
	}
	auto httpConnection = WinHttpConnect(internet, uri.GetHostName().c_str(), port, 0);
	ReturnIfFailed(httpConnection);
	internets->Add(httpConnection);

	// Determine the HTTP verb.
	string_t requestType;
	switch(httpReqType) {
	case HTTP_PUT_REQUEST:
		requestType = _T("PUT");
		break;
	case HTTP_POST_REQUEST:
		requestType = _T("POST");
		break;
	case HTTP_DELETE_REQUEST:
		requestType = _T("DELETE");
		break;
	default:
		requestType = _T("GET");
		break;
	}

	// Open a request for the desired resource.
	DWORD flags = WINHTTP_FLAG_REFRESH | (useSsl ? WINHTTP_FLAG_SECURE : 0);
	auto httpRequest = WinHttpOpenRequest(httpConnection, requestType, uri.GetResource().c_str(),
		nullptr, WINHTTP_NO_REFERER, WINHTTP_DEFAULT_ACCEPT_TYPES, flags);
	ReturnIfFailed(httpRequest);
	internets->Add(httpRequest);

	// If it's a WebSocket request, configure it as such.
	if(httpReqType == HTTP_WEB_SOCKET_REQUEST) {
		ReturnIfFailed(WinHttpSetOption(httpRequest, WINHTTP_OPTION_UPGRADE_TO_WEB_SOCKET, nullptr, 0));
	}

	// Build the headers.
	auto requestHeaderData = BuildHttpHeader(requestHeaders);

	// Send the HTTP request.
	auto fn = [httpRequest, &requestHeaderData, &requestBody] {
		auto const* const headers = requestHeaderData.c_str();
		auto const headersSize = static_cast<DWORD>(requestHeaderData.size());
		auto* const body = const_cast<char*>(requestBody.data());
		auto const bodySize = static_cast<DWORD>(requestBody.size());
		if(WinHttpSendRequest(httpRequest, headers, headersSize, body, bodySize, bodySize, 0)) {
			return static_cast<decltype(GetLastError())>(0);
		} else {
			return GetLastError();
		}
	};
	auto task = std::async(std::launch::async, fn);
	if(task.wait_for(timeout) == std::future_status::timeout) {
		internets.reset();
		task.get();
		return { FromPlatformError(ERROR_TIMEOUT), uintptr_t() };
	} else {
		auto errorCode = task.get();
		if(errorCode) {
			return { FromPlatformError(errorCode), uintptr_t() };
		}
	}

	// Receive the HTTP response.
	ReturnIfFailed(WinHttpReceiveResponse(httpRequest, nullptr));

	// If it's a WebSocket request, replace the HTTP request handle with the
	// WebSocket handle in the return value.
	if(httpReqType == HTTP_WEB_SOCKET_REQUEST) {
		auto webSocketHandle = WinHttpWebSocketCompleteUpgrade(httpRequest, 0);
		ReturnIfFailed(webSocketHandle);
		internets->ReplaceHttpRequest(webSocketHandle);
	}

	return { 0, reinterpret_cast<uintptr_t>(internets.release()) };
}

int Platform::GetHttpResponseResultCode(uintptr_t handle) {
	// Extract the status code from the response.
	DWORD statusCode = 0;
	DWORD statusCodeSize = sizeof(statusCode);
	auto httpRequest = Internets::GetHttpRequest(handle);
	auto result = WinHttpQueryHeaders(httpRequest, WINHTTP_QUERY_STATUS_CODE | WINHTTP_QUERY_FLAG_NUMBER,
		WINHTTP_HEADER_NAME_BY_INDEX, &statusCode, &statusCodeSize, WINHTTP_NO_HEADER_INDEX);
	return result ? static_cast<int>(statusCode) : FromPlatformError(GetLastError());
}

#undef ReturnIfFailed
#define ReturnIfFailed(result) ReturnIfFailedT((result), stringmap)

std::tuple<int, std::map<tstring, tstring>> Platform::GetHttpResponseHeaders(uintptr_t handle) {
	// Determine the size of the headers.
	auto httpRequest = Internets::GetHttpRequest(handle);
	DWORD headersSize = 0;
	auto result = WinHttpQueryHeaders(httpRequest, WINHTTP_QUERY_RAW_HEADERS_CRLF,
		WINHTTP_HEADER_NAME_BY_INDEX, nullptr, &headersSize, WINHTTP_NO_HEADER_INDEX);
	if(result) {
		return { 0, std::map<tstring, tstring>() };
	} else if(GetLastError() != ERROR_INSUFFICIENT_BUFFER) {
		ReturnIfFailed(false);
	}

	// Extract the headers.
	std::vector<TCHAR> headerBuffer;
	headerBuffer.resize(headersSize / sizeof(TCHAR));
	result = WinHttpQueryHeaders(httpRequest, WINHTTP_QUERY_RAW_HEADERS_CRLF,
		WINHTTP_HEADER_NAME_BY_INDEX, headerBuffer.data(), &headersSize, WINHTTP_NO_HEADER_INDEX);
	ReturnIfFailed(result);

	// Parse the headers into a dictionary.
	std::map<tstring, tstring> responseHeaders;
	ParseHeaders(headerBuffer.data(), headerBuffer.size(), responseHeaders);
	return { 0, responseHeaders };
}

#undef ReturnIfFailed
#define ReturnIfFailed(result) ReturnIfFailedT((result), std::vector<char>)

std::tuple<int, std::vector<char>> Platform::GetHttpResponse(uintptr_t handle) {
	auto httpRequest = Internets::GetHttpRequest(handle);
	char buffer[1024];
	DWORD nBytesRead = 0;
	std::vector<char> response;
	do {
		auto result = WinHttpReadData(httpRequest, buffer, sizeof(buffer), &nBytesRead);
		ReturnIfFailed(result);
		response.insert(response.cend(), buffer, buffer + nBytesRead);
	} while(nBytesRead != 0);
	return { 0, response };
}

void Platform::FinishHttpRequest(uintptr_t& handle) {
	auto handle_ = InterlockedExchange(&handle, std::remove_reference_t<decltype(handle)>());
	std::unique_ptr<Internets>(reinterpret_cast<Internets*>(handle_));
}

#include "PlatformWebSocket.inl"
