#include "pch.h"
#include <wrl.h>
#include <ixmlhttprequest2.h>
#include "Internal.h"
#include "../Shared/HttpUtil.h"
#include "../Shared/Http.h"

#define ReturnIfFailed(hr) do { HRESULT hr_= (hr); if(FAILED(hr_)) return hr_; } while(false)

using namespace Microsoft::WRL;
using namespace Twitch;

namespace {
	class InternalResponse : public RuntimeClass<RuntimeClassFlags<ClassicCom>, IXMLHTTPRequest2Callback> {
	public:
		InternalResponse(std::promise<int>&& responsePromise);
		virtual ~InternalResponse();
		int GetErrorCode() const { return errorCode; }
		std::map<tstring, tstring> const& GetHeaders() const { return headerMap; }
		std::vector<char> const& GetResponse() const { return responseData; }
		int GetResponseCode() const { return SUCCEEDED(errorCode) ? static_cast<int>(responseCode) : errorCode; }

		__declspec(property(get = GetErrorCode)) int const ErrorCode;
		__declspec(property(get = GetHeaders)) std::map<tstring, tstring> const& Headers;
		__declspec(property(get = GetResponse)) std::vector<char> const& Response;
		__declspec(property(get = GetResponseCode)) int const ResponseCode;

		// IXMLHTTPRequest2Callback
		HRESULT OnRedirect(IXMLHTTPRequest2* /*request*/, WCHAR const* /*redirectUrl*/) { return S_OK; }
		HRESULT OnHeadersAvailable(IXMLHTTPRequest2* request, DWORD responseCode, wchar_t const* status);
		HRESULT OnDataAvailable(IXMLHTTPRequest2* request, ISequentialStream* responseStream);
		HRESULT OnResponseReceived(IXMLHTTPRequest2* request, ISequentialStream* responseStream);
		HRESULT OnError(IXMLHTTPRequest2* /*request*/, HRESULT /*error*/);

	private:
		std::promise<int> responsePromise;
		std::map<tstring, tstring> headerMap;
		std::vector<char> responseData;
		HRESULT errorCode;
		DWORD responseCode;

		static HRESULT AppendResponse(ISequentialStream* stream, std::vector<char>& data);
	};

	InternalResponse::InternalResponse(std::promise<int>&& responsePromise_) :
		responsePromise(std::move(responsePromise_)), // TODO:  remove std::move?
		errorCode(S_OK),
		responseCode(0) {}

	InternalResponse::~InternalResponse() {}

	HRESULT InternalResponse::OnHeadersAvailable(IXMLHTTPRequest2* request, DWORD responseCode_, wchar_t const* /*status*/) {
		// Save the response code to use as the final result of this network
		// activity if everything else is successful.
		responseCode = responseCode_;

		// Get and parse the response headers.
		wchar_t* headers = nullptr;
		errorCode = request->GetAllResponseHeaders(&headers);
		if(SUCCEEDED(errorCode)) {
			ParseHeaders(headers, headerMap);
			::CoTaskMemFree(headers);
		} else {
			// An error occurred.  Abort the request.
			errorCode = request->Abort();
		}
		return errorCode;
	}

	HRESULT InternalResponse::OnDataAvailable(IXMLHTTPRequest2* request, ISequentialStream* responseStream) {
		// Get available data from the response stream.
		errorCode = AppendResponse(responseStream, responseData);
		if(FAILED(errorCode)) {
			request->Abort();
		}
		return errorCode;
	}

	HRESULT InternalResponse::OnResponseReceived(IXMLHTTPRequest2* /*request*/, ISequentialStream* responseStream) {
		// Get the last of the data from the response stream.
		errorCode = AppendResponse(responseStream, responseData);

		// Resolve the promise.
		responsePromise.set_value(errorCode);
		return errorCode;
	}

	HRESULT InternalResponse::OnError(IXMLHTTPRequest2* /*request*/, HRESULT error) {
		// Resolve the promise.
		if(SUCCEEDED(errorCode) && error != E_ABORT) {
			errorCode = error;
		}
		responsePromise.set_value(errorCode);
		return S_OK;
	}

	HRESULT InternalResponse::AppendResponse(ISequentialStream* stream, std::vector<char>& data) {
		// Read all available data from the stream and add them to the response.
		static size_t const readSize = 4096;
		data.resize(data.size() + readSize);
		auto* p = &data[0] + data.size() - readSize;
		for(;;) {
			ULONG n;
			auto result = stream->Read(p, readSize, &n);
			ReturnIfFailed(result);
			if(result == S_FALSE) {
				data.resize(data.size() - readSize + n);
				break;
			} else {
				p = &data[0] + data.size();
				data.resize(data.size() + readSize);
			}
		}
		return S_OK;
	}

	class HttpRequestStream : public RuntimeClass<RuntimeClassFlags<ClassicCom>, ISequentialStream> {
	public:
		HttpRequestStream(char const* buffer, size_t bufferSize);
		~HttpRequestStream() {}

