package nssdb

import (
	"bytes"
	"errors"
	"fmt"
	"io/ioutil"
	"sync"

	"google.golang.org/protobuf/proto"

	"a.yandex-team.ru/infra/cauth/agent/linux/cauth-agent/internal/config"
	"a.yandex-team.ru/infra/cauth/agent/linux/cauth-agent/pkg/cauthrpc"
	"a.yandex-team.ru/infra/cauth/agent/linux/flatcache/go/flatcache"
	"a.yandex-team.ru/library/go/core/log"
)

var dbMarker = []byte("DEADBEEF")

type StorageConfig struct {
	dbPath string
}

type Storage struct {
	cfg        config.NSSConfig
	log        log.Logger
	minUID     uint32
	maxUID     uint32
	uids       [][]byte
	userGroups [][]byte
	userNames  map[string]uint32
	gids       [][]byte
	minGID     uint32
	maxGID     uint32
	groupNames map[string]uint32
	mu         sync.RWMutex
}

func NewStorage(cfg config.NSSConfig, l log.Logger) *Storage {
	return &Storage{
		cfg: cfg,
		log: l,
	}
}

func (s *Storage) Reload() error {
	data, err := ioutil.ReadFile(s.cfg.DBPath)
	if err != nil {
		return fmt.Errorf("can't read db file: %w", err)
	}

	if len(data) < len(dbMarker)*2 {
		return fmt.Errorf("db is too small: %d < %d", len(data), len(dbMarker)*2)
	}

	if !bytes.Equal(data[:len(dbMarker)], dbMarker) {
		return fmt.Errorf("start marker not found")
	}

	if !bytes.Equal(data[len(data)-len(dbMarker):], dbMarker) {
		return fmt.Errorf("end marker not found")
	}

	cache := flatcache.GetRootAsCache(data[len(dbMarker):len(data)-len(dbMarker)], 0)

	getMinGID := func() (uint32, error) {
		var group flatcache.Group
		// groups are sorted by gid
		for i := 0; i < cache.GroupsLength(); i++ {
			if !cache.Groups(&group, i) {
				continue
			}

			return group.Gid(), nil
		}

		return 0, errors.New("no groups found")
	}

	getMaxGID := func() (uint32, error) {
		var group flatcache.Group
		// groups are sorted by gid
		for i := cache.GroupsLength() - 1; i >= 0; i-- {
			if !cache.Groups(&group, i) {
				continue
			}

			return group.Gid(), nil
		}

		return 0, errors.New("no groups found")
	}

	getMinUID := func() (uint32, error) {
		var user flatcache.User
		// users are sorted by uid
		for i := 0; i < cache.UsersLength(); i++ {
			if !cache.Users(&user, i) {
				continue
			}

			return user.Uid(), nil
		}

		return 0, errors.New("no users found")
	}

	getMaxUID := func() (uint32, error) {
		var user flatcache.User
		// users are sorted by uid
		for i := cache.UsersLength() - 1; i >= 0; i-- {
			if !cache.Users(&user, i) {
				continue
			}

			return user.Uid(), nil
		}

		return 0, errors.New("no users found")
	}

	minGID, err := getMinGID()
	if err != nil {
		return fmt.Errorf("can't calculate min gid: %w", err)
	}

	maxGID, err := getMaxGID()
	if err != nil {
		return fmt.Errorf("can't calculate max gid: %w", err)
	}

	minUID, err := getMinUID()
	if err != nil {
		return fmt.Errorf("can't calculate min uid: %w", err)
	}

	maxUID, err := getMaxUID()
	if err != nil {
		return fmt.Errorf("can't calculate max uid: %w", err)
	}

	uids := make([][]byte, maxUID-minUID+1)
	userGroups := make([][]byte, maxUID-minUID+1)
	userNames := make(map[string]uint32, cache.UsersLength())
	for i := 0; i < cache.UsersLength(); i++ {
		var flatUser flatcache.User
		if !cache.Users(&flatUser, i) {
			s.log.Error("can't unmarshal flatUser, skip it", log.Int("pos", i))
			continue
		}

		groups := cauthrpc.UserGroups{
			Groups: make([]uint32, flatUser.GroupsLength()),
		}

		for k := 0; k < flatUser.GroupsLength(); k++ {
			groups.Groups[k] = flatUser.Groups(k)
		}

		user := cauthrpc.User{
			Uid:     flatUser.Uid(),
			Gid:     flatUser.Gid(),
			HomeDir: string(flatUser.HomeDir()),
			Login:   string(flatUser.Login()),
			Shell:   string(flatUser.Shell()),
		}

		protoUser, err := proto.Marshal(&user)
		if err != nil {
			s.log.Error("can't marshal user, skip it", log.Int("pos", i), log.Any("uid", user.Uid))
			continue
		}

		protoUserGroups, err := proto.Marshal(&groups)
		if err != nil {
			s.log.Error("can't marshal user groups, skip it", log.Int("pos", i), log.Any("uid", user.Uid))
			continue
		}

		uids[user.Uid-minUID] = protoUser
		userGroups[user.Uid-minUID] = protoUserGroups
		userNames[string(flatUser.Login())] = flatUser.Uid()
	}

	gids := make([][]byte, maxGID-minGID+1)
	groupNames := make(map[string]uint32, cache.GroupsLength())
	for i := 0; i < cache.GroupsLength(); i++ {
		var flatGroup flatcache.Group
		if !cache.Groups(&flatGroup, i) {
			s.log.Error("can't unmarshal user, skip it", log.Int("pos", i))
			continue
		}

		group := cauthrpc.Group{
			Gid:     flatGroup.Gid(),
			Name:    string(flatGroup.Name()),
			Members: make([]string, flatGroup.MembersLength()),
		}

		for k := 0; k < flatGroup.MembersLength(); k++ {
			group.Members[k] = string(flatGroup.Members(k))
		}

		protoGroup, err := proto.Marshal(&group)
		if err != nil {
			s.log.Error("can't marshal group, skip it", log.Int("pos", i), log.Any("gid", group.Gid))
			continue
		}

		gids[group.Gid-minGID] = protoGroup
		groupNames[string(flatGroup.Name())] = flatGroup.Gid()
	}

	s.mu.Lock()
	s.minUID = minUID
	s.maxUID = maxUID
	s.uids = uids
	s.userGroups = userGroups
	s.userNames = userNames
	s.log.Info("nssdb: users loaded", log.Int("count", len(s.userNames)))
	s.minGID = minGID
	s.maxGID = maxGID
	s.gids = gids
	s.groupNames = groupNames
	s.log.Info("nssdb: groups loaded", log.Int("count", len(s.groupNames)))
	s.mu.Unlock()

	return nil
}

