package app

import (
	"context"
	"fmt"
	"net"
	"net/http"
	"time"

	"github.com/go-chi/chi/v5/middleware"
	"golang.org/x/sys/unix"

	"a.yandex-team.ru/library/go/core/log"
	"a.yandex-team.ru/library/go/core/log/ctxlog"
)

type contextKey struct {
	key string
}

var connCredsKey = &contextKey{"creds"}

func (a *App) SavePeerCreds(ctx context.Context, c net.Conn) context.Context {
	creds, err := readCreds(c)
	if err != nil {
		a.log.Error("can't read uds creds", log.Error(err))
		return ctx
	}

	return context.WithValue(ctx, connCredsKey, creds)
}

func credsFromContext(ctx context.Context) *unix.Ucred {
	creds, ok := ctx.Value(connCredsKey).(*unix.Ucred)
	if !ok {
		return nil
	}

	return creds
}

func rootOnly(next http.Handler) http.Handler {
	return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
		creds := credsFromContext(r.Context())
		if creds == nil || creds.Uid != 0 {
			w.WriteHeader(http.StatusForbidden)
			return
		}

		next.ServeHTTP(w, r)
	})
}

func logRequest(l log.Logger) func(http.Handler) http.Handler {
	return func(next http.Handler) http.Handler {
		fn := func(w http.ResponseWriter, r *http.Request) {
			var uid uint32
			var pid int32
			if creds := credsFromContext(r.Context()); creds != nil {
				uid = creds.Uid
				pid = creds.Pid
			}

			t := time.Now()
			ww := middleware.NewWrapResponseWriter(w, r.ProtoMajor)
			defer func() {
				ctxlog.Info(r.Context(), l,
					fmt.Sprintf("req: %s", r.RequestURI),
					log.UInt32("uid", uid),
					log.Int32("pid", pid),
					log.String("uri", r.RequestURI),
					log.Int("status", ww.Status()),
					log.String("elapsed", time.Since(t).String()),
				)
			}()

			next.ServeHTTP(ww, r)
		}
		return http.HandlerFunc(fn)
	}
}

func readCreds(c net.Conn) (*unix.Ucred, error) {
	var cred *unix.Ucred

	uc, ok := c.(*net.UnixConn)
	if !ok {
		return nil, fmt.Errorf("unexpected socket type")
	}

	raw, err := uc.SyscallConn()
	if err != nil {
		return nil, fmt.Errorf("error opening raw connection: %s", err)
	}

	// The raw.Control() callback does not return an error directly.
	// In order to capture errors, we wrap already defined variable
	// 'err' within the closure. 'err2' is then the error returned
	// by Control() itself.
	err2 := raw.Control(func(fd uintptr) {
		cred, err = unix.GetsockoptUcred(int(fd),
			unix.SOL_SOCKET,
			unix.SO_PEERCRED)
	})

	if err != nil {
		return nil, fmt.Errorf("getsockoptUcred error: %s", err)
	}

	if err2 != nil {
		return nil, fmt.Errorf("control error: %s", err2)
	}

	return cred, nil
}
