package systemd

import (
	"bufio"
	"bytes"
	"context"
	"fmt"
	"time"

	"a.yandex-team.ru/infra/hostctl/internal/executil"
	"a.yandex-team.ru/infra/hostctl/internal/systemd/persist"
	"a.yandex-team.ru/library/go/core/log"
	"a.yandex-team.ru/library/go/slices"
)

type cmds struct {
	Enable       *executil.Cmd
	Disable      *executil.Cmd
	Restart      *executil.Cmd
	Reload       *executil.Cmd
	Start        *executil.Cmd
	Stop         *executil.Cmd
	ReloadDaemon *executil.Cmd
	Status       *executil.Cmd
	IsEnabled    *executil.Cmd
}

const (
	DefaultTimeout = 30 * time.Second
	// use very large timeout for daemon-reload
	// on overloaded hosts daemon-reload can take a very long time
	systemctlDaemonReloadTimeout = 5 * time.Minute
	binSystemctl                 = "/bin/systemctl"
)

func NewSystemctl(p persist.Persist, l log.Logger) Systemd {
	systemctlCmd := func(method string) *executil.Cmd {
		return executil.NewCmd(binSystemctl, "--system", method).WithExecutor(executil.Execute)
	}
	return &systemctl{
		cmds: &cmds{
			Enable:       systemctlCmd("enable").AddArg("--no-reload"),
			Disable:      systemctlCmd("disable").AddArg("--no-reload"),
			Restart:      systemctlCmd("restart").AddArg("--no-block"),
			Reload:       systemctlCmd("reload").AddArg("--no-block"),
			Start:        systemctlCmd("start").AddArg("--no-block"),
			Stop:         systemctlCmd("stop"),
			ReloadDaemon: systemctlCmd("daemon-reload"),
			Status:       systemctlCmd("show").AddArg("--no-page"),
			IsEnabled:    systemctlCmd("is-enabled"),
		},
		p: p,
		l: l,
	}
}

type systemctl struct {
	cmds *cmds
	p    persist.Persist
	l    log.Logger
}

func (s *systemctl) Restart(ctx context.Context, u *Unit, revID string) error {
	cmd := s.cmds.Restart.AddArg(u.FullName())
	s.l.Debugf("Executing '%s'...", cmd.String())
	err := cmd.ExecCtx(ctx)
	if err != nil {
		return err
	}
	return s.p.SaveRevision(persistName(u), revID)
}

func (s *systemctl) Start(ctx context.Context, u *Unit, revID string) error {
	cmd := s.cmds.Start.AddArg(u.FullName())
	s.l.Debugf("Executing '%s'...", cmd.String())
	err := cmd.ExecCtx(ctx)
	if err != nil {
		return err
	}
	return s.p.SaveRevision(persistName(u), revID)
}

func (s *systemctl) Stop(ctx context.Context, u *Unit) error {
	cmd := s.cmds.Stop.AddArg(u.FullName())
	s.l.Debugf("Executing '%s'...", cmd.String())
	err := cmd.ExecCtx(ctx)
	if err != nil {
		return err
	}
	return s.p.RemoveRevision(persistName(u))
}

func (s *systemctl) Enable(u *Unit) error {
	// reload daemon in separated call with large timeout
	err := s.ReloadDaemon()
	if err != nil {
		return err
	}
	cmd := s.cmds.Enable.AddArg(u.FullName())
	ctx, cancel := context.WithTimeout(context.Background(), DefaultTimeout)
	defer cancel()
	s.l.Debugf("Executing '%s'...", cmd.String())
	return cmd.ExecCtx(ctx)
}

func (s *systemctl) Disable(u *Unit) error {
	// reload daemon in separated call with large timeout
	err := s.ReloadDaemon()
	if err != nil {
		return err
	}
	cmd := s.cmds.Disable.AddArg(u.FullName())
	ctx, cancel := context.WithTimeout(context.Background(), DefaultTimeout)
	defer cancel()
	s.l.Debugf("Executing '%s'...", cmd.String())
	return cmd.ExecCtx(ctx)
}

func (s *systemctl) Reload(ctx context.Context, u *Unit, revID string) error {
	cmd := s.cmds.Reload.AddArg(u.FullName())
	s.l.Debugf("Executing '%s'...", cmd.String())
	err := cmd.ExecCtx(ctx)
	if err != nil {
		return err
	}
	return s.p.SaveRevision(persistName(u), revID)
}

func (s *systemctl) ReloadDaemon() error {
	cmd := s.cmds.ReloadDaemon
	ctx, cancel := context.WithTimeout(context.Background(), systemctlDaemonReloadTimeout)
	defer cancel()
	s.l.Debugf("Executing '%s'...", cmd.String())
	return cmd.ExecCtx(ctx)
}

func (s *systemctl) Status(u *Unit, revID string) (*UnitStatus, error) {
	cmd := s.cmds.Status.AddArg(u.FullName())
	ctx, cancel := context.WithTimeout(context.Background(), DefaultTimeout)
	defer cancel()
	s.l.Debugf("Executing '%s'...", cmd.String())
	err := cmd.ExecCtx(ctx)
	if err != nil {
		return nil, err
	}
	status, err := parseShowOutput(cmd.Stdout())
	if err != nil {
		return nil, err
	}
	// direct call GetUnitFileState method because
	// we need daemon-reload to get actual UnitFileState prop info
	// from 'systemctl show <unit> -p UnitFileState'
	unitFileState, err := s.unitFileState(u)
	if err != nil {
		return nil, err
	}
	status.UnitFileState = unitFileState
	current, err := s.p.IsCurrent(persistName(u), revID)
	if err != nil {
		return nil, err
	}
	status.Outdated = !current
	return status, nil
}

func (s *systemctl) unitFileState(u *Unit) (UnitFileState, error) {
	cmd := s.cmds.IsEnabled.AddArg(u.FullName())
	ctx, cancel := context.WithTimeout(context.Background(), DefaultTimeout)
	defer cancel()
	s.l.Debugf("Executing '%s'...", cmd.String())
	err := cmd.ExecCtx(ctx)
	stdout := string(bytes.TrimSuffix(cmd.Stdout(), []byte("\n")))
	if err != nil {
		// try to parse stdout like one of UnitFileState
		// 'systemctl is-enabled <unit>' returns not 0 exit code on some UnitFileState
		// look 'man systemctl' for example
		if slices.ContainsString(BadUnitFileStates, stdout) {
			return UnitFileState(stdout), nil
		}
		if stdout == string(UnitFileStateBad) {
			return UnitFileStateBad, fmt.Errorf("%s returns '%s', stderr: %s", cmd.String(), cmd.Stdout(), cmd.Stderr())
		}
		// means that we asking UnitFileState on removed unit
		if bytes.Contains(cmd.Stderr(), []byte("No such file or directory")) {
			return UnitFileStateDisabled, nil
		}
		return UnitFileStateUnknown, err
	}
	return UnitFileState(stdout), nil
}

func parseShowOutput(statusBytes []byte) (*UnitStatus, error) {
	s := bufio.NewScanner(bytes.NewBuffer(statusBytes))
	props := make(map[string]interface{})
	for s.Scan() {
		line := s.Bytes()
		splited := bytes.Split(line, []byte("="))
		if len(splited) != 2 {
			continue
		}
		k, v := splited[0], splited[1]
		for _, prop := range append(requiredProps, optionalProps...) {
			if bytes.Equal(k, []byte(prop)) {
				props[string(k)] = string(v)
			}
		}
	}
	return decodeProps(props)
}
