package cmd

import (
	"a.yandex-team.ru/solomon/libs/go/iam"
	"a.yandex-team.ru/solomon/protos/secrets"
	"a.yandex-team.ru/solomon/tools/secrets/internal/kms"
	"context"
	"crypto/aes"
	"crypto/cipher"
	"encoding/base64"
	"fmt"
	"github.com/spf13/cobra"
	"google.golang.org/protobuf/proto"
	"os"
	"os/user"
	"strconv"
	"time"
)

func init() {
	decryptCmd := &cobra.Command{
		Use:   "decrypt",
		Short: "Decrypt prepared file with ciphertext",
		Args:  cobra.ExactArgs(0),
		RunE:  runDecryptCmd,
	}
	decryptCmd.Flags().String("in", "", "filename of cipher file")
	decryptCmd.Flags().String("out", "", "filename of plaintext file")
	decryptCmd.Flags().String("user", "", "username of output file owner")
	_ = decryptCmd.MarkFlagRequired("in")
	_ = decryptCmd.MarkFlagRequired("out")
	rootCmd.AddCommand(decryptCmd)
}

func decryptWithYav(enc *secrets.AesGcmSivEncryption, ciphertext []byte) ([]byte, error) {
	yavKey, err := loadYavKey(context.Background(), enc.KeyRef)
	if err != nil {
		return nil, err
	}

	yavKeyBytes, err := base64.StdEncoding.DecodeString(yavKey)
	if err != nil {
		return nil, fmt.Errorf("cannot decode yav key wiht Base64 decoder, %v", err)
	}
	if len(yavKeyBytes) != 32 {
		return nil, fmt.Errorf("invalid yav key lenght %d, must be 32", len(yavKeyBytes))
	}

	block, err := aes.NewCipher(yavKeyBytes)
	if err != nil {
		return nil, fmt.Errorf("cannot create AES cipher, %v", err)
	}

	gcm, err := cipher.NewGCM(block)
	if err != nil {
		return nil, fmt.Errorf("cannot create GNM cipher, %v", err)
	}

	nonce := ciphertext[:gcm.NonceSize()]
	ciphertext = ciphertext[gcm.NonceSize():]

	plaintext, err := gcm.Open(nil, nonce, ciphertext, nil)
	if err != nil {
		return nil, fmt.Errorf("cannot decrypt ciphertext, %v", err)
	}

	return plaintext, nil
}

func getIamToken(ctx context.Context, env secrets.CloudEnv, secretKey string) (string, error) {
	token, err := iam.GetIamTokenFromLocalMetaDataService()
	if err != nil {
		return "", fmt.Errorf("unable to get IAM token from local metadata service: %w", err)
	}

	if token != "" {
		fmt.Fprintln(os.Stderr, "got IAM token from local metadata service")
	} else {
		keyJSON, err := loadYavKey(ctx, secretKey)
		if err != nil {
			return "", fmt.Errorf("unable to load yav secret %s: %w", secretKey, err)
		}

		if token, err = iam.GetIamTokenFromTokenService(env, keyJSON); err != nil {
			return "", fmt.Errorf("unable to get IAM token from IamTokenServce: %w", err)
		}

		fmt.Fprintln(os.Stderr, "got IAM token from IamTokenService")
	}

	return token, nil
}

func decryptWithKms(enc *secrets.CloudKmsEncryption, ciphertext []byte) ([]byte, error) {
	ctx, cancel := context.WithTimeout(context.Background(), 20*time.Second)
	defer cancel()

	iamToken, err := getIamToken(ctx, enc.Env, enc.ServiceAccountKeyRef)
	if err != nil {
		return nil, err
	}

	client, err := kms.NewClient(ctx, iamToken, enc.Env)
	if err != nil {
		return nil, err
	}
	defer client.Close()

	plaintext, err := client.Decrypt(ctx, enc.KeyId, ciphertext)
	if err != nil {
		return nil, err
	}

	return plaintext, nil
}

func decrypt(filename string) ([]byte, error) {
	content, err := readFile(filename)
	if err != nil {
		return nil, err
	}

	var secretsStore secrets.SecretsStore
	err = proto.Unmarshal(content, &secretsStore)
	if err != nil {
		return nil, err
	}

	switch e := secretsStore.Encryption.(type) {
	case *secrets.SecretsStore_AesGcmSiv:
		return decryptWithYav(e.AesGcmSiv, secretsStore.Ciphertext)
	case *secrets.SecretsStore_CloudKms:
		return decryptWithKms(e.CloudKms, secretsStore.Ciphertext)
	}

	return nil, fmt.Errorf("secrets store does not contain encryption metadata")
}

func securelyWriteFile(filename string, username string, content []byte) error {
	tmpFile := filename + ".tmp"

	_ = os.Remove(tmpFile)

	// open file in exclusive mode and allow only reads for future user and its group
	file, err := os.OpenFile(tmpFile, os.O_EXCL|os.O_CREATE|os.O_WRONLY, 0440)
	if err != nil {
		return fmt.Errorf("cannot open temp file %s for write plaintext, %v", tmpFile, err)
	}

	_, err = file.Write(content)
	if err != nil {
		return fmt.Errorf("cannot write plaintext info %s, %v", tmpFile, err)
	}

	err = file.Close()
	if err != nil {
		return fmt.Errorf("cannot close file %s, %v", tmpFile, err)
	}

	if len(username) != 0 {
		sysUser, err := user.Lookup(username)
		if err != nil {
			return fmt.Errorf("cannot find user %s, %v", username, err)
		}

		uid, _ := strconv.Atoi(sysUser.Uid)
		gid, _ := strconv.Atoi(sysUser.Gid)

		err = os.Chown(tmpFile, uid, gid)
		if err != nil {
			return fmt.Errorf("cannot change owner of file %s, %v", tmpFile, err)
		}
	}

	err = os.Rename(tmpFile, filename)
	if err != nil {
		return fmt.Errorf("cannot rename file %s to %s, %v", tmpFile, filename, err)
	}

	return nil
}

func runDecryptCmd(cmd *cobra.Command, args []string) error {
	flags := cmd.Flags()

	in, _ := flags.GetString("in")
	out, _ := flags.GetString("out")
	username, _ := flags.GetString("user")

	if _, err := os.Stat(in); os.IsNotExist(err) {
		// do nothing if input file does not exit
		return nil
	}

	plaintext, err := decrypt(in)
	if err != nil {
		return err
	}

	return securelyWriteFile(out, username, plaintext)
}
