#include "pch.h"
#include <ddeml.h>
#include "Internal.h"
#include "DdeServer.h"

using namespace Twitch;

namespace {
	inline void ThrowIfError(LSTATUS result) {
		if(result != ERROR_SUCCESS) {
			throw TwitchException(FromPlatformError(result));
		}
	}

	tstring expectedScheme;
	tstring ddeExecutionResult;
	DWORD ddeInstance;

	HDDEDATA CALLBACK DdeCallback(UINT uType, UINT uFmt, HCONV hconv,
		HSZ hsz1, HSZ hsz2, HDDEDATA hdata, ULONG_PTR dwData1, ULONG_PTR dwData2) {
		DebugWriteLine(_T("uType %d, uFmt %d, hconv 0x%Ix, hsz1 0x%Ix, hsz2 0x%Ix, hdata 0x%Ix, dwData1 0x%Ix, dwData2 0x%Ix"),
			uType, uFmt, hconv, hsz1, hsz2, hdata, dwData1, dwData2);
		UNREFERENCED_PARAMETER(uFmt);
		UNREFERENCED_PARAMETER(hconv);
		UNREFERENCED_PARAMETER(dwData1);
		UNREFERENCED_PARAMETER(dwData2);
		switch(uType) {
			TCHAR sz1[88];
			TCHAR sz2[888];
		case XTYP_REGISTER:
			DdeQueryString(ddeInstance, hsz1, sz1, _countof(sz1), CP_WINNEUTRAL);
			DdeQueryString(ddeInstance, hsz2, sz2, _countof(sz2), CP_WINNEUTRAL);
			DebugWriteLine(_T("registering '%s:%s'"), sz1, sz2);
			break;
		case XTYP_CONNECT:
			DdeQueryString(ddeInstance, hsz1, sz1, _countof(sz1), CP_WINNEUTRAL);
			if(_tcscmp(sz1, _T("identity")) != 0) {
				return nullptr;
			}
			DdeQueryString(ddeInstance, hsz2, sz2, _countof(sz2), CP_WINNEUTRAL);
			DebugWriteLine(_T("connecting '%s:%s'"), sz1, sz2);
			return reinterpret_cast<HDDEDATA>(expectedScheme == sz2);
		case XTYP_EXECUTE:
			DdeQueryString(ddeInstance, hsz1, sz1, _countof(sz1), CP_WINNEUTRAL);
			if(_tcscmp(sz1, _T("identity")) == 0) {
				DdeGetData(hdata, reinterpret_cast<LPBYTE>(sz2), sizeof(sz2), 0);
				DebugWriteLine(_T("executing '%s:%ls'"), sz1, sz2);
				ddeExecutionResult = sz2;
				PostQuitMessage(0);
				return reinterpret_cast<HDDEDATA>(DDE_FACK);
			}
			DebugWriteLine(_T("not executing topic '%s'"), sz1);
			break;
		}
		return nullptr;
	}

	tstring Register(LPCTSTR scheme) {
		// Expect this scheme during DDE protocol exchange.
		expectedScheme = scheme;

		// Create the "Software\Classes\{scheme}" key.
		tstring ddeKeyName = _T("Software\\Classes\\");
		ddeKeyName += scheme;
		RegKey ddeKey;
		DWORD disposition;
		auto result = RegCreateKeyEx(HKEY_LOCAL_MACHINE, ddeKeyName.c_str(),
			0, nullptr, 0, KEY_ALL_ACCESS, nullptr, &ddeKey, &disposition);
		if(result == ERROR_ACCESS_DENIED) {
			result = RegCreateKeyEx(HKEY_CURRENT_USER, ddeKeyName.c_str(),
				0, nullptr, 0, KEY_ALL_ACCESS, nullptr, &ddeKey, &disposition);
		}
		ThrowIfError(result);

		// Set its default and "URL Protocol" values.
		tstring ddeKeyDefaultValue = _T("URL:");
		ddeKeyDefaultValue += scheme;
		auto const n = static_cast<DWORD>(ddeKeyDefaultValue.size() * sizeof(ddeKeyDefaultValue[0]));
		ThrowIfError(RegSetValueEx(ddeKey, nullptr, 0, REG_SZ, reinterpret_cast<BYTE const*>(ddeKeyDefaultValue.c_str()), n));
		ThrowIfError(RegSetValueEx(ddeKey, _T("URL Protocol"), 0, REG_SZ, nullptr, 0));

		// Create the "shell\open\command" key under the scheme key.
		RegKey commandKey;
		ThrowIfError(RegCreateKeyEx(ddeKey, _T("shell\\open\\command"),
			0, nullptr, 0, KEY_ALL_ACCESS, nullptr, &commandKey, &disposition));

		// Set its default value to the currently running program.
		//tstring commandDefaultValue(MAX_PATH, _T('\0'));
		tstring commandDefaultValue(20, _T('\0'));
		for(;;) {
			auto const pathSize = GetModuleFileName(NULL, &commandDefaultValue[0], static_cast<DWORD>(commandDefaultValue.size()));
			result = GetLastError();
			if(result == ERROR_INSUFFICIENT_BUFFER) {
				commandDefaultValue.resize(commandDefaultValue.size() * 2);
			} else if(result == ERROR_SUCCESS) {
				commandDefaultValue.resize(pathSize);
				break;
			} else {
				ThrowIfError(result);
			}
		}
		ThrowIfError(RegSetValueEx(commandKey, nullptr, 0, REG_EXPAND_SZ,
			reinterpret_cast<BYTE const*>(commandDefaultValue.c_str()),
			static_cast<DWORD>(commandDefaultValue.size() * sizeof(commandDefaultValue[0]))));

		// Create the "shell\open\ddeexec" key under the scheme key.
		RegKey ddeexecKey;
		ThrowIfError(RegCreateKeyEx(ddeKey, _T("shell\\open\\ddeexec"),
			0, nullptr, 0, KEY_ALL_ACCESS, nullptr, &ddeexecKey, &disposition));

		// Set its default value to a replacement for the DDE result.
		TCHAR const ddeexecDefaultValue[] = _T("%1");
		ThrowIfError(RegSetValueEx(ddeexecKey, nullptr, 0, REG_SZ,
			reinterpret_cast<BYTE const*>(ddeexecDefaultValue), sizeof(ddeexecDefaultValue)));

		// Create the "application" key under the ddeexec key.
		RegKey applicationKey;
		ThrowIfError(RegCreateKeyEx(ddeexecKey, _T("application"), 0, nullptr,
			0, KEY_ALL_ACCESS, nullptr, &applicationKey, &disposition));

		// Set its default value to the URI scheme.
		ThrowIfError(RegSetValueEx(applicationKey, nullptr, 0, REG_SZ,
			reinterpret_cast<BYTE const*>(scheme), lstrlen(scheme) * sizeof(scheme[0])));

		// Create the "topic" key under the ddeexec key.
		RegKey topicKey;
		ThrowIfError(RegCreateKeyEx(ddeexecKey, _T("topic"), 0, nullptr, 0, KEY_ALL_ACCESS, nullptr, &topicKey, &disposition));

		// Set its default value to "identity", the one topic for this DDE server.
		TCHAR const topicDefaultValue[] = _T("identity");
		ThrowIfError(RegSetValueEx(topicKey, nullptr, 0, REG_SZ,
			reinterpret_cast<BYTE const*>(topicDefaultValue), sizeof(topicDefaultValue)));
		return ddeKeyName;
	}