func (s *Storage) UserByUID(uid uint32) ([]byte, bool) {
	s.mu.RLock()
	defer s.mu.RUnlock()

	if uid < s.minUID {
		return nil, false
	}

	user := s.uids[uid-s.minUID]
	if len(user) == 0 {
		return nil, false
	}

	return user, true
}

func (s *Storage) UserByLogin(login string) ([]byte, bool) {
	s.mu.RLock()
	uid, ok := s.userNames[login]
	s.mu.RUnlock()

	if !ok {
		return nil, false
	}

	return s.UserByUID(uid)
}

func (s *Storage) UserGroups(login string) ([]byte, bool) {
	s.mu.RLock()
	uid, ok := s.userNames[login]
	s.mu.RUnlock()

	if !ok {
		return nil, false
	}

	if uid < s.minUID {
		return nil, false
	}

	groups := s.userGroups[uid-s.minUID]
	if len(groups) == 0 {
		return nil, false
	}

	return groups, true
}

func (s *Storage) GroupByGID(gid uint32) ([]byte, bool) {
	s.mu.RLock()
	defer s.mu.RUnlock()

	if gid < s.minGID {
		return nil, false
	}

	group := s.gids[gid-s.minGID]
	if len(group) == 0 {
		return nil, false
	}

	return group, true
}

func (s *Storage) GroupByName(name string) ([]byte, bool) {
	s.mu.RLock()
	gid, ok := s.groupNames[name]
	s.mu.RUnlock()

	if !ok {
		return nil, false
	}

	return s.GroupByGID(gid)
}
