#include "pch.h"
#include "RawSocket.h"
#include "SecureSocket.h"
#include "Http.h"
#include "HttpUtil.h"

using namespace Twitch;

struct WebSocket::WebSocketImpl {
	ClosedFn OnClosed = DefaultClosedFn;
	ReceivedFn OnReceived = DefaultReceivedFn;
	bool GetIsOpen() const { return isOpen; }
	__declspec(property(get = GetIsOpen)) bool const IsOpen;

	WebSocketImpl() = default;
	WebSocketImpl(WebSocketImpl const&) = delete;
	WebSocketImpl(WebSocketImpl&&) = default;
	WebSocketImpl& operator=(WebSocketImpl const&) = delete;
	WebSocketImpl& operator=(WebSocketImpl&&) = default;
	~WebSocketImpl() {
		OnClosed = DefaultClosedFn;
		OnReceived = DefaultReceivedFn;
		Close();
	}

	void Close() {
		if(__sync_swap(&isOpen, false)) {
			socket->Close();
			socket = nullptr;
			receiveTask.get();
		}
	}

	int Open(string_t url, std::chrono::milliseconds timeout) {
		if(IsOpen) {
			DebugWriteLine(_T("[WebSocketImpl::Open] warning:  Open called without preceding Close"));
			assert(!IsOpen);
			Close();
		}

		// Connect to the server.
		Uri uri(url);
		bool isSecure = uri.GetProtocol() == _T("wss");
		unsigned short port;
		if(!uri.GetPort(port)) {
			port = isSecure ? 443 : 80;
		}
		socket = isSecure ? std::static_pointer_cast<Socket>(std::make_shared<SecureSocket>()) : std::make_shared<RawSocket>();
		int errorCode = socket->Connect(uri.GetHostName().c_str(), port, timeout);
		if(errorCode < 0) {
			return errorCode;
		}

		// Send the WebSocket HTTP request.
		std::stringstream ss;
		ss << "GET " << uri.GetPath() << " HTTP/1.1\r\n"
			"Upgrade: websocket\r\n"
			"Connection: Upgrade\r\n"
			"Sec-WebSocket-Key: x3JJHMbDL1EzLkh9GBhXDw==\r\n"
			"Sec-WebSocket-Version: 13\r\n"
			"Host: " << uri.GetHostName();
		if(port != (isSecure ? 443 : 80)) {
			ss << ':' << port;
		}
		ss << "\r\n\r\n";
		auto s = ss.str();
		errorCode = socket->SendAll(s.c_str(), s.size());
		if(errorCode < 0) {
			return errorCode;
		}

		// Receive and interpret the HTTP response.
		errorCode = ERROR_INVALID_RESPONSE;
		char responseText[1024];
		for(size_t i = 0; i < sizeof(responseText);) {
			int result = socket->Receive(responseText + i, sizeof(responseText) - i);
			if(result < 0) {
				errorCode = result;
				break;
			}
			char const eom[] = "\r\n\r\n";
			auto it = std::search(responseText + i, responseText + i + result, eom, eom + sizeof(eom) - 1);
			if(it != responseText + i + result) {
				if(i + result >= 10 && std::equal(responseText + i, responseText + i + 9, "HTTP/1.1 ") && atoi(responseText + i + 9) == 101) {
					// TODO:  generate a random value for the Sec-WebSocket-Key request header
					// and validate the Sec-WebSocket-Accept header value in the response.
					errorCode = 0;
				}
				break;
			}
			i += result;
		}

		if(errorCode == 0) {
			// Start the receiver task.
			isOpen = true;
			receiveTask = std::async(std::launch::async, [this, weak = std::weak_ptr<Socket>(socket)] {
				try {
					std::deque<uint8_t> data;
					for(;;) {
						std::vector<decltype(data)::value_type> frameData;
						auto strong = weak.lock();
						if(!strong || !ReceiveFrame(strong, data, frameData)) {
							break;
						}
						SafelyInvoke([this, &frameData]() mutable { OnReceived(frameData); });
					}
				} catch(std::exception const& ex) {
					DebugWriteLine(_T("[PlatformWebSocket::Connect.receiveTask] unexpected exception \"%s\""), ex.what());
				}
				if(__sync_swap(&isOpen, false)) {
					socket->Close();
					socket = nullptr;
					SafelyInvoke(OnClosed);
				}
			});
		}
		return errorCode;
	}

