package main

import (
	"bufio"
	"compress/gzip"
	"flag"
	"fmt"
	"io"
	"log"
	"os"
	"path"
	"path/filepath"
	"sort"
	"strconv"
	"strings"
	"sync"
	"syscall"
	"time"
)

// =================================================================================================

var DefaultLogsRotationSizeMegabytes = 100.0
var DefaultLogsRotationCount = 10
var DefaultLogsMaxLifetime = time.Duration(60*24) * time.Hour
var DefaultLogsSignal = uint(0)
var DefaultLogsCompressionLevel = 3

// =================================================================================================

type FileStat struct {
	Path string
	Stat syscall.Stat_t
}

func readDir(dir string) ([]os.DirEntry, error) {
	f, err := os.Open(dir)
	if err != nil {
		return nil, err
	}
	defer f.Close()
	return f.ReadDir(-1)
}

func fUser(paths []string) (map[string][]uint32, error) {
	result := map[string][]uint32{}
	inodes := map[uint64][]FileStat{}
	for _, path := range paths {
		var stat syscall.Stat_t
		if syscall.Stat(path, &stat) == nil {
			inodes[stat.Ino] = append(inodes[stat.Ino], FileStat{Path: path, Stat: stat})
		}
		result[path] = []uint32{}
	}

	files, err := readDir("/proc")
	if err != nil {
		return nil, err
	}
	for _, file := range files {
		if !file.IsDir() {
			continue
		}
		pid, err := strconv.ParseUint(file.Name(), 10, 32)
		if err != nil {
			continue
		}
		fdDir := fmt.Sprintf("/proc/%d/fd", pid)
		fds, err := readDir(fdDir)
		if err != nil {
			continue
		}
		for _, fd := range fds {
			fdLink := fmt.Sprintf("%s/%s", fdDir, fd.Name())
			var stat syscall.Stat_t
			if syscall.Stat(fdLink, &stat) == nil {
				if fileStats, ok := inodes[stat.Ino]; ok {
					for _, fileStat := range fileStats {
						if fileStat.Stat.Dev == stat.Dev {
							result[fileStat.Path] = append(result[fileStat.Path], uint32(pid))
						}
					}
				}
			}
		}
	}
	return result, nil
}

func gzipFile(path string,
	level int) error {

	fileIn, err := os.Open(path)
	if err != nil {
		return err
	}
	defer func() {
		if err := fileIn.Close(); err != nil {
			log.Fatal(err)
		}
	}()
	r := bufio.NewReader(fileIn)

	dstFile := fmt.Sprintf("%s.gz", path)
	tmpFile := fmt.Sprintf("%s.gz-%s", path, time.Now().Format("2006.01.02-15.04.05"))
	fileOut, err := os.Create(tmpFile)
	if err != nil {
		return err
	}
	defer func() {
		if err := fileOut.Close(); err != nil {
			log.Fatal(err)
		}
	}()
	w := bufio.NewWriter(fileOut)

	z, err := gzip.NewWriterLevel(w, level)
	if err != nil {
		return err
	}
	z.Name = path
	z.Comment = "logrotator compressor"
	z.ModTime = time.Now()
	defer func() {
		if err := z.Close(); err != nil {
			log.Fatal(err)
		}
	}()

	buf := make([]byte, 1<<14)
	for {
		n, err := r.Read(buf)
		if err != nil && err != io.EOF {
			return err
		}
		if n == 0 {
			break
		}
		if _, err := z.Write(buf[:n]); err != nil {
			return err
		}
	}
	if err := z.Flush(); err != nil {
		return err
	}
	if err := w.Flush(); err != nil {
		return err
	}
	if err := os.Rename(tmpFile, dstFile); err != nil {
		return err
	}
	return nil
}

