package disk

import (
	"io/ioutil"
	"os"
	"strconv"

	"path"
	"strings"
	"syscall"

	"go.uber.org/zap"

	"a.yandex-team.ru/infra/rsm/diskmanager/internal/ilog"
	"a.yandex-team.ru/infra/rsm/diskmanager/internal/utils"
	"a.yandex-team.ru/infra/rsm/diskmanager/pkg/bugon"
	"a.yandex-team.ru/infra/rsm/diskmanager/pkg/lvm"
	"a.yandex-team.ru/infra/rsm/diskmanager/pkg/mountinfo"
	"a.yandex-team.ru/infra/rsm/diskmanager/pkg/sysfs"
	"a.yandex-team.ru/infra/rsm/diskmanager/pkg/udev"
)

func AllDisks(nonzero bool, final bool) ([]*Disk, error) {
	files, err := ioutil.ReadDir("/sys/class/block")
	if err != nil {
		return nil, ErrIntSystem.Wrap(err)
	}
	mounts, err := mountinfo.Self()
	if err != nil {
		return nil, ErrIntSystem.Wrap(err)
	}
	disks := make([]*Disk, 0, len(files))
	sysCache := SystemInfo{mounts: mounts}
	for _, info := range files {
		disk, err := NewDiskFromCache(info.Name(), "", sysCache)
		if err != nil {
			return nil, err
		}

		if nonzero && disk.Size == 0 {
			continue
		}
		if final && !disk.IsFinal {
			continue
		}
		disks = append(disks, disk)
	}
	for _, d := range disks {
		d.frontendsIdentify(disks)
	}
	return disks, nil

}

func (d *Disk) identifyMount(mounts []mountinfo.MountInfo) error {
	var err error
	if mounts == nil {
		mounts, err = mountinfo.Self()
		if err != nil {
			return ErrIntSystem.Wrap(err)
		}
	}
	for _, m := range mounts {
		if (d.Dev != utils.Mkdev(m.Major, m.Minor)) || (m.Root != "/") {
			continue
		}
		bugon.BUGON(d.MInfo.Mountpoint != "", "Must be empty at this point")
		d.MInfo = m
		d.FsType = m.FsType
		break
	}
	switch d.FsType {
	case "ext4":
		obj := sysfs.FS.Object("ext4").SubObject(d.Name).Attribute("errors_count")
		d.FsErrCount, err = obj.ReadUint64()
		if err != nil {
			return err
		}
		if d.FsErrCount != 0 {
			d.Error = ErrIO
		}
	}
	return nil
}

func (d *Disk) backendsIdentify() error {
	slaves := d.SysfsPath + "/slaves"

	if pathExists(slaves) {
		files, err := ioutil.ReadDir(slaves)
		if err != nil {
			return err
		}
		for _, inf := range files {
			d.Backends = append(d.Backends, inf.Name())
		}
	}

	queue := []string{d.Name}
	for len(queue) != 0 {
		name := queue[len(queue)-1]
		queue = queue[:len(queue)-1]

		if pathExists("/sys/class/block/" + name + "/partition") {
			link, err := os.Readlink("/sys/class/block/" + name)
			if err != nil {
				return err
			}
			queue = append(queue, parentDirName(link))
			continue
		}
		slaves := "/sys/class/block/" + name + "/slaves"
		if pathExists(slaves) {
			files, err := ioutil.ReadDir(slaves)
			if err != nil {
				return err
			}
			for _, inf := range files {
				queue = append(queue, inf.Name())
			}
			if len(files) != 0 {
				continue
			}
		}
		if d.Name == name {
			continue
		}
		if _, ok := utils.Find(d.FinalBackends, name); !ok {
			d.FinalBackends = append(d.FinalBackends, name)
		}
	}
	return nil
}

func (d *Disk) frontendsIdentify(disks []*Disk) {
	if d.MInfo.Mountpoint != "" {
		d.FrontMounts = append(d.FrontMounts, d.MInfo.Mountpoint)
	}
	for _, dsk := range disks {
		if _, ok := utils.Find(dsk.FinalBackends, d.Name); !ok {
			continue
		}
		d.FrontDisks = append(d.FrontDisks, dsk.Name)
		if dsk.MInfo.Mountpoint != "" {
			d.FrontMounts = append(d.FrontMounts, dsk.MInfo.Mountpoint)
		}
	}
}

