package cmd

import (
	"context"
	"fmt"
	"log"
	"net"
	"net/http"
	"net/url"
	"os"
	"os/signal"
	rt "runtime"
	"strings"
	"sync"
	"syscall"
	"time"

	"github.com/grpc-ecosystem/grpc-gateway/runtime"
	"github.com/spf13/cobra"

	"golang.org/x/crypto/ssh/terminal"
	"golang.org/x/net/http2"
	"golang.org/x/net/http2/h2c"

	"google.golang.org/grpc"

	"a.yandex-team.ru/library/go/core/log/zap/logrotate"
	"go.uber.org/zap"
	"go.uber.org/zap/zapcore"

	"a.yandex-team.ru/infra/porto/plugins/portostatd/internal/server"
	rpcpb "a.yandex-team.ru/infra/porto/plugins/portostatd/portostatd_rpc"
)

const (
	PortostatdSockPath = "/var/run/portostatd.sock"
	PortostatdLogPath  = "/var/log/portostatd.log"

	DefaultSockPerms = 0o666
	DefaultSockOwner = 0
	DefaultSockGroup = 1333 // porto

	maxProcs = 4
)

var cacheUpdateInterval uint64
var httpServerGracefulShutdownTimeout uint64
var debug bool

func makeZapLogger(debug bool) (*zap.Logger, error) {
	u, err := url.ParseRequestURI(PortostatdLogPath)
	if err != nil {
		return nil, err
	}
	sink, err := logrotate.NewLogrotateSink(u, syscall.SIGHUP)
	if err != nil {
		return nil, err
	}
	encoderCfg := zap.NewProductionEncoderConfig()
	encoderCfg.EncodeTime = zapcore.ISO8601TimeEncoder
	encoderCfg.EncodeLevel = zapcore.CapitalLevelEncoder
	encoder := zapcore.NewConsoleEncoder(encoderCfg)
	al := zap.NewAtomicLevelAt(zapcore.InfoLevel)
	if debug {
		al = zap.NewAtomicLevelAt(zapcore.DebugLevel)
	}
	core := zapcore.NewCore(encoder, sink, al)
	if terminal.IsTerminal(int(os.Stdout.Fd())) {
		core = zapcore.NewTee(
			core,
			zapcore.NewCore(encoder, zapcore.Lock(os.Stdout), al))
	}
	return zap.New(core), nil
}

var startCmd = &cobra.Command{
	Use:   "start",
	Short: "Start grpc-gateway server",
	Run: func(cmd *cobra.Command, args []string) {
		var wg sync.WaitGroup

		logger, err := makeZapLogger(debug)
		if err != nil {
			log.Fatalf("Failed to init zap logger: %v", err)
		}
		_ = zap.ReplaceGlobals(logger)

		zap.S().Info("Starting...")
		// we don't use concurency too much
		// we don't want too much overhead from managing threads
		oldMaxProc := rt.GOMAXPROCS(maxProcs)
		zap.S().Debugf("Set GOMAXPROCS to %d from %d", maxProcs, oldMaxProc)
		portostatdServer, err := server.InitPortostatdServer()
		if err != nil {
			zap.S().Fatalf("%v", err)
		}
		err = portostatdServer.PerformStatsCacheUpdate()
		if err != nil {
			zap.S().Errorf("Initial cache update failed: %v", err)
		}

		cacheCtx, cacheCancel := context.WithCancel(context.Background())
		defer cacheCancel()

		wg.Add(1)
		go cacheUpdateLoop(cacheCtx, &wg, portostatdServer, cacheUpdateInterval)

		grpcHandler := func(ctx context.Context, req interface{}, info *grpc.UnaryServerInfo, handler grpc.UnaryHandler) (interface{}, error) {
			resp, err := handler(ctx, req)
			if err != nil {
				zap.S().Errorf("%v", err)
			}
			return resp, err
		}
		grpcServer := grpc.NewServer(grpc.UnaryInterceptor(grpcHandler))
		rpcpb.RegisterPortostatdServiceServer(grpcServer, portostatdServer)

		/* In order to create socket in Listen() later, we need to make sure,
		 * that there is no staled socket at desired path
		 */
		err = unlinkStaleSocket()
		if err != nil {
			zap.S().Fatalf("Failed to unlink staled socket: %v", err)
		}

		// custom marshaller is needed in order to NOT omit 'nil' values in response json,
		// as 'golang/protobuf' hardcodes 'omitempty'
		//
		// (however, this doesn't help with printing raw gRPC responses - 'nil' values are omitted
		// anyway, but direct access to them works - they are zero values of their type)
		//
		// such behavior is supposed to become default in grpc-gateway v2:
		// https://github.com/grpc-ecosystem/grpc-gateway/issues/233
		customMarshaller := &runtime.JSONPb{
			OrigName:     true,
			EmitDefaults: true, // disable 'omitempty'
		}
		protoErrHandler := func(ctx context.Context, mux *runtime.ServeMux, marshaler runtime.Marshaler, w http.ResponseWriter, _ *http.Request, err error) {
			if err != nil {
				zap.S().Errorf("%v", err)
			}
			runtime.DefaultHTTPError(ctx, mux, marshaler, w, nil, err)
		}
		router := runtime.NewServeMux(
			runtime.WithMarshalerOption(runtime.MIMEWildcard, customMarshaller),
			runtime.WithProtoErrorHandler(protoErrHandler),
		)

		err = rpcpb.RegisterPortostatdServiceHandlerServer(context.Background(), router, portostatdServer)
		if err != nil {
			zap.S().Fatalf("Failed to register gateway: %v", err)
		}

		listener, err := net.Listen("unix", PortostatdSockPath)
		if err != nil {
			zap.S().Fatalf("Listen failed: %v", err)
		}

		err = setSockOwnerPerms()
		if err != nil {
			zap.S().Fatalf("Failed to set socket ownership or permissions: %v", err)
		}

		httpServer := http.Server{
			Handler: httpGrpcRouter(grpcServer, router),
		}

		stop := make(chan os.Signal)
		signal.Notify(stop, syscall.SIGTERM, syscall.SIGINT, syscall.SIGQUIT)

		wg.Add(1)
		// goroutine that performs gracefull shutdown of daemon
		go func() {
			defer wg.Done()
			defer signal.Stop(stop)

			sig := <-stop
			zap.S().Infof("Got %s signal", sig)

			zap.S().Info("Stopping grpc server...")
			grpcServer.GracefulStop()

			err := portostatdServer.Close()
			if err != nil {
				zap.S().Warnf("Got some error on closing portostatdServer client: %v\n", err)
			}

			zap.S().Info("Shutting down http server...")
			ctx, cancel := context.WithTimeout(context.Background(), time.Duration(httpServerGracefulShutdownTimeout)*time.Second)
			defer cancel()

			err = httpServer.Shutdown(ctx)
			if err != nil {
				zap.S().Fatalf("Shutdown failed: %v", err)
			}
			/* Successful Shutdown() closes all open listeners,
			 * thus unix socket is closed now and no longer exists on FS,
			 * so, no need to unlink it here.
			 */

			zap.S().Info("Stopping cache updater...")
			cacheCancel()
		}()

		zap.S().Info("Start serving requests...")
		err = httpServer.Serve(listener)
		if err != nil && err != http.ErrServerClosed { // ErrServerClosed - server was shutdown
			zap.S().Fatalf("Serve failed: %v", err)
		}

		// wait for termination of all goroutines
		zap.S().Info("Exiting...")
		wg.Wait()
	},
}

