package cmd

import (
	"a.yandex-team.ru/library/go/yandex/oauth"
	"a.yandex-team.ru/library/go/yandex/yav/httpyav"
	"a.yandex-team.ru/solomon/protos/secrets"
	"a.yandex-team.ru/solomon/tools/secrets/internal/kms"
	"context"
	"crypto/aes"
	"crypto/cipher"
	"crypto/rand"
	"encoding/base64"
	"fmt"
	"github.com/golang/protobuf/proto"
	"github.com/spf13/cobra"
	"io/ioutil"
	"os"
	"os/user"
	"strings"
	"time"
)

var encryptFlags struct {
	kmsKey         string
	kmsEnv         string
	yavKey         string
	inputFilename  string
	outputFilename string
}

func init() {
	encryptCmd := &cobra.Command{
		Use:   "encrypt",
		Short: "Encrypt file using KMS or by private key stored in YAV",
		RunE:  runEncryptCmd,
	}
	encryptCmd.Flags().StringVar(&encryptFlags.kmsKey, "kms-key", "", "id of KMS key")
	encryptCmd.Flags().StringVar(&encryptFlags.kmsEnv, "kms-env", "", "environment of KMS service {preprod|prod}")
	encryptCmd.Flags().StringVar(&encryptFlags.yavKey, "yav-key", "", "secret id in YAV where private key is stored")
	encryptCmd.Flags().StringVar(&encryptFlags.inputFilename, "in", "", "input filename")
	encryptCmd.Flags().StringVar(&encryptFlags.outputFilename, "out", "", "output filename")
	_ = encryptCmd.MarkFlagRequired("in")
	_ = encryptCmd.MarkFlagRequired("out")
	rootCmd.AddCommand(encryptCmd)
}

func currentLogin() (string, error) {
	envUser := os.Getenv("SUDO_USER")
	if envUser == "z2" {
		// if program running from z2 we have to use robot-skc@ to get oauth token
		return "robot-skc", nil
	}

	if envUser != "" {
		return envUser, nil
	}

	envUser = os.Getenv("USER")
	if envUser != "" {
		return envUser, nil
	}

	sysUser, err := user.Current()
	if err != nil {
		return "", err
	}
	return sysUser.Username, nil
}

func loadYavKey(ctx context.Context, secretKey string) (string, error) {
	secretKeyArr := strings.Split(secretKey, "/")
	if len(secretKeyArr) != 2 {
		return "", fmt.Errorf("invalid format of secret key: %s", secretKey)
	}

	login, err := currentLogin()
	if err != nil {
		return "", fmt.Errorf("cannot get login of current user, %v", err)
	}

	token, err := oauth.GetTokenBySSH(ctx, clientID, clientSecret, oauth.WithUserLogin(login))
	if err != nil {
		return "", fmt.Errorf("cannot get OAuth token by SSH, %v", err)
	}

	client, err := httpyav.NewClient(httpyav.WithOAuthToken(token))
	if err != nil {
		return "", fmt.Errorf("cannot initialize yav client, %v", err)
	}

	ver, err := client.GetVersion(ctx, secretKeyArr[0])
	if err != nil {
		return "", fmt.Errorf("cannot load sercret %s, %v", secretKeyArr[0], err)
	}

	for _, value := range ver.Version.Values {
		if value.Key == secretKeyArr[1] {
			return value.Value, nil
		}
	}

	return "", fmt.Errorf("cannot find %s in secret %s", secretKeyArr[1], secretKeyArr[0])
}

func readFile(filename string) ([]byte, error) {
	if filename == "-" {
		content, err := ioutil.ReadAll(os.Stdin)
		if err != nil {
			return nil, fmt.Errorf("cannot read from stdin, %v", err)
		}
		return content, nil
	}

	content, err := ioutil.ReadFile(filename)
	if err != nil {
		return nil, fmt.Errorf("cannot read file %s, %v", filename, err)
	}
	return content, nil
}