		// ISequentialStream
		HRESULT Read(void* buffer, ULONG bufferSize, ULONG* bytesRead);
		HRESULT Write(void const* /*buffer*/, ULONG /*bufferSize*/, ULONG* /*bytesWritten*/) { return E_NOTIMPL; }

	private:
		char const* m_buffer;
		size_t m_bufferSize;
		size_t m_seekLocation;
	};

	HttpRequestStream::HttpRequestStream(char const* buffer, size_t bufferSize) :
		m_buffer(buffer),
		m_bufferSize(bufferSize),
		m_seekLocation(0) {}

	HRESULT HttpRequestStream::Read(void* buffer, ULONG bufferSize, ULONG* bytesRead) {
		// Validate the parameters.
		if(buffer == nullptr) {
			return FromPlatformError(ERROR_BAD_ARGUMENTS);
		}

		// If there aren't sufficient data left, copy what are available and
		// return S_FALSE.  Otherwise, copy what are requested and return S_OK.
		HRESULT result = S_OK;
		if(bufferSize + m_seekLocation > m_bufferSize) {
			result = S_FALSE;
			bufferSize = static_cast<ULONG>(m_bufferSize - m_seekLocation);
		}
		memcpy(buffer, m_buffer + m_seekLocation, bufferSize);
		*bytesRead = bufferSize;
		m_seekLocation += bufferSize;
		return result;
	}
}

#undef ReturnIfFailed
#define ReturnIfFailedT(result,type) do { auto result_= (result); if(FAILED(result_)) return std::make_tuple(result_, type()); } while(false)
#define ReturnIfFailed(result) ReturnIfFailedT((result), uintptr_t)

std::tuple<int, uintptr_t> Platform::StartHttpRequest(string_t url, HttpRequestType requestType,
	std::chrono::milliseconds timeout, std::vector<HttpParam> const& requestHeaders, std::vector<char> const& requestBody) {
	// Create an HTTP request object.
	ComPtr<IXMLHTTPRequest2> request;
	auto result = ::CoCreateInstance(__uuidof(FreeThreadedXMLHTTP60), nullptr, CLSCTX_SERVER, __uuidof(IXMLHTTPRequest2), &request);
	ReturnIfFailed(result);

	// Determine the HTTP verb.
	string_t method;
	switch(requestType) {
	case HTTP_PUT_REQUEST:
		method = _T("PUT");
		break;
	case HTTP_POST_REQUEST:
		method = _T("POST");
		break;
	case HTTP_DELETE_REQUEST:
		method = _T("DELETE");
		break;
	default:
		method = _T("GET");
		break;
	}

	// Create a InternalResponse object to handle the HTTP response.
	std::promise<int> responsePromise;
	std::future<int> responseFuture = responsePromise.get_future();
	auto response = Make<InternalResponse>(std::move(responsePromise));

	// Open a request for the desired resource.
	result = request->Open(method, url, response.Get(), nullptr, nullptr, nullptr, nullptr);
	ReturnIfFailed(result);

	// Set the time-out to the requested value.
	if(timeout.count()) {
		request->SetProperty(XHR_PROP_TIMEOUT, timeout.count());
	}

	// Add the headers.
	for(auto const& header : requestHeaders) {
		request->SetRequestHeader(header.paramName.c_str(), header.paramValue.c_str());
	}

	// Send the HTTP request.
	if(requestBody.empty()) {
		result = request->Send(nullptr, 0);
	} else {
		auto stream = Make<HttpRequestStream>(requestBody.data(), requestBody.size());
		result = request->Send(stream.Get(), requestBody.size());
	}
	ReturnIfFailed(result);

	// Await the response.  The InternalResponse object will handle it and resolve the promise.
	result = responseFuture.get();
	ReturnIfFailed(result);
	return { 0, reinterpret_cast<uintptr_t>(response.Detach()) };
}

int Platform::GetHttpResponseResultCode(uintptr_t handle) {
	return reinterpret_cast<InternalResponse*>(handle)->ResponseCode;
}

#undef ReturnIfFailed
#define T std::map<tstring, tstring>
#define ReturnIfFailed(result) ReturnIfFailedT((result), T)

std::tuple<int, std::map<tstring, tstring>> Platform::GetHttpResponseHeaders(uintptr_t handle) {
	auto* p = reinterpret_cast<InternalResponse*>(handle);
	return { p->ErrorCode, p->Headers };
}

#undef ReturnIfFailed
#define ReturnIfFailed(result) ReturnIfFailedT((result), std::vector<char>)

std::tuple<int, std::vector<char>> Platform::GetHttpResponse(uintptr_t handle) {
	auto* p = reinterpret_cast<InternalResponse*>(handle);
	return { p->ErrorCode, p->Response };
}

void Platform::FinishHttpRequest(uintptr_t& handle) {
	ComPtr<InternalResponse> p;
	p.Attach(reinterpret_cast<InternalResponse*>(handle));
	handle = 0;
}
