package pcm

import (
	"errors"
	"fmt"

	"a.yandex-team.ru/infra/gopcm/pkg/mcfg"
	"a.yandex-team.ru/infra/gopcm/pkg/pci"
)

const (
	PcmIntelPciVendorID = 0x8086
)

const (
	McChPciPmonCtlEventNotKnl = 0x04
)

const (
	eventRead = iota
	eventWrite
)

type mcRegisterLocation [][][2]uint64
type socket2iMCbus [][2]uint32

type mcCntConfig []uint32

type pciCFGUncore struct {
	iMCBus   uint32
	groupNr  uint32
	cpuModel int64

	MCRegLoc   mcRegisterLocation
	IMCHandles []*pci.PciHandle
	IMCPmu     []*imcUncorePMU
}

func McChPCIPmonCtlUmask(v uint32) uint32 {
	return v << 8
}

func initUncore(socketNum int, pcm *Pcm) (*pciCFGUncore, error) {
	p := &pciCFGUncore{
		cpuModel: pcm.cpuModel,
	}

	err := p.initRegisters(pcm)
	if err != nil {
		return nil, err
	}
	// need debug
	err = p.initBuses(socketNum, pcm)
	if err != nil {
		return nil, err
	}
	// Need Debug
	err = p.initDirect(socketNum, pcm)
	if err != nil {
		return nil, err
	}
	return p, nil
}

func (p *pciCFGUncore) Close() error {
	for _, pmu := range p.IMCPmu {
		err := pmu.cleanup()
		if err != nil {
			return err
		}
	}
	for _, handle := range p.IMCHandles {
		err := handle.Close()
		if err != nil {
			return err
		}
	}
	return nil
}

func initMC(numOfControllers, numOfChannels int, arch string) mcRegisterLocation {
	var controllers mcRegisterLocation
	for controller := 0; controller < numOfControllers; controller++ {
		channels := [][2]uint64{}
		for channel := 0; channel < numOfChannels; channel++ {
			devAddrS := fmt.Sprintf("%sMc%dCh%dRegisterDevAddr", arch, controller, channel)
			funcAddrS := fmt.Sprintf("%sMc%dCh%dRegisterFuncAddr", arch, controller, channel)
			dev := uint64(registers[devAddrS])
			fc := uint64(registers[funcAddrS])
			pair := [2]uint64{dev, fc}
			channels = append(channels, pair)
		}
		controllers = append(controllers, channels)
	}
	return controllers
}

func (p *pciCFGUncore) initRegisters(pcm *Pcm) error {
	// we need only MC registers for now

	var res mcRegisterLocation

	if pcm.cpuModel == Jaketown || pcm.cpuModel == Ivytown {
		res = initMC(2, 4, "Jktivt")

	} else if pcm.cpuModel == Haswellx || pcm.cpuModel == BdxDe || pcm.cpuModel == Bdx {
		res = initMC(2, 4, "Hsx")

	} else if pcm.cpuModel == Skx {
		res = initMC(2, 4, "Skx")

		// } else if pcm.cpuModel == Knl {
		// 2 DDR4 Memory Controllers with 3 channels each
		//res = initMC(2, 3, "Knl")

		// 8iniMnnl[Stacked] DRAM) Memory Controllers

		// TODO: WHAT IS ECLK
		// mcRegisterLocation := initMC(2, 3, "KNL")
		// initMC(0, ECLK, "KNL")
		// initMC(1, ECLK, "KNL")
		// initMC(2, ECLK, "KNL")
		// initMC(3, ECLK, "KNL")
		// initMC(4, ECLK, "KNL")
		// initMC(5, ECLK, "KNL")
		// initMC(6, ECLK, "KNL")
		// initMC(7, ECLK, "KNL")
	} else {
		msg := fmt.Sprintf("uncore PMU for processor with model id %d is not supported.\n", p.cpuModel)
		return errors.New(msg)
	}
	p.MCRegLoc = res
	return nil
}

func (p *pciCFGUncore) initBuses(socketNum int, pcm *Pcm) error {
	totalSockets := pcm.numSockets
	buses, err := initSocket2Bus(uint32(p.MCRegLoc[0][0][0]), uint32(p.MCRegLoc[0][0][1]), ImcDevIds)
	if err != nil {
		return err
	}
	foundBusesNum := len(buses)
	if int32(foundBusesNum) == totalSockets {
		// we don't check with m2m here for now
		p.groupNr = buses[socketNum][0]
		p.iMCBus = buses[socketNum][1]
	} else if totalSockets <= 4 {
		return errors.New("can't init bus: unimplemented case with total of cpu sockets <= 4")
	} else {
		return errors.New("can't init bus")
	}

	return nil
}