func makeDisk(dev string, devpath string, sysInfo SystemInfo) (*Disk, error) {
	ll := ilog.Log()

	disk := &Disk{Kind: MediaKindUnknown}

	if strings.HasPrefix(devpath, "/dev") {
		stat := syscall.Stat_t{}
		if err := syscall.Stat(devpath, &stat); err != nil {
			return nil, ErrNotFound.Wrap(err)
		}
		disk.Dev = stat.Dev
	} else if strings.ContainsRune(devpath, ':') {
		dev, err := utils.ParseMajMin(devpath)
		if err != nil {
			return nil, ErrIntSystem.Wrap(err)
		}
		disk.Dev = dev
	} else {
		obj := sysfs.Class.Object("block").SubObject(dev)
		dPath, err := obj.Attribute("dev").Read()
		if err != nil {
			ll.Error("not exits", zap.String("path", string(obj)))
			return nil, ErrNotFound.Wrap(err)
		}
		dev, err := utils.ParseMajMin(dPath)
		if err != nil {
			return nil, ErrIntSystem.Wrap(err)
		}
		disk.Dev = dev
	}
	bugon.BUGON(disk.Dev == 0, "Invalid dev")
	disk.MajorMinor = utils.FormatMajorMinor(disk.Dev)

	//disk.Backends = make([]string, 1)
	sysfsDev := sysfs.Dev.Object("block").SubObject(disk.MajorMinor)
	if utils.Major(disk.Dev) == 0 {
		disk.Kind = MediaKindVirtual
		disk.Name = string(disk.Kind) + "_" + disk.MajorMinor
	} else if !sysfsDev.Exists() {
		disk.Kind = MediaKindUnregistered
		disk.Name = string(disk.Kind) + "_" + disk.MajorMinor
	} else {
		name, err := sysfsDev.Readlink()
		if err != nil {
			return nil, err
		}
		disk.Name = path.Base(name)
		disk.DevPath = "/dev/" + disk.Name
		disk.SysfsPath = sysfs.Class.Object("block").SubObject(disk.Name).String()
	}
	err := disk.identifyMount(sysInfo.mounts)
	if err != nil {
		return nil, ErrIntSystem.Wrap(err)
	}
	if disk.SysfsPath == "" {
		return disk, nil
	}

	sysfsObj := sysfs.Class.Object("block").SubObject(disk.Name)
	val, err := sysfsObj.Attribute("size").ReadUint64()
	if err != nil {
		return nil, err
	}
	disk.Size = val << 9

	if partition := sysfsObj.Attribute("partition"); partition.Exists() {
		disk.Kind = MediaKindPartition
		pidx, _ := partition.ReadInt()
		start, _ := sysfsObj.Attribute("start").ReadUint64()
		disk.Partition = uint(pidx)
		disk.PartStart = start << 9
		parent, err := sysfsObj.Readlink()
		if err != nil {
			return nil, err
		}
		disk.PartDisk = parentDirName(parent)
		disk.Backends = []string{disk.PartDisk}

		if val, err := partition.Attribute("../queue/rotational").ReadBool(); err == nil {
			disk.IsRotational = val
		}
		if val, err := partition.Attribute("../queue/discard_granularity").ReadUint64(); err == nil {
			disk.DiscardBlockSize = val << 9
		}
	} else {
		baseKinds := []MediaKind{
			MediaKindScsiDisk,
			MediaKindVirtIO,
			MediaKindMD,
			MediaKindDeviceMapper,
			MediaKindNVME,
			MediaKindRAMDev,
			MediaKindNullBlk,
			MediaKindLoopdev}

		for _, kind := range baseKinds {
			if strings.HasPrefix(disk.Name, string(kind)) {
				disk.Kind = kind
				break
			}
		}
	}
	uDisk := udev.BlockDev(disk.MajorMinor)
	udevEntry, err := uDisk.Read()
	if err != nil {
		return nil, ErrIntSystem.Wrap(err)
	}
	disk.UdevEnv = udevEntry.Env
	disk.UdevLinks = udevEntry.Links
	disk.UdevInitTS = udevEntry.InitTS
	disk.PartUUID = disk.UdevEnv["ID_PART_ENTRY_UUID"]
	disk.PartTableUUID = disk.UdevEnv["ID_PART_TABLE_UUID"]

	disk.FsUUID = disk.UdevEnv["ID_FS_UUID"]
	disk.FsLabel = disk.UdevEnv["ID_FS_LABEL"]
	if disk.FsType == "" {
		disk.FsType = disk.UdevEnv["ID_FS_TYPE"]
	}
	if sysInfo.disks != nil {
		disk.frontendsIdentify(sysInfo.disks)
	}
	if err := disk.backendsIdentify(); err != nil {
		return nil, err
	}

	switch disk.Kind {
	case MediaKindScsiDisk:
		disk.IsFinal = true
		err := disk.ScsiIdentify()
		if err != nil {
			return nil, err
		}
	case MediaKindNVME:
		disk.IsFinal = true
		err := disk.NVMEIdentify()
		if err != nil {
			return nil, err
		}
	case MediaKindVirtIO:
		disk.IsFinal = true
		err := disk.VirtIOIdentify()
		if err != nil {
			return nil, err
		}
	case MediaKindLoopdev:
		disk.IsFinal = true
		err := disk.LoopdevIdentify()
		if err != nil {
			return nil, err
		}
	case MediaKindMD:
		err := disk.MDIdentify()
		if err != nil {
			return nil, err
		}
	case MediaKindDeviceMapper:
		err := disk.DeviceMapperIdentify()
		if err != nil {
			return nil, err
		}
	}
	return disk, nil
}

