package frontapi

import (
	"bytes"
	"context"
	"crypto/hmac"
	"crypto/sha1"
	"errors"
	"fmt"
	"net/http"
	"strconv"
	"strings"
	"time"

	"github.com/go-chi/chi/v5"

	"a.yandex-team.ru/library/go/yandex/blackbox"
	"a.yandex-team.ru/library/go/yandex/tvm"
)

type contextKey string

const contextUserKey = contextKey("user")
const csrfTokenTTL = 86400

type User struct {
	UID        blackbox.ID
	Login      string
	UserTicket string
	YandexUID  string
	IP         string
}

type rolesMapping struct {
	Full []string
	Read []string
}

type role int8

const roleFull role = 1
const roleRead role = 2

func checkAccessMiddleware(loginParam string, allowedRoles rolesMapping, tvmClient tvm.Client) func(http.Handler) http.Handler {
	checkTVMRole := func(ctx context.Context, rawTicket string) (role, error) {
		roles, err := tvmClient.GetRoles(ctx)
		if err != nil {
			return 0, fmt.Errorf("unable to get Tirole roles: %w", err)
		}

		ticket, err := tvmClient.CheckUserTicket(ctx, rawTicket)
		if err != nil {
			return 0, fmt.Errorf("blackbox returned invalid user ticket: %w", err)
		}

		userRoles, err := roles.GetRolesForUser(ticket, nil)
		if err != nil {
			return 0, err
		}

		hasRole := func(rolesNames []string) bool {
			for _, role := range rolesNames {
				if userRoles.HasRole(role) {
					return true
				}
			}

			return false
		}

		// first check admin roles
		if hasRole(allowedRoles.Full) {
			return roleFull, nil
		}

		// and read
		if hasRole(allowedRoles.Read) {
			return roleRead, nil
		}

		return 0, nil
	}

	checkACL := func(r *http.Request) error {
		login := chi.URLParam(r, loginParam)
		if login == "" {
			return errors.New("no login in request")
		}

		user, ok := userFromContext(r.Context())
		if !ok {
			return errors.New("unknown user")
		}

		if user.Login == login {
			return nil
		}

		userRole, err := checkTVMRole(r.Context(), user.UserTicket)
		if err != nil {
			return fmt.Errorf("unable to check admin access while user mismatch %q != %q: %w", user.Login, login, err)
		}

		switch userRole {
		case roleFull:
			return nil
		case roleRead:
			switch r.Method {
			case http.MethodGet, http.MethodOptions, http.MethodHead:
				return nil
			}

			fallthrough
		default:
			return fmt.Errorf("user mismatch %q != %q", user.Login, login)
		}
	}

	return func(next http.Handler) http.Handler {
		fn := func(w http.ResponseWriter, r *http.Request) {
			if err := checkACL(r); err != nil {
				RespErrorf(w, err.Error())
				return
			}

			next.ServeHTTP(w, r)
		}
		return http.HandlerFunc(fn)
	}
}

func authMiddleware(bbClient blackbox.Client) func(http.Handler) http.Handler {
	return func(next http.Handler) http.Handler {
		fn := func(w http.ResponseWriter, r *http.Request) {
			sessionID, err := r.Cookie("Session_id")
			if err != nil {
				http.Error(w, fmt.Sprintf("no Session_id cookie: %v", err), http.StatusUnauthorized)
				return
			}

			rsp, err := bbClient.SessionID(r.Context(), blackbox.SessionIDRequest{
				SessionID:     sessionID.Value,
				UserIP:        r.RemoteAddr,
				GetUserTicket: true,
				Host:          r.Host,
			})
			if err != nil {
				http.Error(w, fmt.Sprintf("failed to auth in BB with sesion cookie: %v", err), http.StatusUnauthorized)
				return
			}

			yandexUID, _ := r.Cookie("yandexuid")
			r = r.WithContext(
				context.WithValue(
					r.Context(),
					contextUserKey,
					User{
						UID:        rsp.User.UID.ID,
						Login:      rsp.User.Login,
						UserTicket: rsp.UserTicket,
						IP:         r.RemoteAddr,
						YandexUID:  yandexUID.Value,
					},
				),
			)
			next.ServeHTTP(w, r)
		}
		return http.HandlerFunc(fn)
	}
}

func checkCsrfMiddleware(secret string) func(http.Handler) http.Handler {
	csrfSecret := []byte(secret)
	return func(next http.Handler) http.Handler {
		fn := func(w http.ResponseWriter, r *http.Request) {
			switch r.Method {
			case http.MethodGet, http.MethodHead, http.MethodOptions:
				next.ServeHTTP(w, r)
				return
			}

			csrfToken := r.Header.Get("X-CSRF-Token")
			if csrfToken == "" {
				RespErrorf(w, "empty 'X-CSRF-Token' header")
				return
			}

			yandexuid, err := r.Cookie("yandexuid")
			if err != nil {
				RespErrorf(w, "can't get yandexuid cookie: %v", err)
				return
			}

			if err := validateCSRFToken(csrfSecret, csrfToken, yandexuid.Value); err != nil {
				RespErrorf(w, "CSRF token validation failed: %v", err)
				return
			}

			next.ServeHTTP(w, r)
		}
		return http.HandlerFunc(fn)
	}
}

func userFromContext(ctx context.Context) (User, bool) {
	u, ok := ctx.Value(contextUserKey).(User)
	return u, ok
}

func validateCSRFToken(secret []byte, token, yandexuid string) error {
	data := strings.Split(token, ":")
	if len(data) != 2 {
		return errors.New("wrong format: must be 'token:timestamp'")
	}

	timestamp, err := strconv.ParseInt(data[1], 10, 64)
	if err != nil {
		return fmt.Errorf("timestamp parsing error: %w", err)
	}

	if time.Now().Unix()-timestamp > csrfTokenTTL {
		return fmt.Errorf("expired: %d > %d", time.Now().Unix()-timestamp, csrfTokenTTL)
	}

	expectedToken, err := generateCSRFToken(secret, yandexuid, timestamp)
	if err != nil {
		return fmt.Errorf("generate check token failed: %w", err)
	}

	if !hmac.Equal([]byte(token), []byte(expectedToken)) {
		return errors.New("verification failed")
	}
	return nil
}

func generateCSRFToken(secret []byte, yandexuid string, timestamp int64) (string, error) {
	if timestamp == 0 {
		timestamp = time.Now().Unix()
	}

	mess := bytes.NewBuffer(nil)
	mess.WriteString(yandexuid)
	mess.WriteByte(':')
	mess.WriteString(strconv.FormatInt(timestamp, 10))

	h := hmac.New(sha1.New, secret)
	_, err := h.Write(mess.Bytes())
	return fmt.Sprintf("%x:%d", h.Sum(nil), timestamp), err
}