func findAndRotateLogs(logFiles []string,
	logsRotationSizeBytes int64,
	logsRotationCount int,
	logsMaxLifetime time.Duration,
	logsCompressionLevel int,
	logsSignal uint) {

	logsRotationCount -= 1 // number of gzipped archives is one file short
	oldestFileTime := time.Now().Add(-logsMaxLifetime)
	signal := syscall.Signal(logsSignal)

	filesForPids := make([]string, len(logFiles))
	copy(filesForPids, logFiles)
	for _, file := range logFiles {
		filesForPids = append(filesForPids, fmt.Sprintf("%s.1", file))
	}
	filesToPids, err := fUser(filesForPids)
	if err != nil {
		log.Fatal(err)
	}

	var wg sync.WaitGroup
	for _, file := range logFiles {
		if stat, err := os.Stat(file); err == nil {
			if stat.Size() >= logsRotationSizeBytes {
				wg.Add(1)
				go func() {
					defer wg.Done()
					oneFile := fmt.Sprintf("%s.1", file)
					if _, err := os.Stat(oneFile); err == nil {
						log.Printf("Start archiving %s", oneFile)
						if err := gzipFile(oneFile, logsCompressionLevel); err != nil {
							log.Printf("Failed to archive %s: %v", oneFile, err)
							return
						}
						for _, pid := range filesToPids[oneFile] {
							log.Printf("Sending signal=9 to %d: uses %s", pid, oneFile)
							if err := syscall.Kill(int(pid), syscall.SIGKILL); err != nil {
								log.Printf("Failed to send signal=9 to %d: %v", pid, err)
							}
						}
						log.Printf("Removing archived %s", oneFile)
						if err := os.Remove(oneFile); err != nil {
							log.Printf("Failed to remove archived %s: %v", oneFile, err)
							return
						}
					}
					log.Printf("Moving %s to %s", file, oneFile)
					if err := os.Rename(file, oneFile); err != nil {
						log.Printf("Failed to move %s: %v", file, err)
						return
					}
					if logsSignal > 0 {
						for _, pid := range filesToPids[file] {
							log.Printf("Sending signal=%d to %d: uses %s", logsSignal, pid, file)
							if err := syscall.Kill(int(pid), signal); err != nil {
								log.Printf("Failed to send signal=%d to %d: %v", logsSignal, pid, err)
							}
						}
					}
				}()
			}
		}
	}
	wg.Wait()

	log.Printf("Trying to rotate archives")
	for _, file := range logFiles {
		dir := path.Dir(file)
		fileName := path.Base(file)
		dotFileName := fileName + "."

		files, err := readDir(dir)
		if err != nil {
			log.Printf("Failed to search for files like %s: %v", file, err)
			continue
		}
		needRotate := false
		archNums := []uint32{}
		for _, f := range files {
			archName := f.Name()
			if f.IsDir() || !strings.HasPrefix(archName, dotFileName) || !strings.HasSuffix(archName, ".gz") {
				continue
			}
			info, err := f.Info()
			if err != nil {
				continue
			}
			fullArchName := path.Join(dir, archName)
			if info.ModTime().Before(oldestFileTime) {
				log.Printf("Removing too old file %s: modification time %v", fullArchName, info.ModTime())
				if err := os.Remove(fullArchName); err != nil {
					log.Printf("Failed to remove too old file %s: %v", fullArchName, err)
				}
			}
			archNum, err := strconv.ParseUint(archName[len(dotFileName):len(archName)-3], 10, 32)
			if err != nil {
				log.Printf("Failed to enumerate archive file %s: %v", fullArchName, err)
				continue
			}
			archNums = append(archNums, uint32(archNum))
			if archNum == 1 {
				needRotate = true
			}
		}
		if needRotate {
			sort.Slice(archNums, func(i, j int) bool { return archNums[i] < archNums[j] })
			for i := len(archNums) - 1; i >= 0; i-- {
				fullArchName := fmt.Sprintf("%s.%d.gz", file, archNums[i])
				if i >= logsRotationCount {
					log.Printf("Removing last file %s", fullArchName)
					if err := os.Remove(fullArchName); err != nil {
						log.Printf("Failed to remove last file %s: %v", fullArchName, err)
					}
					continue
				}
				dstArchName := fmt.Sprintf("%s.%d.gz", file, archNums[i]+1)
				log.Printf("Moving %s to %s", fullArchName, dstArchName)
				if err := os.Rename(fullArchName, dstArchName); err != nil {
					log.Printf("Failed to move %s to %s: %v", fullArchName, dstArchName, err)
				}
			}
		}
	}
}

// =================================================================================================

func main() {
	var logsRotationSizeMegabytes float64
	var logsRotationCount int
	var logsMaxLifetime time.Duration
	var logsCompressionLevel int
	var logsSignal uint

	flag.Float64Var(&logsRotationSizeMegabytes, "s", DefaultLogsRotationSizeMegabytes, "min log size to rotate, MB")
	flag.IntVar(&logsRotationCount, "c", DefaultLogsRotationCount, "number of log files to keep")
	flag.IntVar(&logsCompressionLevel, "z", DefaultLogsCompressionLevel, "compression level to use")
	flag.DurationVar(&logsMaxLifetime, "t", DefaultLogsMaxLifetime, "max logs lifetime, days")
	flag.UintVar(&logsSignal, "k", DefaultLogsSignal, "send this signal to fuser, 0 == no signal")
	flag.Usage = func() {
		out := flag.CommandLine.Output()
		_, _ = fmt.Fprintf(out, "Usage of %s:\n", os.Args[0])
		_, _ = fmt.Fprintf(out, "  %s [flags] <files or patterns to rotate>\n\nFlags:\n", path.Base(os.Args[0]))
		flag.PrintDefaults()
	}
	flag.Parse()
	log.SetFlags(log.Ldate | log.Ltime | log.Lmicroseconds)

	if logsRotationCount < 0 {
		log.Fatal("Set rotation count greater than 0")
	}
	if logsCompressionLevel < 0 {
		log.Fatal("Set compression level greater than 0")
	}
	files := []string{}
	for _, pat := range flag.Args() {
		patFiles, err := filepath.Glob(pat)
		if err != nil {
			log.Fatal(err)
		}
		files = append(files, patFiles...)
	}
	if len(files) == 0 {
		log.Fatalf("Nothing to rotate! (%v)", flag.Args())
	}
	logsRotationSizeBytes := int64(logsRotationSizeMegabytes * (1 << 20))

	log.Printf("Rotating files %v", files)
	findAndRotateLogs(files, logsRotationSizeBytes, logsRotationCount, logsMaxLifetime, logsCompressionLevel, logsSignal)
}