func NewDiskFromCache(dev string, devpath string, sysInfo SystemInfo) (*Disk, error) {
	ll := ilog.Log()

	disk, err := makeDisk(dev, devpath, sysInfo)
	ll.Debug("makeDisk", zap.String("req_dev", dev), zap.Any("disk", disk), zap.Error(err))
	return disk, err
}

func NewDisk(dev string) (*Disk, error) {
	return NewDiskFromCache(dev, "", SystemInfo{})
}

func (d *Disk) ScsiIdentify() error {
	d.Model = d.UdevEnv["ID_MODEL"]
	d.Serial = d.UdevEnv["ID_SERIAL_SHORT"]
	d.Firmware = d.UdevEnv["ID_REVISION"]

	if val, ok := d.UdevEnv["ID_WWN_WITH_EXTENSION"]; ok {
		d.WWN = "wwn-" + val
	} else {
		d.WWN = d.UdevEnv["ID_WWN"]
	}
	if d.WWN == "" {
		return ErrBadWWN
	}
	return nil
}

func (d *Disk) VirtIOIdentify() error {
	// Virtio devices has no unique persistent ID, we only can fake it.
	d.WWN = "virtio-" + d.Name
	return nil
}

func (d *Disk) LoopdevIdentify() error {
	// loopdev has no unique persistent ID, we only can fake it.
	d.WWN = d.Name + "-wwn-ts" + strconv.FormatUint(d.UdevInitTS, 16)
	return nil
}

func (d *Disk) MDIdentify() error {
	if val, ok := d.UdevEnv["MD_UUID"]; ok {
		d.WWN = "md-uuid-" + val
		return nil
	}
	return ErrBadWWN
}

func (d *Disk) DeviceMapperIdentify() error {
	if val, ok := d.UdevEnv["DM_UUID"]; ok {
		if strings.HasPrefix("LVM-", d.WWN) {
			if d.LvGroup, ok = d.UdevEnv["DM_VG_NAME"]; !ok {
				d.Error = lvm.ErrVolumeGroupNotFound
			}
			if d.LvName, ok = d.UdevEnv["DM_LV_NAME"]; ok {
				d.Error = lvm.ErrLogicalVolumeNotFound
			}
			d.Kind = MediaKindLVolume
		}
		d.WWN = val
		return nil
	}
	if val, ok := d.UdevEnv["DM_NAME"]; ok {
		d.WWN = "dm-name-" + val
		return nil
	}
	return ErrBadWWN
}

func (d *Disk) NVMEIdentify() error {
	obj := sysfs.NewObject(d.SysfsPath)
	if obj.Attribute("wwid").Exists() {
		d.WWN, _ = obj.Attribute("wwid").Read()
		d.Model, _ = obj.Attribute("device/model").Read()
		d.Serial, _ = obj.Attribute("device/derial").Read()
		d.Firmware, _ = obj.Attribute("device/firmware_rev").Read()
	}
	if d.WWN == "" {
		return ErrBadWWN
	}
	return nil
}
