package utils

import (
	"fmt"
	"io/ioutil"
	"os"
	"os/exec"
	"strconv"
	"strings"
	"sync"
	"syscall"
	"time"
	"unsafe"

	"go.uber.org/zap"
	"golang.org/x/sys/unix"

	"a.yandex-team.ru/infra/rsm/nvgpumanager/internal/ilog"
)

const (
	DefaultServiceRestartsNrLimit    = 5
	DefaultServiceRestartsExpiration = time.Hour
)

type Service struct {
	cmd   *exec.Cmd
	child *os.Process

	pidFilePath string

	restartsCount   uint
	lastRestartTime time.Time

	restartsNrLimit    uint
	restartsExpiration time.Duration

	wait chan error
	done chan bool
	kill chan error

	lastCritErr *SyncedError
}

type SyncedError struct {
	mu  sync.Mutex
	err error
}

func newSyncedError() *SyncedError {
	return &SyncedError{
		err: fmt.Errorf("service hasn't been started yet"),
	}
}

func (se *SyncedError) setSyncedError(newErr error) {
	se.mu.Lock()
	se.err = newErr
	se.mu.Unlock()
}

func (se *SyncedError) getSyncedError() error {
	se.mu.Lock()
	defer se.mu.Unlock()
	return se.err
}

func NewService(name string, args []string, pidFilePath string, restartsNrLimit uint, restartsExpiration time.Duration) *Service {
	return doNewService(name, args, pidFilePath, restartsNrLimit, restartsExpiration, 0, time.Now())
}

func doNewService(name string, args []string, pidFilePath string, restartsNrLimit uint, restartsExpiration time.Duration, restartsCount uint, lastRestartTime time.Time) *Service {
	s := Service{
		cmd: exec.Command(name, args...),

		pidFilePath: pidFilePath,

		restartsCount:   restartsCount,
		lastRestartTime: lastRestartTime,

		restartsNrLimit:    restartsNrLimit,
		restartsExpiration: restartsExpiration,

		wait: make(chan error),
		done: make(chan bool),
		kill: make(chan error),

		lastCritErr: newSyncedError(),
	}

	s.cmd.SysProcAttr = &syscall.SysProcAttr{Setpgid: true}

	return &s
}

func setSubreaper() error {
	return unix.Prctl(unix.PR_SET_CHILD_SUBREAPER, uintptr(1), 0, 0, 0)
}

func isSubreaper() (bool, error) {
	var i uintptr

	if err := unix.Prctl(unix.PR_GET_CHILD_SUBREAPER, uintptr(unsafe.Pointer(&i)), 0, 0, 0); err != nil {
		return false, err
	}

	return (int(i) != 0), nil
}

func doWaitNonBlock(pid int) (int, error) {
	return doWait(pid, true)
}

func doWait(pid int, nonBlock bool) (int, error) {
	var (
		status  unix.WaitStatus
		options int
	)

	if nonBlock {
		options = unix.WNOHANG
	}
	return unix.Wait4(pid, &status, options, nil)
}

func (s *Service) doWaitChild() {
	var status unix.WaitStatus
	pid, err := unix.Wait4(s.child.Pid, &status, 0, nil)
	if err == nil && pid != s.child.Pid {
		e := fmt.Errorf("returned pid (%d) is other than expected (child pid: %d)", pid, s.child.Pid)
		ilog.Log().Error("wait()", zap.Error(e))
		// TODO: return?
	}

	s.wait <- err
}

