package pcm

import (
	"errors"
	"fmt"

	"a.yandex-team.ru/infra/gopcm/pkg/cpuinfo"
	"a.yandex-team.ru/infra/gopcm/pkg/msr"
)

// This constants are taken from `https://github.com/opcm/pcm`, and those are from
// `Intel(r) 64 and IA-32 ArchitecturesSoftware Developer’s Manua`l.
const (
	PlatformInfoAddr = 0xCE

	IA32TimeStampCounter = 0x10
	InstRetiredAny       = 0x309
	CPUClkUnhaltedThread = 0x30A
)

type Pcm struct {
	cpuFamily     int64
	cpuModel      int64
	numCores      int32
	numSockets    int32
	msrs          []*msr.MsrHandle
	pciCfgUncores []*pciCFGUncore //contains *PCICFGUncore for each cpu socket
	nominalFreq   uint64
	// TSC                     uint64
}

type MsrCounterState struct {
	instRetired          uint64
	cpuClkUnhaltedThread uint64
	tsc                  uint64
}

// ReadMsrCounters tries to read counters from MSR fore specified core.
// If some read error ocures, the value for the counter will be 0, so log the error.
func (p *Pcm) ReadMsrCounters(msrID int, state *MsrCounterState) error {
	msg := ""
	inst, err := p.msrs[msrID].Read(InstRetiredAny)
	if err != nil {
		msg += err.Error()
	}
	clk, err := p.msrs[msrID].Read(CPUClkUnhaltedThread)
	if err != nil {
		msg += err.Error()
	}
	tsc, err := p.getTSC(msrID)
	if err != nil {
		msg += err.Error()
	}

	state.instRetired = inst
	state.cpuClkUnhaltedThread = clk
	state.tsc = tsc

	if len(msg) != 0 {
		return errors.New(msg)
	}

	return nil
}

func Init() (*Pcm, error) {
	p := &Pcm{}
	err := p.detectModel()
	if err != nil {
		return nil, err
	}
	err = p.initMSR()
	if err != nil {
		return nil, err
	}
	err = p.initUncoreObjects()
	if err != nil {
		return nil, err
	}
	return p, nil
}

func (p *Pcm) Close() error {
	for _, m := range p.msrs {
		err := m.Write(IA32CrPerfGlobalCtrl, 0) // disable core counting
		if err != nil {
			return err
		}
		err = m.Close()
		if err != nil {
			return err
		}
	}
	// disable uncore counting
	for _, unc := range p.pciCfgUncores {
		err := unc.Close()
		if err != nil {
			return err
		}
	}
	return nil
}

func (p *Pcm) detectModel() error {
	cpuInfo, err := cpuinfo.ReadCPUInfo(cpuinfo.CPUInfoPath)
	if err != nil {
		return err
	}
	p.cpuModel = cpuInfo.CPUs[0].Model
	p.cpuFamily = cpuInfo.CPUs[0].Family
	p.numCores = int32(cpuInfo.NumCPU())
	p.numSockets = int32(cpuInfo.NumSockets())
	return nil
}

func (p *Pcm) checkModel() error {
	switch p.cpuModel {
	case Nehalem:
		p.cpuModel = NehalemEp
	case Atom2:
		p.cpuModel = Atom
	case HaswellUlt, Haswell2:
		p.cpuModel = Haswell
	case BroadwellXeonE3:
		p.cpuModel = Broadwell
	case SklUy:
		p.cpuModel = Skl
	case Kbl1:
		p.cpuModel = Kbl
	}

	if !isSupported(p.cpuModel) {
		msg := fmt.Sprintf("CPU Model %d is unsupported", p.cpuModel)
		return errors.New(msg)
	}
	return nil
}

func (p *Pcm) initMSR() error {
	for i := 0; i < int(p.numCores); i++ {
		// we don't check if core is online, because in our case all cores should be online
		msr, err := msr.Init(i)
		if err != nil {
			msg := fmt.Sprintf("could not init MSR, check if 'msr' kernel module is loaded, err: %s\n", err)
			return errors.New(msg)
		}
		p.msrs = append(p.msrs, msr)
	}
	return nil
}

func (p *Pcm) initUncoreObjects() error {
	if hasPCICFGUncore(p.cpuModel) && len(p.msrs) != 0 {
		for i := 0; i < int(p.numSockets); i++ {
			pcicfg, err := initUncore(i, p)
			if err != nil {
				return err
			}
			p.pciCfgUncores = append(p.pciCfgUncores, pcicfg)
		}
	} else if hasClientMCCounters(p.cpuModel) && len(p.msrs) != 0 {
		msg := "client MCCounters are not implemented yet"
		return errors.New(msg)
	} else {
		msg := fmt.Sprintf("can not init uncore objects for cpu model %d\n", p.cpuModel)
		return errors.New(msg)
	}
	// initUncorePMUDirect() - unimplemented, seems like we don't need if for now

	return nil
}

