package membw

import (
	"fmt"

	"a.yandex-team.ru/infra/gopcm/pkg/pcm"
)

type MemBW struct {
	beforeTime   uint64
	afterTime    uint64
	beforeStates []*pcm.MCCounters
	afterStates  []*pcm.MCCounters
	data         *pcm.MemData
	pcm          *pcm.Pcm
}

func Init(p *pcm.Pcm) (*MemBW, error) {
	err := p.ProgramUncoreMemoryMetrics()
	if err != nil {
		return nil, err
	}

	beforeStates := make([]*pcm.MCCounters, p.NumSockets())
	for s := 0; s < p.NumSockets(); s++ {
		beforeStates[s] = p.GetMCUncoreCounters(s)
	}

	now, err := p.TickCount()
	if err != nil {
		return nil, err
	}

	afterStates := make([]*pcm.MCCounters, p.NumSockets())

	m := &MemBW{
		pcm:          p,
		beforeTime:   now,
		afterTime:    now + 1, // just to ensure that after > before
		beforeStates: beforeStates,
		afterStates:  afterStates,
	}
	return m, nil
}

func (m *MemBW) Update(reprogram bool) error {
	at, err := m.pcm.TickCount()
	if err != nil {
		return err
	}
	m.afterTime = at
	for s := 0; s < m.pcm.NumSockets(); s++ {
		m.afterStates[s] = m.pcm.GetMCUncoreCounters(s)
	}

	m.data = m.pcm.CalculateMemoryBW(m.beforeStates, m.afterStates, m.afterTime-m.beforeTime)

	if reprogram {
		if m.data.Overall() == 0 {
			err := m.pcm.ProgramUncoreMemoryMetrics()
			if err != nil {
				return err
			}
		}
	}

	m.beforeTime, m.afterTime = m.afterTime, m.beforeTime
	m.beforeStates, m.afterStates = m.afterStates, m.beforeStates
	return nil
}

func (m *MemBW) summ(memDataCounters [pcm.MaxNumSockets]float64) float64 {
	var write float64
	for s := 0; s < m.pcm.NumSockets(); s++ {
		write += memDataCounters[s]
	}
	return write
}

func (m *MemBW) Data() *pcm.MemData {
	return m.data
}

func (m *MemBW) BW() {
	var systemReadBW float64
	var systemWriteBW float64
	var systemBW float64

	systemReadBW = m.Data().Reads()
	systemWriteBW = m.Data().Writes()

	for s := 0; s < m.pcm.NumSockets(); s++ {
		fmt.Printf("Socket %d:\n", s)
		for channel := 0; channel < pcm.MaxChannels; channel++ {
			rv := m.data.ImcReadSktChan[s][channel]
			wv := m.data.ImcWriteSktChan[s][channel]
			if rv != 0 || wv != 0 {
				fmt.Printf("\t- Channel %d,\tRead: %f MB\\s,\tWrite: %f MB\\s\n", channel, rv, wv)
			}
		}
		fmt.Printf("\t===================================================\n")
		fmt.Printf("\t- Summary %d:\tRead: %f MB\\s\tWrite: %f MB\\s\n", s, m.data.ImcReadSkt[s], m.data.ImcWriteSkt[s])
		fmt.Printf("\n")
	}
	systemBW = systemReadBW + systemWriteBW
	fmt.Printf("System:\n\t- SysRead: %f MB\\s\n\t- SysWrite: %f MB\\s\n\t- SysOverall: %f MB\\s\n", systemReadBW, systemWriteBW, systemBW)
}