func (p *pciCFGUncore) initDirect(socketNum int, pcm *Pcm) error {

	for _, ctrl := range p.MCRegLoc {
		// fmt.Printf("Controller: %d\n", i)

		for _, ch := range ctrl {
			//fmt.Printf("\tChannel: %d\n", j)
			dev := uint32(ch[0])
			fnc := uint32(ch[1])
			// fmt.Printf("dev: %d", dev)
			// fmt.Printf("func: %d", fnc)

			imcHandle, err := pci.Init(uint16(p.groupNr), uint8(p.iMCBus), dev, fnc)
			if err != nil {
				continue
			}
			p.IMCHandles = append(p.IMCHandles, imcHandle)
		}
	}
	for _, h := range p.IMCHandles {
		if p.cpuModel == Knl {
			upmu := &imcUncorePMU{
				UnitControl:         &hwRegister32{h, KnxMcChPciPmonBoxCtlAddr},
				FixedCounterControl: &hwRegister32{h, KnxMcChPciPmonFixedCtlAddr},
				FixedCounterValue:   &hwRegister64{h, KnxMcChPciPmonFixedCtrAddr},
			}
			upmu.CounterControls = []*hwRegister32{
				&hwRegister32{h, KnxMcChPciPmonCtl0Addr},
				&hwRegister32{h, KnxMcChPciPmonCtl1Addr},
				&hwRegister32{h, KnxMcChPciPmonCtl2Addr},
				&hwRegister32{h, KnxMcChPciPmonCtl3Addr},
			}
			upmu.CounterValues = []*hwRegister64{
				&hwRegister64{h, KnxMcChPciPmonCtr0Addr},
				&hwRegister64{h, KnxMcChPciPmonCtr1Addr},
				&hwRegister64{h, KnxMcChPciPmonCtr2Addr},
				&hwRegister64{h, KnxMcChPciPmonCtr3Addr},
			}
			p.IMCPmu = append(p.IMCPmu, upmu)
		} else {
			upmu := &imcUncorePMU{
				UnitControl:         &hwRegister32{h, XpfMcChPciPmonBoxCtlAddr},
				FixedCounterControl: &hwRegister32{h, XpfMcChPciPmonFixedCtlAddr},
				FixedCounterValue:   &hwRegister64{h, XpfMcChPciPmonFixedCtrAddr},
			}
			upmu.CounterControls = []*hwRegister32{
				&hwRegister32{h, XpfMcChPciPmonCtl0Addr},
				&hwRegister32{h, XpfMcChPciPmonCtl1Addr},
				&hwRegister32{h, XpfMcChPciPmonCtl2Addr},
				&hwRegister32{h, XpfMcChPciPmonCtl3Addr},
			}
			upmu.CounterValues = []*hwRegister64{
				&hwRegister64{h, XpfMcChPciPmonCtr0Addr},
				&hwRegister64{h, XpfMcChPciPmonCtr1Addr},
				&hwRegister64{h, XpfMcChPciPmonCtr2Addr},
				&hwRegister64{h, XpfMcChPciPmonCtr3Addr},
			}
			p.IMCPmu = append(p.IMCPmu, upmu)
		}
	}
	if len(p.IMCPmu) == 0 {
		msg := "could not initialize imcPMUs"
		return errors.New(msg)

	}
	return nil
}

func initSocket2Bus(device, function uint32, devIDs []uint32) (socket2iMCbus, error) {
	var res socket2iMCbus
	_, mcfgR, err := mcfg.Parse()
	if err != nil {
		return res, err
	}
	for _, record := range mcfgR {
		for bus := record.StartBusNumber; bus <= record.EndBusNumber; bus++ {
			pciH, err := pci.Init(record.SegmentGroupNumber, bus, device, function)
			if err != nil {
				if bus == record.EndBusNumber {
					break
				}
				continue
			}

			value, err := pciH.Read32(0)
			if err != nil {
				if bus == record.EndBusNumber {
					break
				}
				continue
			}

			err = pciH.Close()
			if err != nil {
				return res, err
			}

			vendorID := value & 0xffff
			if vendorID != PcmIntelPciVendorID {
				if bus == record.EndBusNumber {
					break
				}
				continue
			}
			deviceID := (value >> 16) & 0xffff
			for _, devID := range devIDs {
				if devID == deviceID {
					pair := [2]uint32{uint32(record.SegmentGroupNumber), uint32(bus)}
					res = append(res, pair)
				}
			}
			// not to overflow
			if bus == record.EndBusNumber {
				break
			}
		}
	}

	return res, nil
}