func parseCloudEnv(str string) (secrets.CloudEnv, error) {
	env, ok := secrets.CloudEnv_value[strings.ToUpper(str)]
	if !ok {
		known := make([]string, 0, len(secrets.CloudEnv_value)-1)
		for k, v := range secrets.CloudEnv_name {
			if k != int32(secrets.CloudEnv_CLOUD_ENV_UNSPECIFIED) {
				known = append(known, v)
			}
		}
		return secrets.CloudEnv_CLOUD_ENV_UNSPECIFIED,
			fmt.Errorf("unknown env type '%s', known values are: {%s}", str, strings.Join(known, ", "))
	}

	return secrets.CloudEnv(env), nil
}

func encryptWithKms() (*secrets.SecretsStore, error) {
	plaintext, err := readFile(encryptFlags.inputFilename)
	if err != nil {
		return nil, err
	}

	env, err := parseCloudEnv(encryptFlags.kmsEnv)
	if err != nil {
		return nil, err
	}

	ctx, cancel := context.WithTimeout(context.Background(), 20*time.Second)
	defer cancel()

	iamToken, err := getIamToken(ctx, env, encryptFlags.yavKey)
	if err != nil {
		return nil, err
	}

	client, err := kms.NewClient(ctx, iamToken, env)
	if err != nil {
		return nil, err
	}
	defer client.Close()

	ciphertext, err := client.Encrypt(ctx, encryptFlags.kmsKey, plaintext)
	if err != nil {
		return nil, err
	}

	return &secrets.SecretsStore{
		Encryption: &secrets.SecretsStore_CloudKms{
			CloudKms: &secrets.CloudKmsEncryption{
				Env:                  env,
				KeyId:                encryptFlags.kmsKey,
				ServiceAccountKeyRef: encryptFlags.yavKey,
			},
		},
		Ciphertext: ciphertext,
	}, nil
}

func encryptWithYav() (*secrets.SecretsStore, error) {
	yavKey, err := loadYavKey(context.Background(), encryptFlags.yavKey)
	if err != nil {
		return nil, err
	}

	yavKeyBytes, err := base64.StdEncoding.DecodeString(yavKey)
	if err != nil {
		return nil, fmt.Errorf("cannot decode yav key wiht Base64 decoder, %v", err)
	}
	if len(yavKeyBytes) != 32 {
		return nil, fmt.Errorf("invalid yav key lenght %d, must be 32", len(yavKeyBytes))
	}

	plaintext, err := readFile(encryptFlags.inputFilename)
	if err != nil {
		return nil, err
	}

	block, err := aes.NewCipher(yavKeyBytes)
	if err != nil {
		return nil, fmt.Errorf("cannot create AES cipher, %v", err)
	}

	gcm, err := cipher.NewGCM(block)
	if err != nil {
		return nil, fmt.Errorf("cannot create GNM cipher, %v", err)
	}

	nonce := make([]byte, gcm.NonceSize())
	if _, err := rand.Read(nonce); err != nil {
		return nil, fmt.Errorf("cannot generate random nonce, %v", err)
	}

	ciphertext := gcm.Seal(nonce, nonce, plaintext, nil)

	return &secrets.SecretsStore{
		Encryption: &secrets.SecretsStore_AesGcmSiv{
			AesGcmSiv: &secrets.AesGcmSivEncryption{
				KeyRef: encryptFlags.yavKey,
			},
		},
		Ciphertext: ciphertext,
	}, nil
}

func encrypt() (*secrets.SecretsStore, error) {
	if len(encryptFlags.kmsKey) != 0 {
		if len(encryptFlags.kmsEnv) == 0 {
			return nil, fmt.Errorf("--kms-env cannot be empty")
		}
		if len(encryptFlags.yavKey) == 0 {
			return nil, fmt.Errorf("--yav-key cannot be empty")
		}
		return encryptWithKms()
	}

	if len(encryptFlags.yavKey) != 0 {
		return encryptWithYav()
	}

	return nil, fmt.Errorf("--yav-key or --kms-key must be provided")
}

func runEncryptCmd(cmd *cobra.Command, args []string) error {
	store, err := encrypt()
	if err != nil {
		return err
	}

	content, err := proto.Marshal(store)
	if err != nil {
		return err
	}

	return ioutil.WriteFile(encryptFlags.outputFilename, content, 0644)
}
