package main

import (
	"bufio"
	"bytes"
	"context"
	"encoding/json"
	"flag"
	"fmt"
	"log"
	"math/rand"
	"os"
	"os/exec"
	"path"
	"path/filepath"
	"regexp"
	"strconv"
	"strings"
	"sync"
	"time"

	"a.yandex-team.ru/infra/rtc/ext4jresizer/pkg/fstab"
)

const (
	e2ProgTimeout              = 15 * time.Minute
	ytPathDefault              = "/yt/"
	tune2fsPathDefault         = "/sbin/tune2fs"
	fsckPathDefault            = "/sbin/e2fsck"
	stateDumpPath              = "/var/lib/yandex-yt-ext4jresizer/state.json"
	oneGiB                     = 1 << 30
	defaultRotationalJSizeGiB  = 32
	rotationalReserveFactor    = 2.3
	defaultNonRotationalJSize  = 4
	nonRotationalReserveFactor = 3
)

var (
	debug               = flag.Bool("debug", false, "enable debugging output")
	rotationJSizeGiB    = flag.Uint("rotationalJSizeGiB", defaultRotationalJSizeGiB, "rotational device journal size in GiB")
	nonRotationJSizeGiB = flag.Uint("nonRotationalJSizeGiB", defaultNonRotationalJSize, "non-rotational device journal size in GiB")
)

type WorkerState struct {
	DevFile          string `json:"dev_file"`
	FreeBytes        uint64 `json:"free_bytes"`
	TotalBytes       uint64 `json:"total_bytes"`
	Rotational       bool   `json:"rotational"`
	LastError        error  `json:"last_error"`
	LastOutput       []byte `json:"last_output"`
	Skipped          bool   `json:"skip"`
	PreFSCKDone      bool   `json:"pre_fsck_done"`
	SpaceOk          bool   `json:"space_ok"`
	JournalDestroyed bool   `json:"journal_destroyed"`
	JournalCreated   bool   `json:"journal_created"`
	JournalOffset    uint64 `json:"journal_offset"`
	JournalSize      uint64 `json:"journal_size"`
	PostFSCKDone     bool   `json:"post_fsck_done"`
}

func NewWorker(device string) *WorkerState {
	return &WorkerState{
		DevFile:   device,
		LastError: nil,
	}
}

func RunWithTimeout(cmd string, args []string, timeout time.Duration) ([]byte, error) {
	ctx, cancel := context.WithTimeout(context.Background(), timeout)
	defer cancel()
	if *debug {
		log.Printf("Running %s with args '%s' and timeout %s", cmd, strings.Join(args, " "), timeout)
	}
	return exec.CommandContext(ctx, cmd, args...).CombinedOutput()
}

func (w *WorkerState) getFsDetails() error {
	BlockSzP0 := "Block"
	BlockSzP1 := "size:"
	BlocksFreeP0 := "Free"
	BlocksFreeP1 := "blocks:"
	BlocksTotalP0 := "Block"
	BlocksTotalP1 := "count:"

	var freeBlocks uint64 = 0
	var totalBlocks uint64 = 0
	var blockSize uint64 = 0

	out, err := RunWithTimeout(tune2fsPathDefault, []string{"-l", w.DevFile}, e2ProgTimeout)
	if *debug {
		log.Printf("[%s] Got output: %s", w.DevFile, out)
	}
	if err != nil {
		return err
	}

	scanner := bufio.NewScanner(bytes.NewReader(out))
	for scanner.Scan() {
		if err := scanner.Err(); err != nil {
			return err
		}
		fields := strings.Fields(scanner.Text())
		if len(fields) != 0 {
			if fields[0] == BlockSzP0 && fields[1] == BlockSzP1 {
				blockSize, err = strconv.ParseUint(fields[2], 10, 64)

				if err != nil {
					return err
				}
			}
			if fields[0] == BlocksFreeP0 && fields[1] == BlocksFreeP1 {
				freeBlocks, err = strconv.ParseUint(fields[2], 10, 64)

				if err != nil {
					return err
				}
			}
			if fields[0] == BlocksTotalP0 && fields[1] == BlocksTotalP1 {
				totalBlocks, err = strconv.ParseUint(fields[2], 10, 64)

				if err != nil {
					return err
				}
			}
		}
	}
	w.FreeBytes = freeBlocks * blockSize
	w.TotalBytes = totalBlocks * blockSize

	if rot, err := isDeviceRotational(w.DevFile); err != nil {
		return err
	} else {
		w.Rotational = rot
	}

	return nil
}

