package irqbalance

import (
	"fmt"
	"io/ioutil"
	"path/filepath"
	"regexp"
	"sort"
	"strconv"
	"strings"

	"a.yandex-team.ru/infra/rsm/sysconf/internal"
	"a.yandex-team.ru/infra/rsm/sysconf/internal/plugin"
	"a.yandex-team.ru/infra/rsm/sysconf/pkg/ethtool"
)

func init() {
	p, err := New()
	if err != nil {
		panic(err)
	}
	plugin.Register(p)
}

const (
	ifacesRgxp             = "/sys/class/net/eth[0-9]"
	cpuRgxp                = "/sys/devices/system/cpu/cpu[0-9]*"
	numaRgxp               = "/sys/devices/system/node/node[0-9]*"
	rpsSockFlowEntriesPath = "/proc/sys/net/core/rps_sock_flow_entries"
	rpsCpusRgxp            = "queues/rx-*/rps_cpus"
	interruptsPath         = "/proc/interrupts"
	// eth2-TxRx-11
	// eth1-rx-0
	// eth2-tx-1
	// mlx4-1@0000:04:00.0
	// mlx5_comp13@pci:0000:86:00.0
	ethChannelRgxp = ".*eth[0-9]+[-t,r,x]+.*|.*mlx[0-9][-_a-z]+[0-9]+@"
)

type iface string

func (i iface) name() string {
	return filepath.Base(string(i))
}

func (i iface) model() (string, error) {
	data, err := ioutil.ReadFile(filepath.Join(string(i), "device/device"))
	if err != nil {
		return "", nil
	}
	return strings.Trim(string(data), "\n "), nil
}

func (i iface) numa() (uint32, error) {
	data, err := ioutil.ReadFile(filepath.Join(string(i), "device/numa_node"))
	if err != nil {
		return 0, err
	}
	n, err := strconv.ParseUint(strings.Trim(string(data[:]), "\n "), 10, 32)
	if err != nil {
		return 0, err
	}
	return uint32(n), nil
}

type pl struct {
	ifaces      []iface
	numa        map[uint32][]uint32
	coresCount  uint32
	pcoresCount uint32
	irqs        map[int]string
}

func New() (*pl, error) {
	// net
	ifacePaths, err := filepath.Glob(ifacesRgxp)
	if err != nil {
		return nil, err
	}
	ifaces := []iface{}
	for _, i := range ifacePaths {
		if err := internal.IsFileContent(filepath.Join(string(i), "operstate"), "up"); err == nil {
			ifaces = append(ifaces, iface(i))
		}
	}

	// cpu
	cpuPaths, err := filepath.Glob(cpuRgxp)
	if err != nil {
		return nil, err
	}

	// numa
	numa := map[uint32][]uint32{}
	numaPaths, err := filepath.Glob(numaRgxp)
	if err != nil {
		return nil, err
	}
	for i, numaPath := range numaPaths {
		cbm, err := internal.GetCPUBitMapFromFile(filepath.Join(numaPath, "cpumap"))
		if err != nil {
			return nil, err
		}
		cpus, err := internal.CPUBitmapToCpus(cbm)
		if err != nil {
			return nil, err
		}
		numa[uint32(i)] = cpus
	}

	// irqs
	rgxp := regexp.MustCompile(ethChannelRgxp)
	irqs := map[int]string{}
	data, err := ioutil.ReadFile(interruptsPath)
	if err != nil {
		return nil, err
	}
	for _, l := range strings.Split(string(data), "\n") {
		if rgxp.MatchString(l) {
			line := strings.Fields(l)
			irq, err := strconv.Atoi(strings.Trim(line[0], ":"))
			if err != nil {
				return nil, err
			}
			n := line[len(line)-1]
			irqs[irq] = n
		}
	}

	return &pl{
		ifaces:      ifaces,
		numa:        numa,
		coresCount:  uint32(len(cpuPaths)),
		pcoresCount: uint32(len(cpuPaths) / 2),
		irqs:        irqs,
	}, nil
}

func (p *pl) Name() string {
	return "irqbalance"
}

func (p *pl) Doc() string {
	return "https://st.yandex-team.ru/RTCNETWORK-92"
}

func (p *pl) IsApplicable() (st plugin.State) {
	st.Status = plugin.StatusSkip
	return
}

func (p *pl) Check() (sts plugin.States) {
	for _, i := range p.ifaces {
		sts.Add(p.checkRingBufFeature(i))
		sts.Add(p.checkChannelsFeature(i))
		sts.Add(p.checkRPSFeature(i))
		sts.Add(p.checkRFSFeature(i))
		sts.Add(p.checkAFFFeature(i))
		sts.Add(p.checkXPSFeature(i))
	}
	return
}

