package main

/*
#include <security/pam_appl.h>
#include <security/pam_ext.h>
#include <errno.h>
#include <pwd.h>
#include <stdlib.h>
#include <stdio.h>
#include <string.h>
#include <sys/stat.h>
#include <sys/types.h>
#include <unistd.h>

char *string_from_argv(int i, char **argv) {
  return strdup(argv[i]);
}

// get_user pulls the username out of the pam handle.
char *get_user(pam_handle_t *pamh) {
  if (!pamh)
    return NULL;

  int pam_err = 0;
  const char *user;
  if ((pam_err = pam_get_item(pamh, PAM_USER, (const void**)&user)) != PAM_SUCCESS)
    return NULL;

  return strdup(user);
}

// get_uid returns the uid for the given char *username
int get_uid(char *user) {
  if (!user)
    return -1;
  struct passwd pw, *result;
  char buf[8192]; // 8k should be enough for anyone

  int i = getpwnam_r(user, &pw, buf, sizeof(buf), &result);
  if (!result || i != 0)
    return -1;
  return pw.pw_uid;
}

// change_euid sets the euid to the given euid
int change_euid(int uid) {
  return seteuid(uid);
}

char *get_user(pam_handle_t *pamh);

int _pam_error(pam_handle_t *pamh, const char *msg) {
	return pam_error(pamh, "%s", msg);
}

int _pam_info(pam_handle_t *pamh, const char *msg) {
	return pam_info(pamh, "%s", msg);
}
*/
import "C"
import (
	"fmt"
	"unsafe"
)

type PamHandle struct {
	pamh *C.pam_handle_t
	user string
	uid  int
}

func NewPamHandle(pamh *C.pam_handle_t) (*PamHandle, error) {
	if pamh == nil {
		return nil, &PamError{
			RetCode: C.PAM_SERVICE_ERR,
			Msg:     "invalid pamh passed",
		}
	}

	cUsername := C.get_user(pamh)
	if cUsername == nil {
		return nil, &PamError{
			RetCode: C.PAM_USER_UNKNOWN,
			Msg:     "unknown user",
		}
	}
	defer C.free(unsafe.Pointer(cUsername))

	uid := int(C.get_uid(cUsername))
	if uid < 0 {
		return nil, &PamError{
			RetCode: C.PAM_USER_UNKNOWN,
			Msg:     fmt.Sprintf("no uid for user %s", string(C.GoString(cUsername))),
		}
	}

	return &PamHandle{
		pamh: pamh,
		user: C.GoString(cUsername),
		uid:  uid,
	}, nil
}

func (p *PamHandle) User() string {
	return p.user
}

func (p *PamHandle) UID() int {
	return p.uid
}

func (p *PamHandle) PrintError(msg string) error {
	cMsg := C.CString(msg)
	defer C.free(unsafe.Pointer(cMsg))

	if ret := C._pam_error(p.pamh, cMsg); ret != 0 {
		return fmt.Errorf("pam_error fail: %d", int(ret))
	}

	return nil
}

func (p *PamHandle) PrintMsg(msg string) error {
	cMsg := C.CString(msg)
	defer C.free(unsafe.Pointer(cMsg))

	if ret := C._pam_info(p.pamh, cMsg); ret != 0 {
		return fmt.Errorf("pam_info fail: %d", int(ret))
	}

	return nil
}

type PamError struct {
	RetCode int
	Msg     string
}

func (e *PamError) Error() string {
	return fmt.Sprintf("pam error (%d): %s", e.RetCode, e.Msg)
}

func (e *PamError) Is(target error) bool {
	t, ok := target.(*PamError)
	if !ok {
		return false
	}

	return e.RetCode == t.RetCode
}

// seteuid drops privs.
func seteuid(uid int) bool {
	return C.change_euid(C.int(uid)) == C.int(0)
}
