#include "pch.h"
#include "../Shared/Internal.h"
#include "../Shared/HttpUtil.h"
#include "../Shared/Http.h"

#pragma comment(lib, "winhttp.lib")

using namespace Twitch;

struct WebSocket::WebSocketImpl {
	// Do not invoke Open or delete the WebSocketImpl instance in the OnClosed
	// handler or a deadlock will occur.  Do not invoke Close, Open, or delete the
	// WebSocketImpl instance in the OnReceived handler or a deadlock will occur.

	ClosedFn OnClosed = DefaultClosedFn;
	ReceivedFn OnReceived = DefaultReceivedFn;
	bool GetIsOpen() const { return handle.IsOpen; }
	__declspec(property(get = GetIsOpen)) bool const IsOpen;

	WebSocketImpl() = default;
	WebSocketImpl(WebSocketImpl const&) = delete;
	WebSocketImpl(WebSocketImpl&&) = default;
	WebSocketImpl& operator=(WebSocketImpl const&) = delete;
	WebSocketImpl& operator=(WebSocketImpl&&) = default;
	~WebSocketImpl() {
		Close();
	}

	void Close() {
		if (handle.Close()) {
			SafelyInvoke(OnClosed);
		}
	}

	int Open(string_t url, std::chrono::milliseconds timeout) {
		if (IsOpen) {
			DebugWriteLine(_T("[WebSocketImpl::Open] warning:  Open called without preceding Close"));
			assert(!IsOpen);
			Close();
		}
		int errorCode;
		uintptr_t handle_;
		std::tie(errorCode, handle_) = Platform::StartHttpRequest(url, HttpRequestType::HTTP_WEB_SOCKET_REQUEST,
			timeout, std::vector<HttpParam>(), HttpRequest::EmptyBody);
		if (!errorCode) {
			handle = SynchronizedHandle(handle_, timeout);
			auto* internets = reinterpret_cast<Internets*>(handle_);
			internets->onReceive = [this](char* p, size_t n) { OnReceived(std::string_view(p, n)); };
			auto webSocket = internets->handles.back();
			constexpr auto const notifications = WINHTTP_CALLBACK_STATUS_READ_COMPLETE | WINHTTP_CALLBACK_STATUS_WRITE_COMPLETE | WINHTTP_CALLBACK_STATUS_HANDLE_CLOSING | WINHTTP_CALLBACK_STATUS_REQUEST_ERROR;
			WinHttpSetStatusCallback(webSocket, &Internets::WebSocketCallback, notifications, 0);
#pragma prefast(suppress: 6387, "WinHttpWebSocketReceive ignores the fourth and fifth parameters in asynchronous mode.")
			WinHttpWebSocketReceive(webSocket, &internets->buffer[0], static_cast<DWORD>(internets->buffer.size()), nullptr, nullptr);
		}
		return errorCode;
	}

	int Send(void const* data, size_t size) {
		auto webSocket = handle.Get();
		auto errorCode = WinHttpWebSocketSend(webSocket,
			WINHTTP_WEB_SOCKET_UTF8_MESSAGE_BUFFER_TYPE,
			const_cast<PVOID>(data), static_cast<DWORD>(size));
		if (errorCode != ERROR_SUCCESS) {
			return FromPlatformError(errorCode);
		}
		return 0;
	}

private:
	struct SynchronizedHandle {
		explicit SynchronizedHandle(uintptr_t handle, std::chrono::milliseconds timeout) : handle(handle), timeout(timeout) {}
		SynchronizedHandle() = default;
		SynchronizedHandle(SynchronizedHandle const&) = delete;

		SynchronizedHandle& operator=(SynchronizedHandle const&) = delete;
		SynchronizedHandle& operator=(SynchronizedHandle&& that) noexcept {
			std::swap(handle, that.handle);
			return *this;
		}

		~SynchronizedHandle() {
			Close();
		}

		bool GetIsOpen() const { return static_cast<bool>(handle); }
		__declspec(property(get = GetIsOpen)) bool const IsOpen;

		bool Close() {
			std::lock_guard<decltype(mutex)> lock(mutex);
			if (handle) {
				Platform::FinishHttpRequest(handle, timeout);
				return true;
			}
			return false;
		}

		HINTERNET Get() {
			std::lock_guard<decltype(mutex)> lock(mutex);
			return Internets::GetHttpRequest(handle);
		}

	private:
		std::mutex mutex;
		uintptr_t handle{};
		std::chrono::milliseconds timeout;
	};

	SynchronizedHandle handle;

	template<typename FN>
	void SafelyInvoke(FN fn) {
		try {
			fn();
		} catch (std::exception const& ex) {
			UNREFERENCED_PARAMETER(ex);
			DebugWriteLine(_T("[WebSocketImpl::SafelyInvoke] function threw an exception \"%hs\""), ex.what());
			assert(false);
		}
	}
};

#include "WebSocket.inl"