func (p *pciCFGUncore) getNumChannels() int {
	return len(p.IMCPmu)
}

func getPMUCounter(pmu []*imcUncorePMU, id, counter uint32) uint64 {
	res := uint64(0)
	if int(id) < len(pmu) && counter < 4 {
		tRes, err := pmu[id].CounterValues[counter].Read()
		if err != nil {
			return res
		}
		res = tRes
	}
	return res
}

func (p *pciCFGUncore) getMCCounter(channel, counter uint32) uint64 {
	return getPMUCounter(p.IMCPmu, channel, counter)
}

func (p *pciCFGUncore) programIMC(mccntcfg mcCntConfig) error {
	var extraIMC uint32

	extraIMC = UncPmonUnitCtlFrzEn
	if p.cpuModel == Skx {
		extraIMC = UncPmonUnitCtlRsv
	}

	for _, imcpmu := range p.IMCPmu {
		err := imcpmu.initFreeze(extraIMC)
		if err != nil {
			return err
		}
		// enable fixed counter (DRAM clocks)
		err = imcpmu.FixedCounterControl.Write(McChPciPmonFixedCtlEn)
		if err != nil {
			return err
		}
		// and reset it
		err = imcpmu.FixedCounterControl.Write(McChPciPmonFixedCtlEn + McChPciPmonFixedCtlRst)
		if err != nil {
			return err
		}

		err = imcpmu.program(mccntcfg, extraIMC)
		if err != nil {
			return err
		}
	}
	return nil
}

func (imcpmu *imcUncorePMU) program(mccntcfg mcCntConfig, extraIMC uint32) error {
	for i := 0; i < len(mccntcfg); i++ {
		err := imcpmu.CounterControls[i].Write(McChPciPmonCtlEn)
		if err != nil {
			return err
		}

		err = imcpmu.CounterControls[i].Write(McChPciPmonCtlEn | mccntcfg[i])
		if err != nil {
			return err
		}
	}

	if extraIMC != 0 {
		return imcpmu.resetUnfreeze(extraIMC)
	}
	return nil
}

func (p *pciCFGUncore) writeUnitControl(value uint32) error {
	for _, imcpmu := range p.IMCPmu {
		err := imcpmu.UnitControl.Write(value)
		if err != nil {
			return err
		}
	}
	return nil
}

func (p *pciCFGUncore) freezeCounters() error {
	value := UncPmonUnitCtlFrz
	if p.cpuModel == Skx {
		value = value + UncPmonUnitCtlRsv
	} else {
		value = value + UncPmonUnitCtlFrzEn
	}

	err := p.writeUnitControl(uint32(value))
	if err != nil {
		return err
	}
	return nil
}

func (p *pciCFGUncore) unfreezeCounters() error {
	if p.cpuModel == Skx {
		err := p.writeUnitControl(uint32(UncPmonUnitCtlRsv))
		if err != nil {
			return err
		}
	} else {
		err := p.writeUnitControl(uint32(UncPmonUnitCtlFrzEn))
		if err != nil {
			return err
		}
	}
	return nil
}

func (p *pciCFGUncore) programMemoryMetrics() error {
	// work with ranks not implemented
	mcCntCfg := make(mcCntConfig, MaxCounters)

	if p.cpuModel == Knl {
		msg := fmt.Sprintf("Can't program metrics for cpu model == %d: unimplemented", p.cpuModel)
		return errors.New(msg)
	}
	// READ
	mcCntCfg[eventRead] = McChPciPmonCtlEventNotKnl + McChPCIPmonCtlUmask(3)
	// WRITE
	mcCntCfg[eventWrite] = McChPciPmonCtlEventNotKnl + McChPCIPmonCtlUmask(12)

	return p.programIMC(mcCntCfg)
	// PMM not implemented
}

func (p *Pcm) ProgramUncoreMemoryMetrics() error {

	if len(p.msrs) == 0 || len(p.pciCfgUncores) == 0 {
		return errors.New("permission denied")
	}
	for _, pcicfg := range p.pciCfgUncores {
		err := pcicfg.programMemoryMetrics()
		if err != nil {
			return err
		}
	}
	return nil
}
