package cmd

import (
	"encoding/json"
	"errors"
	"fmt"
	"os"
	"sort"
	"strings"

	"github.com/spf13/cobra"

	pb "a.yandex-team.ru/infra/rsm/nvgpumanager/api"
)

var allocCmd = &cobra.Command{
	Use:   "alloc",
	Short: "Alloc gpu resources",
	RunE: func(cmd *cobra.Command, args []string) error {
		return doAlloc(cmd, args)
	},
}

func init() {
	rootCmd.AddCommand(allocCmd)
	allocCmd.Flags().StringVar(&driver, "driver", "host", "driver type, available choice [host, vfio]")
	allocCmd.Flags().BoolVarP(&dumpJSON, "json", "J", false, "Json output format")
}

func doAlloc(cmd *cobra.Command, args []string) error {

	if len(args) == 0 {
		return errors.New("not enough arguments")
	}

	req := &pb.AllocateRequest{
		ContainerRequests: &pb.ContainerAllocateRequest{
			DevicesIDs: args,
		},
	}
	switch driver {
	case "host":
		req.ContainerRequests.DriverName = "host"
	case "vfio":
		req.ContainerRequests.DriverName = "vfio"
	default:
		return errors.New("unsupporder driver type: " + driver)
	}
	reply, err := apiClient.Client.Allocate(apiCtx, req)
	if err != nil {
		return err
	}
	if dumpJSON {
		enc := json.NewEncoder(os.Stdout)
		enc.SetIndent("", "\t")
		return enc.Encode(reply.ContainerResponse)
	}
	// By default dump as porto spec

	if reply.ContainerResponse.Envs != nil {
		envs := []string{}
		for k, v := range reply.ContainerResponse.Envs {
			envs = append(envs, k+"="+v)
		}
		sort.Slice(envs[:], func(i, j int) bool {
			return envs[i] < envs[j]
		})
		if len(envs) > 0 {
			fmt.Printf("env='%s' ", strings.Join(envs, ";"))
		}
	}
	if reply.ContainerResponse.Mounts != nil {
		mounts := reply.ContainerResponse.Mounts
		sort.Slice(mounts[:], func(i, j int) bool {
			return mounts[i].ContainerPath < mounts[j].ContainerPath
		})

		binds := []string{}
		for _, m := range mounts {
			if m.ReadOnly {
				binds = append(binds, m.HostPath+" "+m.ContainerPath+" ro")
			} else {
				binds = append(binds, m.HostPath+" "+m.ContainerPath+" rw")
			}
		}
		if len(binds) > 0 {
			fmt.Printf("bind='%s' ", strings.Join(binds, ";"))
		}
	}
	if reply.ContainerResponse.Devices != nil {
		devices := []string{}
		for _, d := range reply.ContainerResponse.Devices {
			devices = append(devices, d.HostPath+" "+d.Permissions)
		}
		if len(devices) > 0 {
			fmt.Printf("devices='%s'", strings.Join(devices, ";"))
		}
	}
	return nil
}
