package putty

import (
	"context"
	"encoding/binary"
	"io"
	"reflect"
	"runtime"
	"syscall"
	"unsafe"

	"golang.org/x/sys/windows"

	"a.yandex-team.ru/library/go/core/log"
	"a.yandex-team.ru/security/skotty/skotty/internal/logger"
)

var (
	k32               = windows.NewLazySystemDLL("Kernel32.dll")
	u32               = windows.NewLazySystemDLL("User32.dll")
	pCreateWindowEx   = u32.NewProc("CreateWindowExW")
	pDefWindowProc    = u32.NewProc("DefWindowProcW")
	pDestroyWindow    = u32.NewProc("DestroyWindow")
	pRegisterClass    = u32.NewProc("RegisterClassExW")
	pUnregisterClass  = u32.NewProc("UnregisterClassW")
	pOpenFileMapping  = k32.NewProc("OpenFileMappingA")
	pDispatchMessage  = u32.NewProc("DispatchMessageW")
	pTranslateMessage = u32.NewProc("TranslateMessage")
	pGetMessage       = u32.NewProc("GetMessageW")
)

const (
	className        = "Pageant"
	agentCopydataID  = 0x804e50ba
	agentMaxMsglen   = 8192
	fileMapAllAccess = 0xf001f
	fileMapWrite     = 0x2
	wmCopyData       = 0x004A
)

type request struct {
	data     []byte
	response chan response
}

type response struct {
	data []byte
	err  error
}

type createWindow struct {
	handle uintptr
	err    error
}

type PageantWindow struct {
	class     *wndClassEx
	window    windows.Handle
	requestCh chan request
	debug     bool
}

func NewPageantWindow() (*PageantWindow, error) {
	classNamePtr, err := syscall.UTF16PtrFromString(className)
	if err != nil {
		return nil, err
	}

	win := &PageantWindow{}

	wcex := &wndClassEx{
		WndProc:   windows.NewCallback(win.wndProc),
		ClassName: classNamePtr,
	}
	err = wcex.register()
	if err != nil {
		return nil, err
	}
	win.class = wcex
	ch := make(chan createWindow)
	go func() {
		runtime.LockOSThread()
		defer runtime.UnlockOSThread()
		windowHandle, _, err := pCreateWindowEx.Call(
			uintptr(0),
			uintptr(unsafe.Pointer(classNamePtr)),
			uintptr(unsafe.Pointer(classNamePtr)),
			uintptr(0),
			uintptr(0),
			uintptr(0),
			uintptr(0),
			uintptr(0),
			uintptr(0),
			uintptr(0),
			uintptr(0),
			uintptr(0),
		)
		ch <- createWindow{windowHandle, err}
		if windowHandle != 0 {
			eventLoop(windowHandle)
		}
	}()
	result := <-ch
	if result.handle == 0 {
		_ = wcex.unregister()
		return nil, err
	}
	win.window = windows.Handle(result.handle)
	win.requestCh = make(chan request)
	return win, nil
}

func (s *PageantWindow) Accept(ctx context.Context) (io.ReadWriteCloser, error) {
	select {
	case req := <-s.requestCh:
		return &memoryMapConn{req: req}, nil
	case <-ctx.Done():
		return nil, io.ErrClosedPipe
	}
}

func (s *PageantWindow) Close() {
	_, _, _ = pDestroyWindow.Call(uintptr(s.window))
	_ = s.class.unregister()
}

func eventLoop(window uintptr) {
	m := &struct {
		WindowHandle windows.Handle
		Message      uint32
		Wparam       uintptr
		Lparam       uintptr
		Time         uint32
		Pt           point
	}{}
	for {
		ret, _, _ := pGetMessage.Call(uintptr(unsafe.Pointer(m)), window, 0, 0)

		// If the function retrieves a message other than WM_QUIT, the return value is nonzero.
		// If the function retrieves the WM_QUIT message, the return value is zero.
		// If there is an error, the return value is -1
		// https://msdn.microsoft.com/en-us/library/windows/desktop/ms644936(v=vs.85).aspx
		switch int32(ret) {
		case -1:
			return
		case 0:
			return
		default:
			_, _, _ = pTranslateMessage.Call(uintptr(unsafe.Pointer(m)))
			_, _, _ = pDispatchMessage.Call(uintptr(unsafe.Pointer(m)))
		}
	}
}

// WindowProc callback function that processes messages sent to a window.
// https://msdn.microsoft.com/en-us/library/windows/desktop/ms633573(v=vs.85).aspx
func (s *PageantWindow) wndProc(hWnd windows.Handle, message uint32, wParam, lParam uintptr) uintptr {
	if message != wmCopyData {
		lResult, _, _ := pDefWindowProc.Call(
			uintptr(hWnd),
			uintptr(message),
			wParam,
			lParam,
		)
		return lResult
	}

	// ugly hack
	//TODO(buglloc): do smth better
	var fake windows.Pointer
	*(*uintptr)(unsafe.Pointer(&fake)) = lParam
	copyData := (*copyDataStruct)(unsafe.Pointer(fake))
	if copyData.dwData != agentCopydataID {
		logger.Error("putty: invalid copy data id")
		return 0
	}

	fileMap, err := OpenFileMapping(fileMapAllAccess, 0, copyData.lpData)
	if err != nil {
		logger.Error("putty: OpenFileMapping error", log.Error(err))
		return 0
	}
	defer func() {
		_ = windows.CloseHandle(fileMap)
	}()

	// check security
	ourself, err := GetUserSID()
	if err != nil {
		logger.Error("putty: GetUserSID error", log.Error(err))
		return 0
	}

	ourself2, err := GetDefaultSID()
	if err != nil {
		logger.Error("putty: GetDefaultSID error", log.Error(err))
		return 0
	}

	mapOwner, err := GetHandleSID(fileMap)
	if err != nil {
		logger.Error("putty: GetHandleSID error", log.Error(err))
		return 0
	}

	if !windows.EqualSid(mapOwner, ourself) && !windows.EqualSid(mapOwner, ourself2) {
		logger.Error("putty: wrong owning SID of file mapping")
		return 0
	}
	// get map view
	sharedMemory, err := windows.MapViewOfFile(fileMap, fileMapWrite, 0, 0, 0)
	if err != nil {
		logger.Error("putty: MapViewOfFile error", log.Error(err))
		return 0
	}
	defer func() { _ = windows.UnmapViewOfFile(sharedMemory) }()

	var sharedMemoryArray []byte
	h := (*reflect.SliceHeader)(unsafe.Pointer(&sharedMemoryArray))
	h.Data = sharedMemory
	h.Len = agentMaxMsglen
	h.Cap = agentMaxMsglen

	// check buffer size
	size := binary.BigEndian.Uint32(sharedMemoryArray[:4])
	size += 4
	if size > agentMaxMsglen {
		logger.Error("putty: invalid message length", log.UInt32("size", size), log.UInt32("max_size", agentMaxMsglen))
		return 0
	}

	// send data to handler
	data := make([]byte, size)
	copy(data, sharedMemoryArray[:size])
	ch := make(chan response)
	s.requestCh <- request{
		data:     data,
		response: ch,
	}
	// wait for response
	resp := <-ch
	if resp.err == nil {
		copy(sharedMemoryArray[:], resp.data)
		return 1
	}
	return 0
}