func (w *WorkerState) journalFits() bool {
	if !w.Rotational {
		return float64(w.FreeBytes) > nonRotationalReserveFactor*float64(*rotationJSizeGiB)*oneGiB
	} else {
		return float64(w.FreeBytes) > rotationalReserveFactor*float64(*nonRotationJSizeGiB)*oneGiB
	}
}

func isDeviceRotational(bDev string) (bool, error) {
	bDevName := path.Base(bDev)
	if strings.HasPrefix(bDevName, "nvme") {
		return false, nil
	}
	i := len(bDevName) - 1
	for ; i > 0; i-- {
		if bDevName[i] > '9' {
			break
		}
	}
	bDevName = bDevName[:i+1]
	sysDev := path.Join("/sys/block", bDevName, "queue/rotational")
	buf, err := os.ReadFile(sysDev)
	if err != nil {
		return false, err
	}
	return buf[0] == '1', nil
}

func (w *WorkerState) runFsck(dryRun bool) ([]byte, error) {
	var fsckArgs []string
	if dryRun {
		fsckArgs = []string{"-vfn", w.DevFile}
	} else {
		fsckArgs = []string{"-vfp", w.DevFile}
	}

	out, err := RunWithTimeout(fsckPathDefault, fsckArgs, e2ProgTimeout)
	if err != nil {
		return out, err
	}
	if *debug {
		log.Printf("[%s] Got output: %s", w.DevFile, out)
	}

	// fsck returns zero, but fs isn't clean due to e2fsck -vfn
	m, err := regexp.MatchString(`^Fix? no`, string(out))

	if err != nil {
		return out, err
	} else if m {
		return out, fmt.Errorf("file system is not clean")
	} else {
		return out, nil
	}
}

func (w *WorkerState) removeJournal() ([]byte, error) {
	args := []string{"-O", "^has_journal", w.DevFile}
	return RunWithTimeout(tune2fsPathDefault, args, e2ProgTimeout)
}

func (w *WorkerState) createJournal() ([]byte, error) {
	var desiredJSizeMiB uint64
	if !w.Rotational {
		// tune2fs allows only megabytes as arg of size
		desiredJSizeMiB = uint64(*nonRotationJSizeGiB) * 1024
	} else {
		desiredJSizeMiB = uint64(*rotationJSizeGiB) * 1024
	}

	seed := rand.NewSource(time.Now().UnixNano())
	r := rand.New(seed)
	quarterSz := uint64(float64(w.TotalBytes) * 0.25)
	n := int64(quarterSz - desiredJSizeMiB*1024*1024)
	offset := uint64(0)
	if n > 0 {
		offset = quarterSz + uint64(r.Int63n(n))
	}
	w.JournalOffset = offset
	w.JournalSize = desiredJSizeMiB * 1024 * 1024
	args := []string{"-J", fmt.Sprintf("size=%d,location=%dB", desiredJSizeMiB, offset), w.DevFile}

	return RunWithTimeout(tune2fsPathDefault, args, e2ProgTimeout)
}

func (w *WorkerState) json() []byte {
	buf, err := json.Marshal(w)
	if err == nil {
		return buf
	}
	return nil
}

func logErrorsAndExit(states map[string]*WorkerState) {
	haveErrors := false
	for k, v := range states {
		if v.LastError != nil {
			haveErrors = true
			log.Printf("Unexpected error on %s: %v\noutput:\n%s\nworker state:\n%s", k, v.LastError, v.LastOutput, v.json())
		}
	}
	if haveErrors {
		dumpResult(states)
		log.Fatal("Refusing to continue")
	}
}

func logErrors(workers map[string]*WorkerState) {
	for k, v := range workers {
		if v.LastError != nil {
			log.Printf("Unexpected error on %s: %v\noutput:\n%s\nworker state:\n%s", k, v.LastError, v.LastOutput, v.json())
		}
	}
}