func (p *pl) Enable(force bool) (sts plugin.States) {
	var st plugin.State
	// TODO: Remove it after deploy on prestable.
	// We must make sure that this plugin are working.
	// Temporarily manual drive!
	if !force {
		sts.Add(plugin.State{Name: "self", Status: plugin.StatusSkip, Err: plugin.ErrNeedForce})
		return
	}
	for _, i := range p.ifaces {
		if st = p.checkRingBufFeature(i); st.Status == plugin.StatusDiff {
			st = p.enableRingBufFeature(i)
		}
		sts.Add(st)
		if st = p.checkChannelsFeature(i); st.Status == plugin.StatusDiff {
			st = p.enableChannelsFeature(i, force)
		}
		sts.Add(st)
		if st = p.checkRPSFeature(i); st.Status == plugin.StatusDiff {
			st = p.enableRPSFeature(i)
		}
		sts.Add(st)
		if st = p.checkRFSFeature(i); st.Status == plugin.StatusDiff {
			st = p.enableRFSFeature(i)
		}
		sts.Add(st)
		if st = p.checkAFFFeature(i); st.Status == plugin.StatusDiff {
			st = p.enableAFFFeature(i)
		}
		sts.Add(st)
		if st = p.checkXPSFeature(i); st.Status == plugin.StatusDiff {
			st = p.enableXPSFeature(i)
		}
		sts.Add(st)
	}
	return
}

func (p *pl) Disable(force bool) (sts plugin.States) {
	sts.Add(plugin.State{Name: "self", Status: plugin.StatusSkip, Err: plugin.ErrNotImpl})
	return
}

func (p *pl) checkRingBufFeature(i iface) (st plugin.State) {
	st = plugin.State{Name: fmt.Sprintf("%s.%s", i.name(), "ringBufSize"), Status: plugin.StatusFail, Err: nil}
	rp, err := ethtool.GetRingParam(i.name())
	if err != nil {
		st.Err = err
		return
	}
	if rp.Rx != rp.RxMax || rp.Tx != rp.TxMax {
		st.Status = plugin.StatusDiff
		st.Err = fmt.Errorf("RXmax:%d, RX:%d; TXmax:%d, TX:%d", rp.RxMax, rp.Rx, rp.TxMax, rp.Tx)
		return
	}
	st.Status = plugin.StatusOk
	return
}

func (p *pl) enableRingBufFeature(i iface) (st plugin.State) {
	st = plugin.State{Name: fmt.Sprintf("%s.%s", i.name(), "ringBufSize"), Status: plugin.StatusFail, Err: nil}
	rp, err := ethtool.GetRingParam(i.name())
	if err != nil {
		st.Err = err
		return
	}
	_, st.Err = ethtool.SetRingParam(i.name(), rp.RxMax, rp.RxMiniMax, rp.RxJumboMax, rp.TxMax)
	if st.Err != nil {
		return
	}
	st.Status = plugin.StatusOk
	return
}

// Receive-Side Scaling (RSS)
// https://github.com/torvalds/linux/blob/v4.19/Documentation/networking/scaling.txt#L20
func (p *pl) checkChannelsFeature(i iface) (st plugin.State) {
	st = plugin.State{Name: fmt.Sprintf("%s.%s", i.name(), "channelsCount"), Status: plugin.StatusFail, Err: nil}
	// TODO
	model, _ := i.model()
	if model == "0x10d3" {
		st.Status = plugin.StatusSkip
		st.Err = plugin.ErrNotImpl
		return
	}
	rp, err := ethtool.GetChannelsParam(i.name())
	if err != nil {
		st.Err = err
		return
	}
	rxCount := internal.MinUInt(rp.RxMax, p.pcoresCount)
	txCount := internal.MinUInt(rp.TxMax, p.pcoresCount)
	//otherCount := internal.MinUInt(rp.OtherMax, p.pcoresCount)
	combinedCount := internal.MinUInt(rp.CombinedMax, p.pcoresCount)

	if rp.Rx != rxCount || rp.Tx != txCount || rp.Combined != combinedCount {
		st.Status = plugin.StatusDiff
		st.Err = fmt.Errorf("RXmax:%d, RX:%d; TXmax:%d, TX:%d; CombinedMax:%d, Combined:%d", rxCount, rp.Rx, txCount, rp.Tx, combinedCount, rp.Combined)
		return
	}
	st.Status = plugin.StatusOk
	return
}

