package storage

import (
	"context"
	"fmt"
	"io"
	"os"
	"strings"

	"a.yandex-team.ru/infra/cauth/agent/linux/yandex-cauth-userd/internal/cauth"
	"a.yandex-team.ru/infra/cauth/agent/linux/yandex-cauth-userd/internal/passwd"
	"a.yandex-team.ru/infra/cauth/agent/linux/yandex-cauth-userd/internal/utils"
	"a.yandex-team.ru/infra/cauth/agent/linux/yandex-cauth-userd/pkg/cauthrpc"
	"github.com/golang/protobuf/proto"
)

// Index defines interface for cached data lookups
type Index interface {
	// UserByUID returns serialized in wire format cauthrpc.User
	// and lookup success flag searching by uid.
	UserByUID(uid uint32) ([]byte, bool)
	// UserByLogin returns serialized in wire format cauthrpc.User
	// and lookup success flag searching by login
	UserByLogin(login string) ([]byte, bool)
	// UserGroups returns serialized in wire format cauthrpc.UserGroups
	// and lookup success flag searching by login.
	UserGroups(login string) ([]byte, bool)
	// GroupByGID returns serialized in wire format cauthrpc.Group
	// and lookup success flag searching by gid.
	GroupByGID(gid uint32) ([]byte, bool)
	// GroupByName returns serialized in wire format cauthrpc.Group
	// and lookup success flag searching by group name.
	GroupByName(name string) ([]byte, bool)
	// KeysByLogin returns user keys by login in authorized_keys file format (one key per line).
	KeysByLogin(login string) ([]byte, bool)
}

const (
	// serveradminsGID gid for virtual serveradmins group (gid_t_max - 3)
	serveradminsGID = 0xfffffffd
	// virtualUsersGroupName name of group containing all virtual users
	virtualUsersGroupName = "dpt_virtual"
)

// inMemoryIndex implements Index interface storing all data in memory.
type inMemoryIndex struct {
	uidToUser        map[uint32][]byte
	loginToUser      map[string][]byte
	gidToGroup       map[uint32][]byte
	groupNameToGroup map[string][]byte
	loginToKeys      map[string][]byte
	loginToGroups    map[string][]byte
}

func (d *inMemoryIndex) UserByUID(uid uint32) ([]byte, bool) {
	if b, ok := d.uidToUser[uid]; ok {
		return b, true
	} else {
		return nil, false
	}
}

func (d *inMemoryIndex) UserByLogin(login string) ([]byte, bool) {
	if b, ok := d.loginToUser[login]; ok {
		return b, true
	} else {
		return nil, false
	}
}

func (d *inMemoryIndex) UserGroups(login string) ([]byte, bool) {
	if b, ok := d.loginToGroups[login]; ok {
		return b, true
	} else {
		return nil, false
	}
}

func (d *inMemoryIndex) GroupByGID(gid uint32) ([]byte, bool) {
	if b, ok := d.gidToGroup[gid]; ok {
		return b, true
	} else {
		return nil, false
	}
}

func (d *inMemoryIndex) GroupByName(name string) ([]byte, bool) {
	if b, ok := d.groupNameToGroup[name]; ok {
		return b, true
	} else {
		return nil, false
	}
}

func (d *inMemoryIndex) KeysByLogin(login string) ([]byte, bool) {
	if b, ok := d.loginToKeys[login]; ok {
		return b, true
	} else {
		return nil, false
	}
}

type serializedUser struct {
	u *passwd.User
	b []byte
}

type serializedGroup struct {
	g *passwd.Group
	b []byte
}

func NewInMemoryIndex(ctx context.Context, client cauth.Client, resolveGroupMembers bool, overrides *UserOverrides) (*inMemoryIndex, error) {
	users, err := loadPasswd(ctx, client)
	if err != nil {
		return nil, err
	}
	adminUsers, err := loadAdminUsers(ctx, client)
	if err != nil {
		return nil, err
	}
	users = mergeAdminUsersToUsers(adminUsers, users)

	groups, err := loadGroups(ctx, client)
	if err != nil {
		return nil, err
	}
	groups = append(groups, makeServeradminGroup(adminUsers))
	sGroups, err := serializeGroups(groups, resolveGroupMembers)
	if err != nil {
		return nil, err
	}

	keys, err := loadKeys(ctx, client)
	if err != nil {
		return nil, err
	}

	keysInfo, err := loadKeysInfo(ctx, client)
	if err != nil {
		return nil, err
	}

	additionalKeys, staffSource, err := loadAdditionalKeys(ctx, client, keysInfo)
	if err != nil {
		return nil, err
	}

	rv := &inMemoryIndex{
		uidToUser:        make(map[uint32][]byte),
		loginToUser:      make(map[string][]byte),
		gidToGroup:       make(map[uint32][]byte),
		groupNameToGroup: make(map[string][]byte),
		loginToKeys:      make(map[string][]byte),
		loginToGroups:    make(map[string][]byte),
	}

	virtMembers := make(map[string]struct{})
	err = indexGroups(rv.gidToGroup, rv.groupNameToGroup, rv.loginToGroups, sGroups, virtMembers)
	if err != nil {
		return nil, err
	}

	sUsers, err := serializeUsers(users, virtMembers, overrides)
	if err != nil {
		return nil, err
	}
	indexUsers(rv.uidToUser, rv.loginToUser, sUsers)

	err = indexKeys(rv.loginToKeys, keys, adminUsers, staffSource, additionalKeys)
	if err != nil {
		return nil, err
	}
	return rv, nil
}