func httpGrpcRouter(grpcServer *grpc.Server, httpHandler http.Handler) http.Handler {
	return h2c.NewHandler(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
		zap.S().Debugf("Got %s request", r.URL)

		if r.ProtoMajor == 2 && strings.Contains(r.Header.Get("Content-Type"), "application/grpc") {
			grpcServer.ServeHTTP(w, r)
		} else {
			httpHandler.ServeHTTP(w, r)
		}

	}), &http2.Server{})
}

func cacheUpdateLoop(ctx context.Context, wg *sync.WaitGroup, serv *server.PortostatdServer, updateInterval uint64) {
	defer wg.Done()

	for {
		select {
		case <-time.After(time.Duration(updateInterval) * time.Second):
			err := serv.PerformStatsCacheUpdate()
			if err != nil {
				zap.S().Errorf("Failed to update cache: %v", err)
			}
		case <-ctx.Done():
			return
		}
	}
}

func init() {
	rootCmd.AddCommand(startCmd)

	startCmd.Flags().Uint64Var(&cacheUpdateInterval, "cache-update-interval", 5, "metrics cache update interval (in seconds)")
	startCmd.Flags().Uint64Var(&httpServerGracefulShutdownTimeout, "http-server-graceful-shutdown-timeout", 10, "amount of time (in seconds) given to http server to shutdown gracefully")
	startCmd.Flags().BoolVar(&debug, "debug", false, "enable debug output")
}

// TODO: there is definitely a TOCTOU here
func unlinkStaleSocket() error {
	fi, err := os.Stat(PortostatdSockPath)
	if err == nil {
		/* File exists:
		 *
		 *   if (it is a socket) {
		 *       unlink it
		 *   } else {
		 *       print error and exit (we shouldn't unlink unknown files)
		 *   }
		 */
		if fi.Mode()&os.ModeSocket != 0 {
			err = os.Remove(PortostatdSockPath)
			if err != nil {
				return fmt.Errorf("failed to remove socket: %w", err)
			}

			zap.S().Info("Unlinked staled socket")
			return nil
		} else {
			return fmt.Errorf("some file already exists at path %q and isn't a socket", PortostatdSockPath)
		}
	} else {
		/* Stat() returned an error:
		 *
		 *   if (error is "no such file") {
		 *       do nothing (can create our socket later)
		 *   } else {
		 *       print error and exit (Stat() failed for other reasons)
		 *   }
		 */
		if os.IsNotExist(err) {
			return nil
		} else {
			return fmt.Errorf("failed to stat socket path: %w", err)
		}
	}
}

func setSockOwnerPerms() error {
	if err := os.Chown(PortostatdSockPath, DefaultSockOwner, DefaultSockGroup); err != nil {
		return fmt.Errorf("chown failed to set socket owner or group: %w", err)
	}

	if err := os.Chmod(PortostatdSockPath, DefaultSockPerms); err != nil {
		return fmt.Errorf("chmod failed to set socket permissions: %w", err)
	}

	return nil
}
