package cmd

import (
	"encoding/json"
	"fmt"
	"os"
	"sort"
	"strconv"

	"github.com/olekukonko/tablewriter"
	"github.com/spf13/cobra"
	"go.uber.org/zap"

	pb "a.yandex-team.ru/infra/rsm/nvgpumanager/api"
	"a.yandex-team.ru/infra/rsm/nvgpumanager/internal/ilog"
	//"a.yandex-team.ru/infra/rsm/nvgpumanager/internal/client"
)

// listCmd represents the list command
var listCmd = &cobra.Command{
	Use:   "list",
	Short: "List gpu resources",
	RunE: func(cmd *cobra.Command, args []string) error {
		return doList()
	},
}

var (
	showNvGpu bool
	showVFio  bool
)

func init() {
	rootCmd.AddCommand(listCmd)
	listCmd.Flags().BoolVarP(&dumpJSON, "json", "J", false, "Json output format")
	listCmd.Flags().BoolVar(&showNvGpu, "show-nvgpu", false, "Show only gpu devices controlled by nvidia driver")
	listCmd.Flags().BoolVar(&showVFio, "show-vfio", false, "show only gpu devices controlled by vfio-pci driver")
}

func doList() error {
	ll := ilog.Log()
	var err error
	nvDev := []*pb.GpuDevice{}
	vfioDev := []*pb.GpuDevice{}

	// If none was selected, show all
	if !showNvGpu && !showVFio {
		showNvGpu = true
		showVFio = true
	}

	ll.Debug("call api.List")
	r, err := apiClient.Client.ListDevices(apiCtx, &pb.Empty{})
	ll.Debug("ret api.List", zap.Any("reply", r), zap.Error(err))
	if err != nil {
		return err
	}
	// Show devices in deterministic order
	devices := r.Devices
	sort.Slice(devices[:], func(i, j int) bool {
		if devices[i].Meta.BusId == devices[j].Meta.BusId {
			return devices[i].Meta.Id == devices[j].Meta.Id
		}
		return devices[i].Meta.BusId < devices[j].Meta.BusId
	})

	for _, d := range devices {
		ll.Debug("walk", zap.Any("device", d))
		switch d.Spec.Driver.(type) {
		case *pb.GpuDeviceSpec_Nvidia:
			nvDev = append(nvDev, d)
		case *pb.GpuDeviceSpec_Vfio:
			vfioDev = append(vfioDev, d)
		}
	}
	if dumpJSON {
		devices := []*pb.GpuDevice{}
		if showNvGpu {
			devices = append(devices, nvDev...)
		}
		if showVFio {
			devices = append(devices, vfioDev...)
		}
		enc := json.NewEncoder(os.Stdout)
		enc.SetIndent("", "\t")
		return enc.Encode(devices)
	}

	if showNvGpu {
		err = dumpNvGpu(nvDev)
		if err != nil {
			return err
		}
	}
	if showVFio {
		err = dumpVFioGpu(vfioDev)
		if err != nil {
			return err
		}
	}
	return nil
}

func Readiness(ready *pb.Condition) string {
	if ready.Status {
		return "OK"
	}
	return "[CRIT] " + ready.Message
}

func Throttled(throttled *pb.Condition) string {
	if throttled.Status {
		return "[WARN] " + throttled.Message
	}
	return "OK"
}

func dumpNvGpu(devices []*pb.GpuDevice) error {
	odata := [][]string{}
	for _, d := range devices {
		spec := d.Spec.GetNvidia()
		status := d.Status.GetNvidia()

		odata = append(odata, []string{
			d.Meta.BusId,
			d.Meta.Id,
			d.Spec.PciDevice.ModelName,
			strconv.FormatUint(uint64(d.Spec.PciDevice.MemorySizeGb<<10), 10),
			strconv.Itoa(int(d.Spec.PciDevice.NumaNode)),
			spec.DriverVersion,
			fmt.Sprintf("%d.%d", spec.CudaVersion.Major, spec.CudaVersion.Minor),
			Readiness(d.Status.Ready),
			Throttled(status.Throttle),
			strconv.FormatUint(uint64(status.Temperature), 10),
			strconv.FormatUint(status.MemoryUsedMb, 10),
			strconv.FormatUint(status.MemoryFreeMb, 10),
			strconv.FormatFloat(float64(status.SmUtilization), 'g', 4, 32),
			strconv.FormatFloat(float64(status.SmOccupancy), 'g', 4, 32),
		})
	}
	table := tablewriter.NewWriter(os.Stdout)
	table.SetHeader([]string{"bus id", "id", "model", "memory", "numa", "driver", "cuda", "ready", "throttled", "temp", "used", "free", "sm_util", "sm_occup"})
	for _, v := range odata {
		table.Append(v)
	}
	table.Render()
	return nil
}

func dumpVFioGpu(devices []*pb.GpuDevice) error {
	odata := [][]string{}
	for _, d := range devices {
		spec := d.Spec.GetVfio()
		status := d.Status.GetVfio()
		odata = append(odata, []string{
			d.Meta.BusId,
			d.Meta.Id,
			d.Spec.PciDevice.ModelName,
			strconv.FormatUint(uint64(d.Spec.PciDevice.MemorySizeGb)<<10, 10),
			strconv.Itoa(int(d.Spec.PciDevice.NumaNode)),
			strconv.FormatUint(uint64(spec.IommuGroup), 10),
			strconv.FormatBool(status.Active),
		})
	}
	table := tablewriter.NewWriter(os.Stdout)
	table.SetHeader([]string{"bus id", "id", "model", "memory", "numa", "iommu_group", "active"})
	for _, v := range odata {
		table.Append(v)
	}
	table.Render()
	return nil
}