func (p *Pcm) GetMCUncoreCounters(socket int) *MCCounters {
	pcicfg := p.pciCfgUncores[socket]
	_ = pcicfg.freezeCounters()
	mc := &MCCounters{}
	for channel := 0; channel < pcicfg.getNumChannels(); channel++ {
		for cnt := 0; cnt < MaxCounters; cnt++ {
			mc[channel][cnt] = pcicfg.getMCCounter(uint32(channel), uint32(cnt))
		}
	}
	_ = pcicfg.unfreezeCounters()
	return mc
}

func (p *Pcm) NumSockets() int {
	return int(p.numSockets)
}

func (p *Pcm) NumCores() int {
	return int(p.numCores)
}

func (p *Pcm) detectNominalFrequency() (uint64, error) {
	var busFreq uint64
	var nominal uint64
	if len(p.msrs) != 0 {
		freq, err := p.msrs[0].Read(PlatformInfoAddr)
		if err != nil {
			return nominal, err
		}
		if p.cpuModel == SandyBridge ||
			p.cpuModel == Jaketown ||
			p.cpuModel == Ivytown ||
			p.cpuModel == Haswellx ||
			p.cpuModel == BdxDe ||
			p.cpuModel == Bdx ||
			p.cpuModel == IvyBridge ||
			p.cpuModel == Haswell ||
			p.cpuModel == Broadwell ||
			p.cpuModel == Avoton ||
			p.cpuModel == ApolloLake ||
			p.cpuModel == Denverton ||
			p.cpuModel == Skl ||
			p.cpuModel == Kbl ||
			p.cpuModel == Icl ||
			p.cpuModel == Knl ||
			p.cpuModel == Skx {
			busFreq = uint64(100000000)
		} else {
			busFreq = uint64(133333333)
		}
		nominal = ((freq >> 8) & 255) * busFreq
	}
	return nominal, nil
}

func (p *Pcm) NominalFrequency() uint64 {
	if p.nominalFreq != 0 {
		return p.nominalFreq
	} else {
		tmp, err := p.detectNominalFrequency()
		if err != nil {
			fmt.Println(err)
			return p.nominalFreq
		}
		p.nominalFreq = tmp
	}
	return p.nominalFreq
}

func (p *Pcm) TickCount() (uint64, error) {
	// default values
	const multiplier = 1000

	tsc, err := p.getTSC(0)
	if err != nil {
		return 0, err
	}
	res := (multiplier * tsc) / p.NominalFrequency()
	// p.TSC = tsc
	return res, nil
}

func (p *Pcm) getTSC(msrID int) (uint64, error) {
	var res uint64

	if !p.isAtom() || p.cpuModel == Avoton {
		tRes, err := p.msrs[msrID].Read(IA32TimeStampCounter)
		if err != nil {
			return res, err
		}
		res = tRes
	} else {
		// unimplemented
		msg := fmt.Sprintf("can't get TSC for cpu model == %d: unimplemented", p.cpuModel)
		return res, errors.New(msg)
	}
	return res, nil
}

func (p *Pcm) GetIPC(before, after *MsrCounterState) float64 {
	clocks := after.cpuClkUnhaltedThread - before.cpuClkUnhaltedThread
	if clocks != 0 {
		return float64(after.instRetired-before.instRetired) / float64(clocks)
	}
	return 0
}

func (p *Pcm) isAtom() bool {
	return p.cpuModel == Atom ||
		p.cpuModel == Atom2 ||
		p.cpuModel == Centerton ||
		p.cpuModel == Baytrail ||
		p.cpuModel == Avoton ||
		p.cpuModel == Cherrytrail ||
		p.cpuModel == ApolloLake ||
		p.cpuModel == Denverton
}

func toBW(nEvents, elapsedTime uint64) float64 {
	nv := float64(nEvents)
	et := float64(elapsedTime)
	return nv * 64 / 1000000 / (et / 1000)
}

func (p *Pcm) CalculateMemoryBW(beforeStates, afterStates []*MCCounters, elapsedTime uint64) *MemData {
	md := &MemData{}
	for s := 0; s < p.NumSockets(); s++ {
		for channel := 0; channel < MaxChannels; channel++ {
			reads := MCDiffCounters(channel, eventRead, beforeStates[s], afterStates[s])
			writes := MCDiffCounters(channel, eventWrite, beforeStates[s], afterStates[s])

			md.ImcReadSktChan[s][channel] = toBW(reads, elapsedTime)
			md.ImcWriteSktChan[s][channel] = toBW(writes, elapsedTime)
			md.ImcReadSkt[s] = md.ImcReadSkt[s] + md.ImcReadSktChan[s][channel]
			md.ImcWriteSkt[s] = md.ImcWriteSkt[s] + md.ImcWriteSktChan[s][channel]

		}
	}
	return md
}
