package main

import (
	"flag"
	"fmt"
	"io/ioutil"
	"os"
	"path/filepath"

	flatbuffers "github.com/google/flatbuffers/go"

	"a.yandex-team.ru/infra/cauth/agent/linux/flatcache/go/flatcache"
	"a.yandex-team.ru/infra/cauth/agent/linux/nss_export/provider"
	"a.yandex-team.ru/infra/cauth/agent/linux/nss_export/provider/cauth"
	"a.yandex-team.ru/infra/cauth/agent/linux/nss_export/provider/ldap"
)

var (
	dbMarker = []byte("DEADBEEF")
)

func chooseProvider(p string) (provider.Provider, error) {
	switch p {
	case "cauth":
		return cauth.NewProvider(), nil
	case "ldap":
		return ldap.NewProvider(), nil
	default:
		return nil, fmt.Errorf("unknown provider: %s", p)
	}
}

func fatalf(format string, args ...interface{}) {
	_, _ = fmt.Fprintf(os.Stderr, format+"\n", args...)
	os.Exit(1)
}

func writeCache(targetPath string, data []byte) error {
	dir, name := filepath.Split(targetPath)
	f, err := ioutil.TempFile(dir, name)
	if err != nil {
		return fmt.Errorf("can't create temporary file: %w", err)
	}

	_, err = f.Write(dbMarker)
	if err != nil {
		return fmt.Errorf("can't write start marker to the temporary file: %w", err)
	}

	_, err = f.Write(data)
	if err != nil {
		return fmt.Errorf("can't create temporary file: %w", err)
	}

	_, err = f.Write(dbMarker)
	if err != nil {
		return fmt.Errorf("can't write end marker to the temporary file: %w", err)
	}

	err = f.Close()
	if err != nil {
		return fmt.Errorf("can't close temporary file: %w", err)
	}

	err = os.Chmod(f.Name(), 0644)
	if err != nil {
		return fmt.Errorf("can't chmod temporary file: %w", err)
	}

	if _, err := os.Stat(targetPath); err == nil {
		_ = os.Remove(targetPath + ".old")
		err = os.Link(targetPath, targetPath+".old")
		if err != nil {
			return fmt.Errorf("can't backup old database")
		}
	}

	return os.Rename(f.Name(), targetPath)
}

func main() {
	rawProvider := "cauth"
	flag.StringVar(&rawProvider, "provider", rawProvider, "passwd/group provider to use")
	flag.Parse()

	if flag.NArg() != 1 {
		fatalf("usage: nss_export /path/to/db")
	}

	provider, err := chooseProvider(rawProvider)
	if err != nil {
		fatalf("%v", err)
	}

	flatBuilder := flatbuffers.NewBuilder(0)
	groups, err := getGroups(provider, flatBuilder)
	if err != nil {
		fatalf("can't get groups from ldap: %v", err)
	}

	users, err := getUsers(provider, flatBuilder, groups)
	if err != nil {
		fatalf("can't get users from ldap: %v", err)
	}

	flatcache.CacheStart(flatBuilder)
	flatcache.CacheAddUsers(flatBuilder, users.UsersOffset)
	flatcache.CacheAddUserNames(flatBuilder, users.NamesOffset)
	flatcache.CacheAddGroups(flatBuilder, groups.GroupsOffset)
	flatcache.CacheAddGroupNames(flatBuilder, groups.NamesOffset)
	flatBuilder.Finish(flatcache.CacheEnd(flatBuilder))

	if err := writeCache(flag.Arg(0), flatBuilder.FinishedBytes()); err != nil {
		fatalf("write cache failed: %v", err)
	}
}
