package installer

import (
	"crypto/sha512"
	"crypto/subtle"
	"crypto/tls"
	"encoding/hex"
	"errors"
	"fmt"
	"io"
	"net/url"
	"os"
	"os/exec"
	"path"
	"strings"
	"syscall"
	"unsafe"

	"github.com/go-resty/resty/v2"
	"golang.org/x/sys/windows"
	"golang.org/x/sys/windows/registry"

	"a.yandex-team.ru/library/go/certifi"
	"a.yandex-team.ru/security/skotty/win32ssh-installer/internal/logger"
	"a.yandex-team.ru/security/skotty/win32ssh-installer/internal/sshrelease"
)

const (
	TargetFolder  = "C:\\Program Files\\OpenSSH"
	SystemEnvPath = "SYSTEM\\CurrentControlSet\\Control\\Session Manager\\Environment"
)

func Install(version, sha512 string) error {
	logger.Infof("Installation Win32OpenSSH v%s started", version)
	downloadURI := sshrelease.DownloadURL(version)
	logger.Infof("Download release: %s", downloadURI)
	msiPath, err := download(downloadURI)
	if err != nil {
		return fmt.Errorf("download failed: %w", err)
	}
	defer func() { _ = os.RemoveAll(msiPath) }()

	logger.Info("Checks installer hash")
	if err := checkFileHash(msiPath, sha512); err != nil {
		return fmt.Errorf("downloaded corruted msi installer: %w", err)
	}

	logger.Info("Run installer")
	if err := install(msiPath); err != nil {
		var exitErr *exec.ExitError
		if errors.As(err, &exitErr) {
			if status, ok := exitErr.Sys().(syscall.WaitStatus); ok && status.ExitStatus() == 1603 {
				return fmt.Errorf("seems Win32OpenSSH is already installed. Please uninstall installed version first")
			}
		}
		return err
	}

	logger.Info("Update PATH env")
	if err := updateEnv(TargetFolder); err != nil {
		return err
	}

	logger.Info("Disable OpenSSH agent")
	if err := disableAgent(); err != nil {
		return err
	}

	logger.Info("Done")
	return nil
}

func download(uri string) (string, error) {
	parsed, err := url.Parse(uri)
	if err != nil {
		return "", fmt.Errorf("invalid downloading uri %q: %w", uri, err)
	}

	targetFile, err := os.CreateTemp("", "*-"+path.Base(parsed.Path))
	if err != nil {
		return "", fmt.Errorf("unable to create temporary file: %w", err)
	}
	defer silentClose(targetFile)

	httpc := resty.New().SetDoNotParseResponse(true)
	certPool, err := certifi.NewCertPool()
	if err == nil {
		httpc.SetTLSClientConfig(&tls.Config{RootCAs: certPool})
	}

	rsp, err := httpc.R().Get(uri)
	if err != nil {
		return "", err
	}
	defer silentClose(rsp.RawBody())

	if _, err := io.Copy(targetFile, rsp.RawBody()); err != nil {
		return "", fmt.Errorf("unable to save file: %w", err)
	}

	return targetFile.Name(), targetFile.Close()
}

func checkFileHash(filepath, expectedSha512 string) error {
	f, err := os.Open(filepath)
	if err != nil {
		return fmt.Errorf("unable to open file: %w", err)
	}
	defer silentClose(f)

	h := sha512.New()
	if _, err := io.Copy(h, f); err != nil {
		return fmt.Errorf("unable to calculate file hash: %w", err)
	}

	actualSha512 := hex.EncodeToString(h.Sum(nil))
	if subtle.ConstantTimeCompare([]byte(expectedSha512), []byte(actualSha512)) != 1 {
		return fmt.Errorf("hash mismatch: %s (expected) != %s (actual)", expectedSha512, actualSha512)
	}
	return nil
}

func install(msiPath string) error {
	return runMSI(msiPath, "ADDLOCAL=Client")
}

func runMSI(msiPath string, params ...string) error {
	args := append(
		[]string{
			"/i",
			msiPath,
		},
		params...,
	)
	cmd := exec.Command("msiexec", args...)
	logger.Infof("starts: %s", cmd)
	return cmd.Run()
}

func updateEnv(targetDir string) error {
	getRegEnv := func(key string) (string, error) {
		k, err := registry.OpenKey(registry.LOCAL_MACHINE, SystemEnvPath, registry.QUERY_VALUE)
		if err != nil {
			return "", err
		}
		defer func() { _ = k.Close() }()

		s, _, err := k.GetStringValue(key)
		return s, err
	}

	setRegEnv := func(key, value string) error {
		k, err := registry.OpenKey(registry.LOCAL_MACHINE, SystemEnvPath, registry.SET_VALUE)
		if err != nil {
			return err
		}
		defer func() { _ = k.Close() }()

		return k.SetStringValue(key, value)
	}

	sendBroadcast := func() {
		const (
			hwndBroadcast          = windows.HWND(0xFFFF)
			wmSettingChange        = uint(0x001A)
			smtoAbortIsHung        = uint(0x0000)
			smtoNormal             = uint(0x0002)
			smtoNoTimeoutIfNotHung = uint(0x0008)
			timeout                = uint(5000)
		)

		var proc = syscall.NewLazyDLL("user32.dll").NewProc("SendMessageTimeoutW")
		envUTF, err := windows.UTF16PtrFromString("Environment")
		if err != nil {
			return
		}

		_, _, _ = proc.Call(
			uintptr(hwndBroadcast),
			uintptr(wmSettingChange),
			0,
			uintptr(unsafe.Pointer(envUTF)),
			uintptr(smtoNormal|smtoAbortIsHung|smtoNoTimeoutIfNotHung),
			uintptr(timeout),
			0,
		)
	}

	pathEnv, err := getRegEnv("Path")
	if err != nil && !os.IsNotExist(err) {
		return fmt.Errorf("can't read HKLM\\%s[Path]: %w", SystemEnvPath, err)
	}

	const envDelim = string(os.PathListSeparator)
	paths := strings.Split(pathEnv, envDelim)
	pos := 0
	for _, p := range paths {
		if p == targetDir {
			continue
		}
		paths[pos] = p
		pos++
	}
	paths = paths[:pos]

	targetPos := 0
	for i, p := range paths {
		p = strings.ToLower(p)
		if strings.Contains(p, "openssh") {
			targetPos = i
			break
		}

		if strings.HasPrefix(p, "%systemroot%\\") || p == "%systemroot%" {
			targetPos = i
			continue
		}

		targetPos += 1
		break
	}

	var newPaths []string
	newPaths = append(newPaths, paths[:targetPos]...)
	newPaths = append(newPaths, targetDir)
	newPaths = append(newPaths, paths[targetPos:]...)
	err = setRegEnv("Path", strings.Join(newPaths, envDelim))
	if err != nil {
		return fmt.Errorf("can't update HKLM\\%s[Path]: %w", SystemEnvPath, err)
	}

	sendBroadcast()
	return nil
}

func silentClose(closer io.Closer) {
	_ = closer.Close()
}