func indexUsers(uids map[uint32][]byte, logins map[string][]byte, sUsers []*serializedUser) {
	for _, u := range sUsers {
		uids[u.u.UID] = u.b
		logins[u.u.Login] = u.b
	}
}

func indexGroups(gids map[uint32][]byte, names map[string][]byte, logins map[string][]byte, sGroups []*serializedGroup, virtMembers map[string]struct{}) error {
	loginToGroups := make(map[string]*cauthrpc.UserGroups)
	for _, g := range sGroups {
		gids[g.g.GID] = g.b
		names[g.g.Name] = g.b
		isVirtGroup := g.g.Name == virtualUsersGroupName
		for _, member := range g.g.Members {
			if isVirtGroup {
				virtMembers[member] = struct{}{}
			}
			if ug, ok := loginToGroups[member]; ok {
				ug.Groups = append(ug.Groups, g.g.GID)
			} else {
				loginToGroups[member] = &cauthrpc.UserGroups{Groups: []uint32{g.g.GID}}
			}
		}
	}
	for u, g := range loginToGroups {
		b, err := proto.Marshal(g)
		if err != nil {
			return err
		}
		logins[u] = b
	}
	return nil
}

func makeCAKey(key string, principals ...string) string {
	return fmt.Sprintf("cert-authority,principals=\"%s\" %s", strings.Join(principals, ","), key)
}

func indexKeys(logins map[string][]byte, keys []*passwd.Key, adminUsers []*passwd.User, staffSource bool, caKeys []string) error {
	roots, rootsList := rootPrincipals(adminUsers)
	loginToKeys := make(map[string][]string)
	for _, ak := range caKeys {
		loginToKeys["root"] = append(loginToKeys["root"], makeCAKey(ak, rootsList...))
	}
	if staffSource {
		for _, k := range keys {
			loginToKeys[k.Login] = append(loginToKeys[k.Login], k.Key)
			for _, ak := range caKeys {
				loginToKeys[k.Login] = append(loginToKeys[k.Login], makeCAKey(ak, k.Login))
			}
			if _, ok := roots[k.Login]; ok {
				loginToKeys["root"] = append(loginToKeys["root"], k.Key)
			}
		}
	} else {
		for _, k := range keys {
			// Fill CA keys only once per user
			if _, ok := loginToKeys[k.Login]; !ok {
				for _, ak := range caKeys {
					loginToKeys[k.Login] = append(loginToKeys[k.Login], makeCAKey(ak, k.Login))
				}
			}
		}
	}
	for u, k := range loginToKeys {
		b := []byte(strings.Join(k, "\n"))
		logins[u] = b
	}
	return nil
}

func loadPasswd(ctx context.Context, fetcher cauth.PasswdFetcher) ([]*passwd.User, error) {
	reader, err := fetcher.FetchPasswd(ctx)
	if err != nil {
		return nil, err
	}
	defer reader.Close()
	users, err := passwd.ParsePasswd(reader)
	if err != nil {
		return nil, err
	}
	return users, nil
}

func loadGroups(ctx context.Context, fetcher cauth.GroupFetcher) ([]*passwd.Group, error) {
	reader, err := fetcher.FetchGroup(ctx)
	if err != nil {
		return nil, err
	}
	defer reader.Close()
	groups, err := passwd.ParseGroup(reader)
	if err != nil {
		return nil, err
	}
	return groups, nil
}

func loadKeys(ctx context.Context, fetcher cauth.KeysFetcher) ([]*passwd.Key, error) {
	reader, err := fetcher.FetchKeys(ctx)
	if err != nil {
		return nil, err
	}
	keys, err := passwd.ParseKeys(reader)
	if err != nil {
		return nil, err
	}
	return keys, nil
}

func loadAdminUsers(ctx context.Context, fetcher cauth.AdminsUsersFetcher) ([]*passwd.User, error) {
	reader, err := fetcher.FetchAdminUsers(ctx)
	if err != nil {
		return nil, err
	}
	users, err := passwd.ParsePasswd(reader)
	if err != nil {
		return nil, err
	}
	return users, nil
}

