package storage

import (
	"context"
	"fmt"
	"io"
	"os"
	"path"
	"path/filepath"
	"regexp"
	"sort"

	"a.yandex-team.ru/infra/cauth/agent/linux/yandex-cauth-userd/internal/config"
)

var versionRe = regexp.MustCompile(`^\d+$`)

type Repo interface {
	CurrentVersion() (Version, error)
	PersistVersion(ctx context.Context, version Version) error
	UpdateCurrent(version Version) error
}

type timestampRepo struct {
	repoPath     string
	keepVersions int
}

func NewTimestampRepo(config *config.RepoConfig) (*timestampRepo, error) {
	repoPath, err := filepath.Abs(config.Path)
	if err != nil {
		return nil, err
	}
	return &timestampRepo{repoPath: repoPath, keepVersions: config.KeepVersions}, nil
}

func (r *timestampRepo) CurrentVersion() (Version, error) {
	ver, err := os.Readlink(path.Join(r.repoPath, "current"))
	if err != nil {
		return nil, err
	}
	s, err := os.Stat(ver)
	if err != nil {
		return nil, err
	}
	if !s.IsDir() {
		return nil, fmt.Errorf("current symlink target %s is not ad directory", ver)
	}
	return NewLocalVersion(ver)
}

type fetcherFunc func(ctx context.Context) (io.ReadCloser, error)

func fetchAndSave(ctx context.Context, f fetcherFunc, path string) error {
	_, err := os.Stat(path)
	if err == nil {
		return fmt.Errorf("path %s already exists", path)
	}
	reader, err := f(ctx)
	if err != nil {
		return err
	}
	defer reader.Close()
	writer, err := os.OpenFile(path, os.O_CREATE|os.O_TRUNC|os.O_RDWR, 0o644)
	if err != nil {
		return err
	}
	defer writer.Close()
	_, err = io.Copy(writer, reader)
	if err != nil {
		return err
	}
	if err := writer.Sync(); err != nil {
		return err
	}
	return nil
}

func (r *timestampRepo) PersistVersion(ctx context.Context, version Version) error {
	incomingPath := path.Join(r.repoPath, "incoming")
	// Remove incoming dir if it exists
	if _, err := os.Stat(incomingPath); err == nil {
		if err := os.RemoveAll(incomingPath); err != nil {
			return fmt.Errorf("failed to clean up incoming dir: %w", err)
		}
	}
	if err := os.MkdirAll(incomingPath, os.ModeDir|0o755); err != nil {
		return err
	}

	adminUsersPath := path.Join(incomingPath, adminUsersFile)
	jobs := []struct {
		f fetcherFunc
		p string
	}{
		{version.FetchPasswd, path.Join(incomingPath, passwdFile)},
		{version.FetchGroup, path.Join(incomingPath, groupFile)},
		{version.FetchAccess, path.Join(incomingPath, accessFile)},
		{version.FetchAdminUsers, adminUsersPath},
		{version.FetchKeys, path.Join(incomingPath, keysFile)},
		{version.FetchSudoers, path.Join(incomingPath, sudoersFile)},
		{version.FetchKeysInfo, path.Join(incomingPath, keysInfoFile)},
		{version.FetchInsecure, path.Join(incomingPath, insecureFile)},
		{version.FetchSecure, path.Join(incomingPath, secureFile)},
		{version.FetchSudo, path.Join(incomingPath, sudoersCAFile)},
		{version.FetchKRL, path.Join(incomingPath, KRLFile)},
	}
	for _, job := range jobs {
		if err := fetchAndSave(ctx, job.f, job.p); err != nil {
			return err
		}
	}

	// Generate NOPASSWD admin sudoers from adminUsersFile data
	adminUsersFile, err := os.Open(adminUsersPath)
	if err != nil {
		return err
	}
	defer adminUsersFile.Close()
	adminSudoersFile, err := os.OpenFile(path.Join(incomingPath, adminSudoersFile), os.O_CREATE|os.O_TRUNC|os.O_RDWR, 0o600)
	if err != nil {
		return err
	}
	defer adminSudoersFile.Close()
	if err := generateAdminSudoers(adminUsersFile, adminSudoersFile); err != nil {
		return err
	}
	if err := adminSudoersFile.Sync(); err != nil {
		return err
	}

	tsStr := fmt.Sprintf("%d", version.Timestamp().Unix())
	if err := os.Rename(incomingPath, path.Join(r.repoPath, tsStr)); err != nil {
		return err
	}
	if err := r.gcVersions(); err != nil {
		return err
	}
	return nil
}

func (r *timestampRepo) gcVersions() error {
	entries, err := os.ReadDir(r.repoPath)
	if err != nil {
		return err
	}
	var versions []string
	for _, ent := range entries {
		if ent.IsDir() && versionRe.MatchString(ent.Name()) {
			versions = append(versions, ent.Name())
		}
	}
	if nVer := len(versions); nVer > r.keepVersions {
		sort.Strings(versions)
		for i := 0; nVer > r.keepVersions; nVer-- {
			if err := os.RemoveAll(path.Join(r.repoPath, versions[i])); err != nil {
				return err
			}
			i++
		}
	}
	return nil
}

func (r *timestampRepo) UpdateCurrent(version Version) error {
	tsStr := fmt.Sprintf("%d", version.Timestamp().Unix())
	vPath := path.Join(r.repoPath, tsStr)
	newCurrentPath := path.Join(r.repoPath, "new")
	if err := os.Symlink(vPath, newCurrentPath); err != nil {
		return err
	}
	return os.Rename(newCurrentPath, path.Join(r.repoPath, "current"))
}
