#include "pch.h"
#include <winrt/Windows.Networking.Sockets.h>
#include <winrt/Windows.Storage.Streams.h>
#include <wrl.h>
#include "Internal.h"
#include "../Shared/Http.h"

using namespace Microsoft::WRL;
using namespace winrt::Windows::Storage::Streams;
using namespace winrt::Windows::Networking::Sockets;
using namespace Twitch;

using Uri = winrt::Windows::Foundation::Uri;
using AsyncStatus = winrt::Windows::Foundation::AsyncStatus;

struct WebSocket::WebSocketImpl {
	ClosedFn OnClosed = DefaultClosedFn;
	ReceivedFn OnReceived = DefaultReceivedFn;
	bool GetIsOpen() const { return isOpen; }
	__declspec(property(get = GetIsOpen)) bool const IsOpen;

	WebSocketImpl() {
		CreateWebSocket();
	}
	WebSocketImpl(WebSocketImpl const&) = delete;
	WebSocketImpl(WebSocketImpl&&) = default;
	WebSocketImpl& operator=(WebSocketImpl const&) = delete;
	WebSocketImpl& operator=(WebSocketImpl&&) = default;
	~WebSocketImpl() {
		OnClosed = DefaultClosedFn;
		OnReceived = DefaultReceivedFn;
		Close();
	}

	void Close() {
		if(IsOpen) {
			webSocket.Close();
			closePromise.get_future().get();
			closePromise = std::promise<void>();
		}
	}

	int Open(string_t url, std::chrono::milliseconds timeout) {
		if(IsOpen) {
			DebugWriteLine(_T("[WebSocketImpl::Open] warning:  Open called without preceding Close"));
			assert(!IsOpen);
			Close();
		}
		Uri uri(url);
		try {
			auto action = webSocket.ConnectAsync(uri);
			std::promise<void> promise;
			auto fn = [&action, &promise] {
				try {
					action.get();
					promise.set_value();
				} catch(winrt::hresult_error const&) {
					promise.set_value();
				}
			};
			auto task = std::async(std::launch::async, fn);
			auto future = promise.get_future();
			if(future.wait_for(timeout) == std::future_status::timeout) {
				action.Cancel();
			}
			future.get();
			switch(action.Status()) {
			case AsyncStatus::Canceled:
				return FromPlatformError(ERROR_TIMEOUT);
			case AsyncStatus::Completed:
				isOpen = true;
				break;
			case AsyncStatus::Error:
				return action.ErrorCode();
			default:
				assert(action.Status() == AsyncStatus::Completed);
				return E_UNEXPECTED;
			}
			return 0;
		} catch(winrt::hresult_error const& ex) {
			return ex.code();
		}
	}

	int Send(void const* data, size_t size) {
		DataWriter writer(webSocket.OutputStream());
		auto result = Send(writer, data, size);
		writer.DetachStream();
		return result;
	}

private:
	MessageWebSocket webSocket;
	std::promise<void> closePromise;
	bool isOpen = false;

	void CreateWebSocket() {
		webSocket = MessageWebSocket();
		webSocket.Closed([this](IWebSocket, WebSocketClosedEventArgs) {
			isOpen = false;
			CreateWebSocket();
			SafelyInvoke(OnClosed);
			closePromise.set_value();
		});
		webSocket.MessageReceived([this](IWebSocket, MessageWebSocketMessageReceivedEventArgs args) {
			try {
				auto reader = args.GetDataReader();
				auto size = reader.UnconsumedBufferLength();
				std::vector<uint8_t> data(size);
				reader.ReadBytes(data);
				SafelyInvoke([this, &data]() mutable { OnReceived(data); });
			} catch(winrt::hresult_error const& ex) {
				DebugWriteLine(_T("[WebSocketImpl::CreateWebSocket.MessageReceived] unexpected failure %#x (%d)"),
					ex.code(), ex.code() & 0xffff);
				try {
					webSocket.Close();
				} catch(...) {
					DebugWriteLine(_T("[WebSocketImpl::CreateWebSocket.MessageReceived] unexpected failure"));
				}
			} catch(std::exception const& ex) {
				DebugWriteLine(_T("[WebSocketImpl::CreateWebSocket.MessageReceived] unexpected failure \"%hs\""), ex.what());
			}
		});
	}

	static int Send(DataWriter& writer, void const* data, size_t size) noexcept {
		try {
			auto const* begin = reinterpret_cast<uint8_t const*>(data);
			auto const* end = begin + size;
			winrt::array_view view(begin, end);
			writer.WriteBytes(view);
			writer.StoreAsync().get();
		} catch(winrt::hresult_error const& ex) {
			return ex.code();
		}
		return 0;
	}

	template<typename FN>
	void SafelyInvoke(FN fn) {
		try {
			fn();
		} catch(std::exception const& ex) {
			DebugWriteLine(_T("[WebSocketImpl::SafelyInvoke] function threw an exception \"%hs\""), ex.what());
			assert(false);
		}
	}
};

#include "WebSocket.inl"