func loadKeysInfo(ctx context.Context, fetcher cauth.KeysInfoFetcher) (*cauth.KeysInfo, error) {
	reader, err := fetcher.FetchKeysInfo(ctx)
	if os.IsNotExist(err) {
		keysInfo := &cauth.KeysInfo{
			InsecureCAListURL: "https://skotty.sec.yandex-team.ru/api/v1/ca/pub/insecure",
			SudoCAListURL:     "https://skotty.sec.yandex-team.ru/api/v1/ca/pub/sudo",
			KrlURL:            "https://skotty.sec.yandex-team.ru/api/v1/ca/krl/all.zst",
			SecureCAListURL:   "https://skotty.sec.yandex-team.ru/api/v1/ca/pub/secure",
			KeySources:        []string{"staff"},
		}
		return keysInfo, nil
	} else if err != nil {
		return nil, err
	}
	keysInfo, _, err := cauth.ParseKeysInfo(reader)
	if err != nil {
		return nil, err
	}
	return keysInfo, nil
}

func loadCAKeys(ctx context.Context, fetcher func(context.Context) (io.ReadCloser, error)) ([]string, error) {
	reader, err := fetcher(ctx)
	if err != nil {
		return nil, err
	}
	keys, err := utils.ParseKeys(reader)
	if err != nil {
		return nil, err
	}
	return keys, nil
}

func loadAdditionalKeys(ctx context.Context, client cauth.Client, keysInfo *cauth.KeysInfo) ([]string, bool, error) {
	// Make list of keys (insecure and secure) if it need and mark if staff keys in Keys Sources
	var additionalKeys []string
	staffSource := false
	for _, i := range keysInfo.KeySources {
		if i == "insecure" {
			insecureKeys, err := loadCAKeys(ctx, client.FetchInsecure)
			if err != nil {
				return nil, false, err
			}
			additionalKeys = append(additionalKeys, insecureKeys...)
		} else if i == "secure" {
			secureKeys, err := loadCAKeys(ctx, client.FetchSecure)
			if err != nil {
				return nil, false, err
			}
			additionalKeys = append(additionalKeys, secureKeys...)
		} else if i == "staff" {
			staffSource = true
		}
	}
	return additionalKeys, staffSource, nil
}

func mergeAdminUsersToUsers(adminUsers, users []*passwd.User) []*passwd.User {
	userMap := make(map[string]struct{})
	for _, u := range users {
		if _, ok := userMap[u.Login]; !ok {
			userMap[u.Login] = struct{}{}
		}
	}
	for _, u := range adminUsers {
		if _, ok := userMap[u.Login]; !ok {
			users = append(users, u)
		}
	}
	return users
}

func makeServeradminGroup(adminUsers []*passwd.User) *passwd.Group {
	g := &passwd.Group{
		Name:    "serveradmins",
		GID:     serveradminsGID,
		Members: make([]string, len(adminUsers)),
	}
	for i := range adminUsers {
		g.Members[i] = adminUsers[i].Login
	}
	return g
}

func rootPrincipals(adminUsers []*passwd.User) (map[string]struct{}, []string) {
	rv := make(map[string]struct{})
	var rl []string
	for _, u := range adminUsers {
		if _, ok := rv[u.Login]; !ok {
			rl = append(rl, u.Login)
		}
		rv[u.Login] = struct{}{}
	}
	return rv, rl
}

func isVirtualUser(name string, virtMembers map[string]struct{}) bool {
	if _, ok := virtMembers[name]; ok {
		return true
	}
	return strings.HasPrefix(name, "robot-") || strings.HasPrefix(name, "zomb-")
}

func serializeUsers(users []*passwd.User, virtMembers map[string]struct{}, overrides *UserOverrides) ([]*serializedUser, error) {
	var rv []*serializedUser
	rpcUser := cauthrpc.User{}
	for _, u := range users {
		rpcUser.Uid = u.UID
		rpcUser.Gid = u.GID
		rpcUser.Login = u.Login
		rpcUser.HomeDir = u.HomeDir
		rpcUser.Shell = u.Shell
		rpcUser.Gecos = u.Gecos

		// Apply overrides.
		if isVirtualUser(u.Login, virtMembers) {
			if overrides.VirtualUsersShellOverride != "" {
				rpcUser.Shell = overrides.VirtualUsersShellOverride
			}
			if overrides.VirtualUsersGIDOverride >= 0 {
				rpcUser.Gid = uint32(overrides.VirtualUsersGIDOverride)
			}
		} else {
			if overrides.RealUsersShellOverride != "" {
				rpcUser.Shell = overrides.RealUsersShellOverride
			}
			if overrides.RealUsersGIDOverride >= 0 {
				rpcUser.Gid = uint32(overrides.RealUsersGIDOverride)
			}
		}

		b, err := proto.Marshal(&rpcUser)
		if err != nil {
			return nil, err
		}
		rv = append(rv, &serializedUser{u: u, b: b})
	}
	return rv, nil
}

func serializeGroups(groups []*passwd.Group, resolveGroupMembers bool) ([]*serializedGroup, error) {
	var rv []*serializedGroup
	rpcGroup := cauthrpc.Group{}
	for _, g := range groups {
		rpcGroup.Gid = g.GID
		rpcGroup.Name = g.Name
		if resolveGroupMembers {
			rpcGroup.Members = g.Members
		}
		b, err := proto.Marshal(&rpcGroup)
		if err != nil {
			return nil, err
		}
		rv = append(rv, &serializedGroup{g: g, b: b})
	}
	return rv, nil
}
