package server

import (
	"context"
	"fmt"
	"strings"
	"syscall"

	"golang.org/x/sys/unix"

	"go.uber.org/zap"

	"a.yandex-team.ru/infra/porto/plugins/portostatd/internal"
	"a.yandex-team.ru/infra/porto/plugins/portostatd/pkg/diskstat"

	rpcpb "a.yandex-team.ru/infra/porto/plugins/portostatd/portostatd_rpc"

	diskman_api "a.yandex-team.ru/infra/diskmanager/proto"
	porto_api "a.yandex-team.ru/infra/porto/proto"
)

func getDeviceStats(lvmStatsByDev map[diskstat.Device]*internal.LvmVolumeStats) (diskstat.DiskStatMap, error) {
	procDiskStat, err := diskstat.ReadSystem()
	if err != nil {
		return procDiskStat, fmt.Errorf("failed to get disk stats: %w", err)
	}

	// just filter out those devices that we don't need
	for dev := range procDiskStat {
		if _, ok := lvmStatsByDev[dev]; !ok {
			delete(procDiskStat, dev)
		}
	}
	return procDiskStat, nil
}

var prevDeviceStats diskstat.DiskStatMap

func extractLvmStats(lvmVolumes []*diskman_api.Volume, portoVolumes []*porto_api.TVolumeDescription) (map[string]*internal.LvmStats, error) {
	lvmVolumesByMntPath := make(map[string]*diskman_api.Volume, len(lvmVolumes))
	for _, v := range lvmVolumes {
		lvmVolumesByMntPath[v.Status.MountPath] = v
	}

	ctLvmVolumes := map[string]map[string]string{} // ctName -> (mntPath -> volName)
	for _, v := range portoVolumes {
		place := v.Properties["place"]
		if _, ok := lvmVolumesByMntPath[place]; ok {
			for _, link := range v.Links {
				if link.Container != nil && link.Target != nil {
					dirs := strings.SplitN(*link.Target, "/", 4) // "/mnt/hdd0/logs" -> ["", "mnt", "hdd0", "logs"]
					if len(dirs) < 3 || dirs[1] != "mnt" {
						continue
					}
					ctName := *link.Container
					if ctLvmVolumes[ctName] == nil {
						ctLvmVolumes[ctName] = map[string]string{}
					}
					if _, ok := ctLvmVolumes[ctName][place]; !ok {
						ctLvmVolumes[ctName][place] = dirs[2] // "hdd0"
					}
				}
			}
		}
	}

	lvmStatsByMntPath := make(map[string]*internal.LvmVolumeStats, len(lvmVolumes))      // mntPath -> volStats
	lvmStatsByDev := make(map[diskstat.Device]*internal.LvmVolumeStats, len(lvmVolumes)) // "maj:min" -> volStats
	for _, volumes := range ctLvmVolumes {
		for mntPath, volName := range volumes {
			if _, ok := lvmStatsByMntPath[mntPath]; !ok {
				var statfs syscall.Statfs_t
				if err := syscall.Statfs(mntPath, &statfs); err != nil {
					// ENOENT if volume destroyed
					if err != syscall.ENOENT {
						zap.S().Errorf("Failed statfs('%v'): %v", mntPath, err)
					}
					continue
				}

				var stat syscall.Stat_t
				if err := syscall.Stat(mntPath, &stat); err != nil {
					zap.S().Errorf("Failed stat('%v'): %v", mntPath, err)
					continue
				}

				bSize := (uint64)(statfs.Bsize)
				volStats := &internal.LvmVolumeStats{
					VolName:        volName,
					AvailableBytes: statfs.Bavail * bSize,
					UsedBytes:      (statfs.Blocks - statfs.Bfree) * bSize,
				}

				lvmStatsByMntPath[mntPath] = volStats
				devKey := diskstat.Device{
					Maj: unix.Major(stat.Dev),
					Min: unix.Minor(stat.Dev),
				}
				lvmStatsByDev[devKey] = volStats
			}
		}
	}

	currDeviceStats, err := getDeviceStats(lvmStatsByDev)
	if err != nil {
		zap.S().Errorf("Failed to get device stats: %v", err)
	} else {
		if len(prevDeviceStats) != 0 {
			for dev, stats := range currDeviceStats {
				prevStats, ok := prevDeviceStats[dev]
				if !ok {
					prevStats = &diskstat.DiskStatRecord{}
				}

				lvmStats := lvmStatsByDev[dev]

				if stats.ReadSectors >= prevStats.ReadSectors {
					lvmStats.ReadBytes = (stats.ReadSectors - prevStats.ReadSectors) * 512
				}
				if stats.WriteSectors >= prevStats.WriteSectors {
					lvmStats.WriteBytes = (stats.WriteSectors - prevStats.WriteSectors) * 512
				}

				if stats.ReadSuccess >= prevStats.ReadSuccess {
					lvmStats.ReadOps = stats.ReadSuccess - prevStats.ReadSuccess
				}
				if stats.WriteSuccess >= prevStats.WriteSuccess {
					lvmStats.WriteOps = stats.WriteSuccess - prevStats.WriteSuccess
				}
				lvmStats.IoInProgress = stats.IoInProgress

				// got Await from sysstat
				// https://github.com/sysstat/sysstat/blob/e7295b25434b6082bd8a6ad9f7f02cad63b2ce49/rd_stats.c#L386
				prevIONum := prevStats.WriteSuccess + prevStats.ReadSuccess + prevStats.DiscardSuccess
				currIONum := stats.WriteSuccess + stats.ReadSuccess + stats.DiscardSuccess
				IONumDiff := currIONum - prevIONum
				if IONumDiff == 0 {
					lvmStats.Await = 0
				} else {
					lvmStats.Await = ((stats.TimeInRead - prevStats.TimeInRead) + (stats.TimeInWrite - prevStats.TimeInWrite) + (stats.TimeInDiscard - prevStats.TimeInDiscard)) / IONumDiff
				}

			}
		}
		prevDeviceStats = currDeviceStats
	}

	ctLvmStats := make(map[string]*internal.LvmStats, len(ctLvmVolumes)) // slotName -> lvm_stats
	for ctName, mntPaths := range ctLvmVolumes {
		volStats := make([]internal.LvmVolumeStats, 0, len(mntPaths))
		for mntPath := range mntPaths {
			volStats = append(volStats, *lvmStatsByMntPath[mntPath])
		}
		slotName := strings.SplitN(ctName, "/", 2)[0]
		if ctLvmStats[slotName] == nil {
			ctLvmStats[slotName] = &internal.LvmStats{VolStats: volStats}
		} else {
			ctLvmStats[slotName].VolStats = append(ctLvmStats[slotName].VolStats, volStats...)
		}
	}

	return ctLvmStats, nil
}

func doGetLvmStatsCached(req *rpcpb.GetLvmStatRequest) (*rpcpb.GetLvmStatsResponse, error) {
	return internal.GetLvmStatsStorage(req.CtName)
}

func (s *PortostatdServer) GetLvmStats(ctx context.Context, req *rpcpb.GetLvmStatRequest) (*rpcpb.GetLvmStatsResponse, error) {
	return doGetLvmStatsCached(req)
}