func (p *pl) enableChannelsFeature(i iface, force bool) (st plugin.State) {
	st = plugin.State{Name: fmt.Sprintf("%s.%s", i.name(), "channelsCount"), Status: plugin.StatusFail, Err: nil}
	if !force {
		st.Status = plugin.StatusSkip
		st.Err = plugin.ErrNeedForce
		return
	}
	rp, err := ethtool.GetChannelsParam(i.name())
	if err != nil {
		st.Err = err
		return
	}
	rxCount := internal.MinUInt(rp.RxMax, p.pcoresCount)
	txCount := internal.MinUInt(rp.TxMax, p.pcoresCount)
	otherCount := internal.MinUInt(rp.OtherMax, p.pcoresCount)
	combinedCount := internal.MinUInt(rp.CombinedMax, p.pcoresCount)

	_, st.Err = ethtool.SetChannelsParam(i.name(), rxCount, txCount, otherCount, combinedCount)
	if st.Err != nil {
		return
	}
	st.Status = plugin.StatusOk
	return
}

// Receive Packet Steering (RPS).
// https://github.com/torvalds/linux/blob/v4.19/Documentation/networking/scaling.txt#L99
// Only for network cards with 1 channel
func (p *pl) checkRPSFeature(i iface) (st plugin.State) {
	st = plugin.State{Name: fmt.Sprintf("%s.%s", i.name(), "RPS"), Status: plugin.StatusFail, Err: nil}
	model, _ := i.model()
	if model != "0x10d3" {
		st.Status = plugin.StatusSkip
		st.Err = plugin.ErrNotImpl
		return
	}
	numa, err := i.numa()
	if err != nil {
		st.Err = err
		return
	}
	cbm := internal.CpusToCPUBitMap(p.coresCount, p.numa[numa])
	files, err := filepath.Glob(filepath.Join("/sys/class/net/", i.name(), rpsCpusRgxp))
	if err != nil {
		st.Err = err
		return
	}
	for _, f := range files {
		cbmCurrent, err := internal.GetCPUBitMapFromFile(f)
		if err != nil {
			st.Err = err
			return
		}
		if cbm != cbmCurrent {
			st.Status = plugin.StatusDiff
			st.Err = fmt.Errorf("%s doesn't content %s", f, cbm)
			return
		}
	}
	st.Status = plugin.StatusOk
	return
}

func (p *pl) enableRPSFeature(i iface) (st plugin.State) {
	st = plugin.State{Name: fmt.Sprintf("%s.%s", i.name(), "RPS"), Status: plugin.StatusFail, Err: nil}
	numa, err := i.numa()
	if err != nil {
		st.Err = err
		return
	}
	cbm := internal.CpusToCPUBitMap(p.coresCount, p.numa[numa])
	files, err := filepath.Glob(filepath.Join("/sys/class/net/", i.name(), rpsCpusRgxp))
	if err != nil {
		st.Err = err
		return
	}
	for _, f := range files {
		st.Err = ioutil.WriteFile(f, []byte(cbm), 0)
		if st.Err != nil {
			return
		}
	}
	st.Status = plugin.StatusOk
	return
}

// Receive Flow Steering (RFS).
// https://github.com/torvalds/linux/blob/v4.19/Documentation/networking/scaling.txt#L225
// https://access.redhat.com/documentation/en-us/red_hat_enterprise_linux/6/html/performance_tuning_guide/network-rfs
func (p *pl) checkRFSFeature(i iface) (st plugin.State) {
	st = plugin.State{Name: fmt.Sprintf("%s.%s", i.name(), "RFS"), Status: plugin.StatusFail, Err: nil}
	st.Err = internal.IsFileContent(rpsSockFlowEntriesPath, fmt.Sprintf("%d", getRpsSockFlowEntries(p.coresCount)))
	if st.Err != nil {
		st.Status = plugin.StatusDiff
		return
	}
	files, err := filepath.Glob(filepath.Join("/sys/class/net/", i.name(), "/queues/rx-*/rps_flow_cnt"))
	if err != nil {
		st.Err = err
	}
	for _, f := range files {
		st.Err = internal.IsFileContent(f, fmt.Sprintf("%d", getRpsFlowCnt(i, p.coresCount)))
		if st.Err != nil {
			st.Status = plugin.StatusDiff
			return
		}
	}
	st.Status = plugin.StatusOk
	return
}