	int Send(void const* data, size_t size) {
		std::vector<uint8_t> header;
		header.assign(2 + (size >= 126 ? 2 : 0) + (size >= 65536 ? 6 : 0), 0);
		header[0] = 0x82; // binary frame
		if(size < 126) {
			header[1] = size & 0xff;
		} else if(size < 65536) {
			header[1] = 126;
			header[2] = (size >> 8) & 0xff;
			header[3] = (size >> 0) & 0xff;
		} else {
			header[1] = 127;
			header[2] = (size >> 56) & 0xff;
			header[3] = (size >> 48) & 0xff;
			header[4] = (size >> 40) & 0xff;
			header[5] = (size >> 32) & 0xff;
			header[6] = (size >> 24) & 0xff;
			header[7] = (size >> 16) & 0xff;
			header[8] = (size >> 8) & 0xff;
			header[9] = (size >> 0) & 0xff;
		}
		int result = socket->SendAll(header);
		if(result > 0) {
			result = socket->SendAll(data, size);
		}
		return result < 0 ? result : result ? 0 : ERROR_INVALID_STATE;
	}

private:
	std::shared_ptr<Socket> socket;
	std::future<void> receiveTask;
	bool isOpen = false;

	static int HasFrame(std::deque<uint8_t>& data, std::vector<uint8_t>& frameData) {
		if(data.size() < 2) {
			return false;
		}
		bool isValid = data[0] & 0x80;
		auto type = data[0] & 0xf;
		bool isMasked = data[1] & 0x80;
		auto size0 = data[1] & 0x7f;
		size_t headerSize = 2 + (size0 == 126 ? 2 : 0) + (size0 == 127 ? 8 : 0) + (isMasked ? 4 : 0);
		if(data.size() < headerSize) {
			return false;
		}
		uint64_t size = 0;
		int i = 0;
		if(size0 < 126) {
			size = size0;
			i = 2;
		} else if(size0 == 126) {
			size |= static_cast<decltype(size)>(data[2]) << 8;
			size |= static_cast<decltype(size)>(data[3]) << 0;
			i = 4;
		} else if(size0 == 127) {
			size |= static_cast<decltype(size)>(data[2]) << 56;
			size |= static_cast<decltype(size)>(data[3]) << 48;
			size |= static_cast<decltype(size)>(data[4]) << 40;
			size |= static_cast<decltype(size)>(data[5]) << 32;
			size |= static_cast<decltype(size)>(data[6]) << 24;
			size |= static_cast<decltype(size)>(data[7]) << 16;
			size |= static_cast<decltype(size)>(data[8]) << 8;
			size |= static_cast<decltype(size)>(data[9]) << 0;
			i = 10;
			if(size & 0x8000000000000000ull) {
				// https://tools.ietf.org/html/rfc6455 says "the most significant bit MUST be 0."
				DebugWriteLine(_T("[PlatformWebSocket::HasFrame] frame has invalid frame length; closing"));
				return -1;
			}
		}
		if(data.size() < headerSize + size) {
			return false;
		}
		if(type != 1 && type != 2) {
			DebugWriteLine(_T("[PlatformWebSocket::HasFrame] unexpected type %d"), type);
			return -1;
		}
		if(isMasked) {
			for(size_t i = 0; i < size; ++i) {
				data[i + headerSize] ^= data[i & 0x3];
			}
		}
		if(isValid) {
			frameData.assign(data.cbegin() + headerSize, data.cbegin() + headerSize + size);
			data.erase(data.cbegin(), data.cbegin() + headerSize + size);
			return true;
		}
		data.erase(data.cbegin(), data.cbegin() + headerSize + size);
		return HasFrame(data, frameData);
	}

	static bool ReceiveFrame(std::shared_ptr<Socket> socket, std::deque<uint8_t>& data, std::vector<uint8_t>& frameData) {
		do {
			std::remove_reference_t<decltype(data)>::value_type buffer[1024];
			int result = socket->Receive(buffer, sizeof(buffer));
			if(result <= 0) {
				if(result < 0) {
					DebugWriteLine(_T("[PlatformWebSocket::Connect.receiveTask] socket->Receive error %#x (%d)"),
						result, sce_net_errno);
				}
				return false;
			}
			data.insert(data.cend(), buffer, buffer + result);
		} while(!HasFrame(data, frameData));
		return true;
	}

	template<typename FN>
	void SafelyInvoke(FN fn) {
		try {
			fn();
		} catch(std::exception const& ex) {
			DebugWriteLine(_T("[WebSocketImpl::SafelyInvoke] function threw an exception \"%s\""), ex.what());
			assert(false);
		}
	}
};

#include "WebSocket.inl"