func (s *Service) doStart(transferRestarts bool) error {
	// create new cmd if old one was started and already finished
	if s.child != nil { // was started
		errno := syscall.Kill(s.child.Pid, syscall.Signal(0))
		if errno == syscall.ESRCH { // already finished
			if transferRestarts {
				*s = *doNewService(s.cmd.Args[0], s.cmd.Args[1:], s.pidFilePath, s.restartsNrLimit, s.restartsExpiration, s.restartsCount, s.lastRestartTime)
			} else {
				*s = *NewService(s.cmd.Args[0], s.cmd.Args[1:], s.pidFilePath, s.restartsNrLimit, s.restartsExpiration)
			}
			ilog.Log().Debug("created new Service")
		} else if errno == nil {
			return fmt.Errorf("service already started")
		} else {
			return fmt.Errorf("failed to send child a signal, pid: %d, errno: %v", s.child.Pid, errno)
		}
	}

	isSubr, err := isSubreaper()
	if err != nil {
		return fmt.Errorf("failed to get subreaper, err: %v", err)
	}
	if !isSubr {
		err = setSubreaper()
		if err != nil {
			return fmt.Errorf("failed to set subreaper, err: %v", err)
		}
		ilog.Log().Debug("set subreaper", zap.Int("pid", os.Getpid()))
	}

	err = s.cmd.Start()
	if err != nil {
		return fmt.Errorf("failed to start service parent, err: %v", err)
	}
	ilog.Log().Debug("started service parent", zap.Int("pid", s.cmd.Process.Pid), zap.String("cmd", s.cmd.String()))

	err = s.cmd.Wait()
	if err != nil {
		return fmt.Errorf("failed to Wait() for service parent, err: %v", err)
	}
	if ret := s.cmd.ProcessState.ExitCode(); ret != 0 {
		return fmt.Errorf("service parent failed, ret: %d", ret)
	}
	ilog.Log().Debug("waited service parent", zap.Int("pid", s.cmd.Process.Pid), zap.String("cmd", s.cmd.String()))

	data, err := ioutil.ReadFile(s.pidFilePath)
	if err != nil {
		return fmt.Errorf("failed to read pid-file: %s, err: %v", s.pidFilePath, err)
	}
	// https://st.yandex-team.ru/HOSTMAN-975#60817c582228a321224b5323
	// parsing potentially smth like "254218\n2"
	dataStr := strings.Fields(string(data))[0]

	pid, err := strconv.Atoi(dataStr)
	if err != nil {
		return fmt.Errorf("failed to parse child pid from pid-file: %s, data: %s, err: %v", s.pidFilePath, dataStr, err)
	}
	ilog.Log().Debug("got child pid from pid-file", zap.Int("pid", pid), zap.String("cmd", s.cmd.String()))

	child, err := os.FindProcess(pid)
	if err != nil {
		return fmt.Errorf("failed to find service child process from its pid: %d, err: %v", pid, err)
	}
	s.child = child

	// try to wait for intermediate process from double fork:
	// * if there is no such process (pid == 0) - it's OK, probably service does only one fork
	// * if there is a pid - compare it with pid from pid-file in case service died prematurely
	pid, err = doWaitNonBlock(-1)
	if err != nil {
		return fmt.Errorf("failed to wait for intermediate process from double fork, pid: %d, err: %v", pid, err)
	}
	if pid == s.child.Pid {
		return fmt.Errorf("service died prematurely, pid: %d, err: %v", pid, err)
	} else if pid > 0 {
		ilog.Log().Debug("waited intermediate process from double fork", zap.Int("pid", pid), zap.String("cmd", s.cmd.String()))
	}

	go s.doWaitChild()

	return nil
}

func (s *Service) Start() error {
	err := s.doStart(false)
	if err != nil {
		err = fmt.Errorf("failed to start service, err: %v", err)
		s.lastCritErr.setSyncedError(err)
		return err
	}
	ilog.Log().Info("service started", zap.String("cmd", s.cmd.String()))
	s.lastCritErr.setSyncedError(nil)

	go func() {
		restart := true

		for {
			select {
			case err := <-s.wait:
				ilog.Log().Info("service exited", zap.Int("pid", s.child.Pid), zap.String("cmd", s.cmd.String()), zap.Error(err))
				if restart && (s.restartsCount < s.restartsNrLimit || time.Since(s.lastRestartTime) >= s.restartsExpiration) {
					err = s.doStart(true)
					if err != nil {
						err = fmt.Errorf("failed to restart service, err: %v", err)
						s.lastCritErr.setSyncedError(err)
						ilog.Log().Error("doStart()", zap.Error(err), zap.String("cmd", s.cmd.String()))
						return
					} else {
						if time.Since(s.lastRestartTime) < s.restartsExpiration {
							s.restartsCount++
						} else {
							s.restartsCount = 0
							ilog.Log().Debug("reset expired restarts count", zap.Duration("expiration period", s.restartsExpiration), zap.String("cmd", s.cmd.String()))
						}
						s.lastRestartTime = time.Now()

						s.lastCritErr.setSyncedError(nil)
						ilog.Log().Info("service restarted", zap.Uint("restarts count", s.restartsCount), zap.String("cmd", s.cmd.String()))
						if s.restartsCount == s.restartsNrLimit {
							ilog.Log().Info("next service restart will exceed the limit", zap.Uint("restarts count", s.restartsCount),
								zap.Uint("restarts limit", s.restartsNrLimit), zap.String("cmd", s.cmd.String()))
						}
					}
				} else {
					err = fmt.Errorf("service exited and won't restart")
					if s.restartsCount == s.restartsNrLimit {
						err = fmt.Errorf("service restarts exceeded the limit, %v", err)
					}

					s.lastCritErr.setSyncedError(err)
					ilog.Log().Info(err.Error(), zap.String("cmd", s.cmd.String()))

					return
				}
			case <-s.done:
				ilog.Log().Info("stopping service", zap.String("cmd", s.cmd.String()))

				err := syscall.Kill(s.child.Pid, syscall.SIGKILL)
				if err != nil {
					ilog.Log().Error("failed to kill service processes", zap.Int("pid", s.child.Pid), zap.Error(err))
				} else {
					restart = false
					s.lastCritErr.setSyncedError(fmt.Errorf("service has been stopped"))
				}
				s.kill <- err
			}

		}
	}()

	return nil
}

func (s *Service) Stop() error {
	if s.child == nil {
		return fmt.Errorf("service has never been started")
	}

	// if s.child != nil && s.child.Signal(syscall.Signal(0)) == os.ErrProcessDone { // in Go 1.6
	if s.child != nil && syscall.Kill(s.child.Pid, syscall.Signal(0)) == syscall.ESRCH {
		return fmt.Errorf("service already stopped")
	}

	s.done <- true
	return <-s.kill
}

func (s *Service) GetLastError() error {
	return s.lastCritErr.getSyncedError()
}