func (p *pl) enableRFSFeature(i iface) (st plugin.State) {
	st = plugin.State{Name: fmt.Sprintf("%s.%s", i.name(), "RFS"), Status: plugin.StatusFail, Err: nil}
	st.Err = ioutil.WriteFile(rpsSockFlowEntriesPath, []byte(fmt.Sprintf("%d", getRpsSockFlowEntries(p.coresCount))), 0)
	if st.Err != nil {
		return
	}
	files, err := filepath.Glob(filepath.Join("/sys/class/net/", i.name(), "/queues/rx-*/rps_flow_cnt"))
	if err != nil {
		st.Err = err
		return
	}
	for _, f := range files {
		st.Err = ioutil.WriteFile(f, []byte(fmt.Sprintf("%d", getRpsFlowCnt(i, p.coresCount))), 0)
		if st.Err != nil {
			return
		}
	}
	st.Status = plugin.StatusOk
	return
}

func (p *pl) checkAFFFeature(i iface) (st plugin.State) {
	st = plugin.State{Name: fmt.Sprintf("%s.%s", i.name(), "Affinity"), Status: plugin.StatusFail, Err: nil}
	numa, err := i.numa()
	if err != nil {
		st.Err = err
		return
	}
	irqs, err := p.getIRQQueues(i)
	if err != nil {
		st.Err = err
		return
	}
	for i, irq := range irqs {
		cpus := []uint32{p.numa[numa][uint32(i)%p.pcoresCount]}
		cbm := internal.CpusToCPUBitMap(p.coresCount, cpus)
		f := fmt.Sprintf("/proc/irq/%d/smp_affinity", irq)
		cbmCurrent, err := internal.GetCPUBitMapFromFile(f)
		if err != nil {
			st.Err = err
			return
		}
		if cbm != cbmCurrent {
			st.Status = plugin.StatusDiff
			st.Err = fmt.Errorf("%s doesn't content %s", f, cbm)
			return
		}
	}
	st.Status = plugin.StatusOk
	return
}

func (p *pl) enableAFFFeature(i iface) (st plugin.State) {
	st = plugin.State{Name: fmt.Sprintf("%s.%s", i.name(), "Affinity"), Status: plugin.StatusFail, Err: nil}
	numa, err := i.numa()
	if err != nil {
		st.Err = err
		return
	}
	irqs, err := p.getIRQQueues(i)
	if err != nil {
		st.Err = err
		return
	}
	for i, irq := range irqs {
		cpus := []uint32{p.numa[numa][uint32(i)%p.pcoresCount]}
		cbm := internal.CpusToCPUBitMap(p.coresCount, cpus)
		st.Err = ioutil.WriteFile(fmt.Sprintf("/proc/irq/%d/smp_affinity", irq), []byte(cbm), 0)
		if st.Err != nil {
			return
		}
	}
	st.Status = plugin.StatusOk
	return
}

// TODO
func (p *pl) checkXPSFeature(i iface) (st plugin.State) {
	st = plugin.State{Name: fmt.Sprintf("%s.%s", i.name(), "XPS"), Status: plugin.StatusFail, Err: nil}
	//fmt.Println(filepath.Join("/sys/class/net/", i.name(), "queues", fmt.Sprintf("tx-%d", indx), "/xps_cpus"), cbm)
	st.Status = plugin.StatusSkip
	st.Err = plugin.ErrNotImpl
	return
}

// TODO
func (p *pl) enableXPSFeature(i iface) (st plugin.State) {
	st = plugin.State{Name: fmt.Sprintf("%s.%s", i.name(), "XPS"), Status: plugin.StatusFail, Err: nil}
	st.Status = plugin.StatusSkip
	st.Err = plugin.ErrNotImpl
	return
}

func (p *pl) getIRQQueues(i iface) ([]int, error) {
	irqs := []int{}
	files, err := ioutil.ReadDir(filepath.Join(string(i), "device/msi_irqs"))
	if err != nil {
		return nil, err
	}
	for _, file := range files {
		irq, err := strconv.Atoi(file.Name())
		if err != nil {
			return nil, err
		}
		if _, ok := p.irqs[irq]; ok {
			irqs = append(irqs, irq)
		}
	}
	sort.Ints(irqs)
	return irqs, nil
}

func getRpsSockFlowEntries(cores uint32) uint32 {
	return 4096 * cores
}

func getRpsFlowCnt(i iface, cores uint32) uint32 {
	// TODO
	model, _ := i.model()
	if model == "0x10d3" {
		return getRpsSockFlowEntries(cores)
	}
	return 4096
}
