package supervisor

import (
	"context"
	"fmt"
	"net"
	"os"
	"sync"

	"google.golang.org/grpc"
	"google.golang.org/grpc/codes"
	"google.golang.org/grpc/status"

	"a.yandex-team.ru/library/go/core/log"
	"a.yandex-team.ru/library/go/core/log/nop"
	"a.yandex-team.ru/security/skotty/skotty/internal/stdenv"
	"a.yandex-team.ru/security/skotty/skotty/internal/version"
	"a.yandex-team.ru/security/skotty/skotty/pkg/skottyctl/skottyrpc"
)

var _ skottyrpc.SupervisorServer = (*Supervisor)(nil)

type Supervisor struct {
	addr        string
	log         log.Logger
	ac          *AgentChecker
	onRestart   RestartHandler
	onKeyReload []KeyReloadHandler
	grpc        *grpc.Server
	state       skottyrpc.Status
	done        chan struct{}
	mu          sync.Mutex
}

func NewSupervisor(addr string, opts ...Option) *Supervisor {
	s := &Supervisor{
		addr: addr,
		log:  &nop.Logger{},
		done: make(chan struct{}),
	}

	for _, opt := range opts {
		switch v := opt.(type) {
		case loggerOption:
			s.log = v.logger
		case keyReloadHandlerOption:
			s.onKeyReload = append(s.onKeyReload, v.fn)
		case restartHandlerOption:
			s.onRestart = v.fn

		}
	}

	s.ac = NewAgentChecker(opts...)
	s.grpc = grpc.NewServer(
		grpc.UnaryInterceptor(s.errHandler),
		grpc.Creds(newUnixSocketCredentials(os.Getuid(), os.Getgid())),
	)
	return s
}

func (s *Supervisor) ListenAndServe() error {
	_ = os.Remove(s.addr)
	defer close(s.done)

	listener, err := net.Listen("unix", s.addr)
	if err != nil {
		return fmt.Errorf("listen fail: %w", err)
	}

	defer func() {
		_ = listener.Close()
	}()

	go s.ac.Start()

	skottyrpc.RegisterSupervisorServer(s.grpc, s)

	s.log.Info("supervisor server started", log.String("addr", s.addr))
	s.state = skottyrpc.Status_S_OK
	return s.grpc.Serve(listener)
}

func (s *Supervisor) Shutdown(ctx context.Context) error {
	s.grpc.GracefulStop()
	s.ac.Shutdown()
	_ = os.Remove(s.addr)

	select {
	case <-ctx.Done():
		return ctx.Err()
	case <-s.done:
		return nil
	}
}

func (s *Supervisor) Status(_ context.Context, _ *skottyrpc.StatusRequest) (*skottyrpc.StatusReply, error) {
	s.mu.Lock()
	defer s.mu.Unlock()

	return &skottyrpc.StatusReply{
		Status:  s.state,
		Pid:     int32(os.Getpid()),
		Version: version.Full(),
	}, nil
}

func (s *Supervisor) Restart(_ context.Context, _ *skottyrpc.RestartRequest) (*skottyrpc.RestartReply, error) {
	if s.onRestart == nil {
		return nil, status.Error(codes.Unimplemented, "no restart handler configured")
	}

	s.mu.Lock()
	state := s.state
	s.mu.Unlock()

	if state != skottyrpc.Status_S_OK {
		return nil, fmt.Errorf("invalid state: %s", state)
	}

	s.log.Info("restarting")
	s.mu.Lock()
	s.state = skottyrpc.Status_S_RESTARTING
	s.mu.Unlock()

	return &skottyrpc.RestartReply{}, s.onRestart()
}

func (s *Supervisor) ReloadKeys(_ context.Context, _ *skottyrpc.ReloadKeysRequest) (*skottyrpc.ReloadKeysReply, error) {
	s.mu.Lock()
	state := s.state
	s.mu.Unlock()

	if state != skottyrpc.Status_S_OK {
		return nil, fmt.Errorf("invalid state: %s", state)
	}

	s.log.Info("reload keys")
	s.mu.Lock()
	s.state = skottyrpc.Status_S_KEY_RELOADING
	s.mu.Unlock()
	for _, f := range s.onKeyReload {
		if err := f(); err != nil {
			s.log.Error("fail during key reloading", log.Error(err))
		}
	}

	s.mu.Lock()
	s.state = skottyrpc.Status_S_OK
	s.mu.Unlock()
	s.log.Info("keys reloaded")
	return &skottyrpc.ReloadKeysReply{}, nil
}

func (s *Supervisor) UpdateStartupTTY(_ context.Context, req *skottyrpc.UpdateStartupTTYRequest) (*skottyrpc.UpdateStartupTTYReply, error) {
	s.mu.Lock()
	defer s.mu.Unlock()

	allowedEnvs := make(map[string]struct{})
	for _, name := range stdenv.StartupEnvs {
		allowedEnvs[name] = struct{}{}
	}

	for _, env := range req.Env {
		if _, ok := allowedEnvs[env.Name]; !ok {
			continue
		}

		if err := os.Setenv(env.Name, env.Value); err != nil {
			return nil, fmt.Errorf("unable to set environment variable %s: %w", env.Name, err)
		}
	}

	return &skottyrpc.UpdateStartupTTYReply{}, nil
}

func (s *Supervisor) errHandler(ctx context.Context, req interface{}, info *grpc.UnaryServerInfo, handler grpc.UnaryHandler) (interface{}, error) {
	resp, err := handler(ctx, req)
	if err != nil {
		s.log.Error("request error", log.String("method", info.FullMethod), log.Error(err))
	}
	return resp, err
}
