#include "pch.h"
#include "../../base-sdk/Shared/Internal.h"
#include "../Shared/Twitch.h"
#include "../Shared/RawSocket.h"
#include "../Shared/SecureSocket.h"

using namespace std::literals;
using namespace Microsoft::VisualStudio::CppUnitTestFramework;
using namespace Twitch;

namespace WindowsLibTest {
	TEST_CLASS(SocketTest) {
public:
	TEST_METHOD(Connect_Success) {
		RawSocket socket;
		Assert::AreEqual(0, socket.Connect(_T("google.com"), 80));
		char const request[] = "GET / HTTP/1.1\r\nUser-Agent: curl/7.16.3 libcurl/7.16.3 OpenSSL/0.9.7l zlib/1.2.3\r\nHost: www.example.com\r\nAccept-Language: en\r\n\r\n";
		size_t const n = sizeof(request) - 1;
		Assert::AreEqual(static_cast<int>(n), socket.SendAll(request, n));
		std::vector<char> response(99);
		Assert::AreEqual(static_cast<int>(response.size()), socket.ReceiveAll(response));
		tstring const _404 = _T("HTTP/1.1 404 Not Found\r\n");
		Assert::AreEqual(_404, tstring(response.cbegin(), response.cbegin() + _404.size()));
	}

	TEST_METHOD(SecureConnect_Success) {
#ifdef __ORBIS__
		SecureSocket socket;
		Assert::AreEqual(0, socket.Connect(_T("dev.twitch.tv"), 443));
		char const request[] = "GET / HTTP/1.1\r\nUser-Agent: curl/7.16.3 libcurl/7.16.3 OpenSSL/0.9.7l zlib/1.2.3\r\nHost: www.example.com\r\nAccept-Language: en\r\n\r\n";
		size_t const n = sizeof(request) - 1;
		Assert::AreEqual(static_cast<int>(n), socket.SendAll(request, n));
		std::vector<char> response(99);
		Assert::AreEqual(static_cast<int>(response.size()), socket.ReceiveAll(response));
		tstring const _403 = _T("HTTP/1.1 403 Forbidden\r\n");
		Assert::AreEqual(_403, tstring(response.cbegin(), response.cbegin() + _403.size()));
#else
		Assert::IsTrue(true);
#endif
	}

	TEST_METHOD(WebSocket_Success) {
		WebSocket webSocket;
		Assert::IsFalse(webSocket.IsOpen);
		std::promise<std::string> promise;
		webSocket.OnReceived = [&promise](std::string_view data) {
			promise.set_value(std::string(data));
		};
		Assert::AreEqual(0, webSocket.Open(_T("wss://metadata.twitch.tv/api/ingest")));
		Assert::IsTrue(webSocket.IsOpen);
		auto message = std::string("{\"connect\":{\"data\":{\"value\":0},\"debug\":true,\"env\":\"dev\",\"game_id\":\"1\",\"session_id\":\"test\",\"token\":\"test\"}}");
		webSocket.Send(message);
#ifdef __NX__
		bool isRunning = true;
		auto fn = [&isRunning, &webSocket] {
			while (isRunning && webSocket.Process()) {
				continue;
			}
		};
		auto task = std::async(std::launch::async, fn);
#endif
		auto future = promise.get_future();
		if (future.wait_for(4s) == std::future_status::ready) {
			auto response = future.get();
			auto expected = std::string("{\"error\":{\"code\":\"invalid_connect_token\",\"message\":\"auth token is not valid\"}}");
			Assert::AreEqual(expected, response);
		} else {
			Assert::Fail(_T("Did not receive a message"));
		}
#ifdef __NX__
		isRunning = false;
		task.get();
		webSocket.Close();
#endif
	}

	TEST_METHOD(WebSocket_ConnectFailure) {
		WebSocket webSocket;
		Assert::IsFalse(webSocket.IsOpen);
		long result = webSocket.Open(_T("ws://localhost:65533"));
#ifdef __ORBIS__
		Assert::AreEqual(SCE_NET_ERROR_ECONNREFUSED, result);
#else
		Assert::AreEqual(FromPlatformError(ERROR_CANNOT_CONNECT), result);
#endif
		Assert::IsFalse(webSocket.IsOpen);
	}

	TEST_METHOD(WebSocket_TimeoutFailure) {
#ifndef __NX__
		WebSocket webSocket;
		Assert::IsFalse(webSocket.IsOpen);
		std::promise<std::vector<std::uint8_t>> promise;
		long result = webSocket.Open(_T("wss://echo.websocket.org"), 110ms);
# ifdef __ORBIS__
		Assert::AreEqual(-68, result);
# else
		Assert::AreEqual(FromPlatformError(ERROR_TIMEOUT), result);
# endif
		Assert::IsFalse(webSocket.IsOpen);
#endif
	}
	};
}