	void Unregister(LPCTSTR ddeKeyName) noexcept {
		RegDeleteTree(HKEY_LOCAL_MACHINE, ddeKeyName);
		RegDeleteTree(HKEY_CURRENT_USER, ddeKeyName);
	}
}

DdeServer::DdeServer(tstring const& redirectUri) {
	// Register this server as a handler of the URI scheme in the URL.
	auto i = redirectUri.find(_T(':'));
	if(i == redirectUri.npos) {
		throw TwitchException(FromPlatformError(ERROR_BAD_ARGUMENTS));
	}
	tstring const scheme = redirectUri.substr(0, i);
	ddeKeyName = Register(scheme.c_str());

	// Start the DDE thread.
	std::promise<UINT> initializationPromise;
	auto const fn = [this, scheme, &initializationPromise] {
		// Open the DDE service.
		auto result = DdeInitialize(&ddeInstance, DdeCallback,
			APPCLASS_STANDARD | CBF_SKIP_CONNECT_CONFIRMS | CBF_SKIP_DISCONNECTS, 0);
		if(result == DMLERR_NO_ERROR) {
			auto ddeServiceName = DdeCreateStringHandle(ddeInstance, scheme.c_str(), CP_WINNEUTRAL);
			if(ddeServiceName != nullptr) {
				if(DdeNameService(ddeInstance, ddeServiceName, nullptr, DNS_REGISTER)) {
					// Opening complete.  Signal the invoking thread.
					ddeThreadId = GetCurrentThreadId();
					initializationPromise.set_value(DMLERR_NO_ERROR);

					// Pump messages for this thread.
					MSG message;
					for(BOOL b; b = GetMessage(&message, nullptr, 0, 0), b;) {
						TranslateMessage(&message);
						DispatchMessage(&message);
					}

					// Close the DDE service.  Clear the thread identifier so
					// the Cancel method doesn't post to some other thread due
					// to thread identifier recycling.
					ddeThreadId = 0;
					DdeNameService(ddeInstance, ddeServiceName, nullptr, DNS_UNREGISTER);
				} else {
					result = DdeGetLastError(ddeInstance);
					DebugWriteLine(_T("DdeNameService failed: 0x%x"), result);
				}
				DdeFreeStringHandle(ddeInstance, ddeServiceName);
			} else {
				result = DdeGetLastError(ddeInstance);
				DebugWriteLine(_T("DdeCreateStringHandle failed: 0x%x"), result);
			}
			DdeUninitialize(ddeInstance);
			ddeInstance = 0;
		} else {
			DebugWriteLine(_T("DdeInitialize failed: 0x%x"), result);
		}

		// Check for errors during opening.
		if(result != DMLERR_NO_ERROR) {
			initializationPromise.set_value(result);
		}
	};
	ddeTask = std::async(std::launch::async, fn);

	// Await DDE initialization.
	auto initializationFuture = initializationPromise.get_future();
	auto result = initializationFuture.get();
	if(result != DMLERR_NO_ERROR) {
		// DDE initialization failed.  Unregister this server and throw.
		Unregister(ddeKeyName.c_str());
		throw TwitchException(FromPlatformError(result));
	}
}

DdeServer::~DdeServer() {
	ddeTask.wait();
	Unregister(ddeKeyName.c_str());
}

tstring DdeServer::GetUri() const {
	ddeTask.wait();
	return ddeExecutionResult;
}

void DdeServer::Cancel() {
	PostThreadMessage(ddeThreadId, WM_QUIT, 0, 0);
}
