package rdmaib

import (
	"fmt"

	"a.yandex-team.ru/infra/rsm/sysconf/internal/plugin"

	"a.yandex-team.ru/infra/rsm/sysconf/internal/sysinfo"
	"a.yandex-team.ru/infra/rsm/sysconf/pkg/fs"
)

const (
	PCIMaxReadReq   = 5937
	PCICurLinkSpeed = "16 GT/s"
)

type pl struct {
	hostName string
	ibdevs   map[string]*fs.IBDev
}

func New() (*pl, error) {
	hostName, _, err := sysinfo.Hostname()
	if err != nil {
		return nil, err
	}

	ibdevs, err := fs.NewDefaultSys().ClassInfiniband()
	if err != nil {
		return nil, err
	}

	return &pl{
		hostName: hostName,
		ibdevs:   ibdevs,
	}, nil
}

func init() {
	p, err := New()
	if err != nil {
		panic(err)
	}
	plugin.Register(p)
}

func (p *pl) Name() string {
	return "rdmaib"
}

func (p *pl) Doc() string {
	return "https://st.yandex-team.ru/RTCNETWORK-636"
}

func (p *pl) IsApplicable() (st plugin.State) {
	st.Status = plugin.StatusEnable
	if len(p.ibdevs) == 0 {
		st.Status = plugin.StatusSkip
	}
	return
}

func (p *pl) Check() (sts plugin.States) {
	sts.Add(p.checkNodeDescFeature())
	sts.Add(p.checkCurLinkSpeed())
	sts.Add(p.checkMaxReadReqFeature())
	return
}

func (p *pl) Enable(force bool) (sts plugin.States) {
	var st plugin.State
	if st = p.checkNodeDescFeature(); st.Status == plugin.StatusDiff {
		st = p.enableNodeDescFeature()
	}
	sts.Add(st)
	if st = p.checkMaxReadReqFeature(); st.Status == plugin.StatusDiff {
		st = p.enableMaxReadReqFeature()
	}
	sts.Add(st)
	return
}

func (p *pl) Disable(force bool) (sts plugin.States) {
	sts.Add(plugin.State{Name: "self", Status: plugin.StatusSkip, Err: plugin.ErrNotImpl})
	return
}

func (p *pl) enableNodeDescFeature() (st plugin.State) {
	st = plugin.State{Name: "NodeDesc", Status: plugin.StatusFail, Err: nil}
	for _, ibdev := range p.ibdevs {
		st.Err = ibdev.Fs.WriteFile([]byte(openIBFormat(p.hostName, ibdev.Name)), "node_desc")
		if st.Err != nil {
			return
		}
	}
	st.Status = plugin.StatusOk
	return
}

func (p *pl) checkNodeDescFeature() (st plugin.State) {
	st = plugin.State{Name: "NodeDesc", Status: plugin.StatusFail, Err: nil}
	for _, ibdev := range p.ibdevs {
		data := openIBFormat(p.hostName, ibdev.Name)
		flag, err := ibdev.Fs.IsFileEqual(data, "node_desc")
		if err != nil {
			st.Err = err
			st.Status = plugin.StatusFail
			return
		}
		if !flag {
			st.Err = fmt.Errorf("%s must content %s", ibdev.Fs.Path("node_desc"), data)
			st.Status = plugin.StatusDiff
			return
		}
	}
	st.Status = plugin.StatusOk
	return
}

// https://st.yandex-team.ru/RTCNETWORK-693#601d7c2f9100531f27613461
func (p *pl) checkCurLinkSpeed() (st plugin.State) {
	pth := "device/current_link_speed"
	st = plugin.State{Name: "CurLinkSpeed", Status: plugin.StatusFail, Err: nil}
	for _, ibdev := range p.ibdevs {
		if !IsCX6(ibdev) {
			continue
		}
		flag, err := ibdev.Fs.IsFileEqual(PCICurLinkSpeed, pth)
		if err != nil {
			st.Err = err
			st.Status = plugin.StatusFail
			return
		}
		if !flag {
			st.Err = fmt.Errorf("%s must content %s", ibdev.Fs.Path(pth), PCICurLinkSpeed)
			st.Status = plugin.StatusDiff
			return
		}
	}
	st.Status = plugin.StatusOk
	return
}

func (p *pl) enableMaxReadReqFeature() (st plugin.State) {
	st = plugin.State{Name: "MaxReadReq", Status: plugin.StatusFail, Err: nil}
	for _, ibdev := range p.ibdevs {
		st.Err = ibdev.PCI.SetMaxReadReq(PCIMaxReadReq)
		if st.Err != nil {
			return
		}
	}
	st.Status = plugin.StatusOk
	return
}

func (p *pl) checkMaxReadReqFeature() (st plugin.State) {
	st = plugin.State{Name: "MaxReadReq", Status: plugin.StatusFail, Err: nil}
	for _, ibdev := range p.ibdevs {
		if !IsCX6(ibdev) {
			continue
		}
		flag, err := ibdev.PCI.IsMaxReadReq(PCIMaxReadReq)
		if err != nil {
			st.Err = err
			st.Status = plugin.StatusFail
			return
		}
		if !flag {
			st.Err = fmt.Errorf("on %s (%s) not equal %d", ibdev.PCI.Slot, ibdev.Fs, PCIMaxReadReq)
			st.Status = plugin.StatusDiff
			return
		}
	}
	st.Status = plugin.StatusOk
	return
}

func openIBFormat(k, v string) string {
	return k + " " + v
}

// check Mellanox Technologies MT28908 Family [ConnectX-6]
func IsCX6(dev *fs.IBDev) bool {
	return dev.PCI.Vendor == "15B3" && dev.PCI.Class == "20700"
}