func dumpResult(states map[string]*WorkerState) {
	f, err := os.OpenFile(stateDumpPath, os.O_CREATE|os.O_TRUNC|os.O_RDWR, 0o644)
	if err != nil {
		return
	}
	defer f.Close()
	buf, err := json.Marshal(states)
	if err != nil {
		return
	}
	_, _ = f.Write(buf)
}

func main() {
	flag.Parse()
	parts := make(map[string]*WorkerState)
	wg := &sync.WaitGroup{}
	fsTab, err := fstab.ParseFromPath(fstab.DefaultPath)
	if err != nil {
		log.Fatalf("Failed to parse fstab at path %s: %v", fstab.DefaultPath, err)
	}

	for _, e := range fsTab {
		if !strings.HasPrefix(e.Dir, ytPathDefault) {
			continue
		}
		var realDev string
		info, err := os.Lstat(e.DevName)
		if err != nil {
			log.Fatalf("Failed to get info about %s: %v", e.DevName, err)
		}
		if info.Mode()&os.ModeSymlink > 0 {
			ln, err := os.Readlink(e.DevName)
			if err != nil {
				log.Fatalf("Failed to read link %s: %v", e.DevName, err)
			}
			devPath := path.Join(path.Dir(e.DevName), ln)
			realDev, err = filepath.Abs(devPath)
			if err != nil {
				log.Fatalf("Failed to get real device path for %s: %v", devPath, err)
			}
		} else {
			realDev = e.DevName
		}

		if _, ok := parts[realDev]; !ok {
			parts[realDev] = NewWorker(realDev)
		}
	}

	for k := range parts {
		log.Printf("Detected YT disk: %s", k)
	}

	for _, w := range parts {
		wg.Add(1)
		go func(w *WorkerState) {
			defer wg.Done()
			w.LastError = w.getFsDetails()
		}(w)
	}
	wg.Wait()
	logErrorsAndExit(parts)

	for _, w := range parts {
		if !w.journalFits() {
			w.LastError = fmt.Errorf("new journal does not fit to %s: free space %d", w.DevFile, w.FreeBytes)
		} else {
			w.SpaceOk = true
		}
	}
	logErrorsAndExit(parts)

	for _, w := range parts {
		log.Printf("Running fsck on %s", w.DevFile)
		wg.Add(1)
		go func(w *WorkerState) {
			defer wg.Done()
			w.LastOutput, w.LastError = w.runFsck(false)
			if w.LastError == nil {
				w.PreFSCKDone = true
			}
		}(w)
	}
	wg.Wait()
	logErrorsAndExit(parts)

	for _, w := range parts {
		log.Printf("Removing journal on %s", w.DevFile)
		wg.Add(1)
		go func(w *WorkerState) {
			defer wg.Done()
			w.LastOutput, w.LastError = w.removeJournal()
			if *debug {
				log.Printf("[%s] Got output: %s", w.DevFile, w.LastOutput)
			}
			if w.LastError == nil {
				w.JournalDestroyed = true
			} else {
				w.Skipped = true
			}
		}(w)
	}
	wg.Wait()
	logErrors(parts)

	for _, w := range parts {
		if w.Skipped {
			continue
		}
		log.Printf("Creating journal on %s", w.DevFile)
		wg.Add(1)
		go func(w *WorkerState) {
			defer wg.Done()
			w.LastOutput, w.LastError = w.createJournal()
			if *debug {
				log.Printf("[%s] Got output: %s", w.DevFile, w.LastOutput)
			}
			if w.LastError == nil {
				w.JournalCreated = true
			}
		}(w)
	}
	wg.Wait()
	logErrors(parts)

	for _, w := range parts {
		log.Printf("Running fsck on %s", w.DevFile)
		wg.Add(1)
		go func(w *WorkerState) {
			defer wg.Done()
			w.LastOutput, w.LastError = w.runFsck(false)
			if w.LastError == nil {
				w.PostFSCKDone = true
			}
		}(w)
	}
	wg.Wait()
	logErrorsAndExit(parts)
	log.Printf("Done")
	dumpResult(parts)
}
