package main

import (
	"context"
	"crypto/tls"
	"crypto/x509"
	"errors"
	"flag"
	"fmt"
	"io/ioutil"
	"log"
	"net"
	"os"
	"os/signal"
	"runtime/pprof"
	"strconv"
	"syscall"
	"time"

	"github.com/labstack/echo/v4"
	"github.com/labstack/echo/v4/middleware"

	"a.yandex-team.ru/library/go/maxprocs"
	"a.yandex-team.ru/security/osquery/osquery-sender/config"
	"a.yandex-team.ru/security/osquery/osquery-sender/handlers"
	"a.yandex-team.ru/security/osquery/osquery-sender/sendmgr"
	"a.yandex-team.ru/security/osquery/osquery-sender/syslogparsing"
	"a.yandex-team.ru/security/osquery/osquery-sender/util"
)

const (
	waitBalancerTimeout = time.Second * 10

	shutdownTimeout = time.Second * 10

	maxTimeoutBeforeRequest = time.Second * 5

	defaultMaxConcurrentRequests    = 250
	defaultMaxConcurrentConnections = 2500
)

func signalHandler(e *echo.Echo) {
	signalChan := make(chan os.Signal, 1)
	signal.Notify(signalChan, syscall.SIGHUP, syscall.SIGTERM, syscall.SIGUSR1)
	go util.RunWithLabels(pprof.Labels("name", "signal-handler"), func() {
		for {
			s := <-signalChan
			switch s {
			case syscall.SIGUSR1:
				f, _ := os.Create("mem.pprof")
				err := pprof.WriteHeapProfile(f)
				if err != nil {
					log.Printf("ERROR: writing heap profile: %v\n", err)
				}
				err = f.Close()
				if err != nil {
					log.Printf("ERROR: closing heap profile: %v\n", err)
				}

			case syscall.SIGTERM:
				// Wait until balancer stops opening new connections with our service.
				handlers.DisablePing()
				time.Sleep(waitBalancerTimeout)

				ctx, cancel := context.WithTimeout(context.Background(), shutdownTimeout)
				err := e.Shutdown(ctx)
				cancel()
				if err != nil {
					log.Printf("ERROR: graceful shutdown error: %v\n", err)
				}
			}
		}
	})
}

func parseTLSConfig(conf *config.SenderConfig) (certFiles []string, keyFiles []string, err error) {
	certFiles = append(certFiles, conf.TLSCertificateFile)
	certFiles = append(certFiles, conf.AdditionalTLSCertificateFiles...)
	keyFiles = append(keyFiles, conf.TLSKeyFile)
	keyFiles = append(keyFiles, conf.AdditionalTLSKeyFiles...)
	if len(certFiles) != len(keyFiles) {
		return nil, nil, errors.New("additional_keys must match additional_certs")
	}
	return
}

func startSNI(e *echo.Echo, address string, certFiles []string, keyFiles []string, maxConnectionLimit int) error {
	s := e.TLSServer
	s.TLSConfig = &tls.Config{}
	for i := 0; i < len(certFiles); i++ {
		cert, err := ioutil.ReadFile(certFiles[i])
		if err != nil {
			return err
		}
		key, err := ioutil.ReadFile(keyFiles[i])
		if err != nil {
			return err
		}
		parsed, err := tls.X509KeyPair(cert, key)
		if err != nil {
			return err
		}
		// From tls/common.go:
		//
		// > Note: if there are multiple Certificates, and they don't have the
		// > optional field Leaf set, certificate selection will incur a significant
		// > per-handshake performance cost.
		//
		// Copied from Certificate.leaf()
		parsed.Leaf, err = x509.ParseCertificate(parsed.Certificate[0])
		if err != nil {
			return err
		}
		s.TLSConfig.Certificates = append(s.TLSConfig.Certificates, parsed)
	}

	// Copied from Echo.startTLS()
	s.Addr = address
	if !e.DisableHTTP2 {
		s.TLSConfig.NextProtos = append(s.TLSConfig.NextProtos, "h2")
	}

	setConnectionLimit(e, address, maxConnectionLimit)

	return e.StartServer(s)
}

func runHealthcheckServer(conf *config.SenderConfig) {
	var port int
	if portStr := os.Getenv("HEALTHCHECK_HTTP_PORT"); portStr != "" {
		var err error
		port, err = strconv.Atoi(portStr)
		if err != nil {
			log.Fatalf("HEALTHCHECK_HTTP_PORT must be a number: '%s'\n", portStr)
		}
	} else {
		if conf.HealthcheckPort == 0 {
			log.Printf("Healthcheck port not configured, disabling")
			return
		}
		port = conf.HealthcheckPort
	}

	hc := echo.New()
	hc.GET("/", handlers.Ping)
	go util.RunWithLabels(pprof.Labels("name", "healthcheck-server"), func() {
		// Bind only to localhost, because the endpoints do not have any authorization/TLS.
		err := hc.Start(fmt.Sprintf(":%d", port))
		log.Fatalf("ERROR: running healthcheck server: %v\n", err)
	})
}

func runManagementServer(conf *config.SenderConfig, mgr *sendmgr.SendMgr) {
	if conf.ManagementPort == 0 {
		log.Printf("Management port not configured, disabling")
		return
	}

	hc := echo.New()
	hc.GET("/", handlers.Ping)
	hc.GET("/unistat", func(context echo.Context) error {
		mgr.UpdateMetrics()
		return handlers.Unistat(context)
	})
	hc.GET("/solomon", func(context echo.Context) error {
		mgr.UpdateMetrics()
		return handlers.Solomon(context)
	})
	hc.GET("/profile", func(context echo.Context) error {
		return handlers.Profile(context)
	})
	hc.GET("/heap", func(context echo.Context) error {
		return handlers.Heap(context)
	})
	hc.GET("/goroutine", func(context echo.Context) error {
		return handlers.Gorotune(context)
	})
	hc.GET("/trace", func(context echo.Context) error {
		return handlers.Trace(context)
	})
	hc.GET("/connection-peers", handlers.ConnectionPeers)
	hc.GET("/request-peers", handlers.RequestPeers)
	go util.RunWithLabels(pprof.Labels("name", "management-server"), func() {
		// Bind only to localhost, because the endpoints do not have any authorization/TLS.
		err := hc.Start(fmt.Sprintf("localhost:%d", conf.ManagementPort))
		log.Fatalf("ERROR: running management server: %v\n", err)
	})
}

func main() {
	log.SetFlags(log.Lshortfile | log.Ldate | log.Ltime)
	log.SetPrefix("[osquery-sender] ")
	if os.Getenv("ENV") == "DEPLOY" {
		maxprocs.AdjustYP()
	} else {
		maxprocs.AdjustAuto()
	}

	configFilename := flag.String("config", "", "Configuration file")
	flag.Parse()

	if *configFilename == "" {
		log.Println("Specify: --config filename")
		os.Exit(1)
	}

	conf, err := config.FromFile(*configFilename)
	if err != nil {
		log.Println(err.Error())
		os.Exit(1)
	}

	var port int
	if portStr := os.Getenv("QLOUD_HTTP_PORT"); portStr != "" {
		port, err = strconv.Atoi(portStr)
		if err != nil {
			port = 80
		}
	} else {
		port = conf.Port
		if port == 0 {
			log.Fatalf("Port must be specified")
		}
	}

	s := echo.New()
	if conf.EnableDebug {
		s.Use(middleware.Logger())
	}

	setupIPExtractor(s)

	maxConcurrentConnections := defaultMaxConcurrentConnections
	if conf.MaxConcurrentConnections != 0 {
		maxConcurrentConnections = conf.MaxConcurrentConnections
	}

	maxConcurrentRequests := defaultMaxConcurrentRequests
	if conf.MaxConcurrentRequests != 0 {
		maxConcurrentRequests = conf.MaxConcurrentRequests
	}
	s.Use(util.ConcurrentRequestLimiter(maxConcurrentRequests, maxTimeoutBeforeRequest))
	log.Printf("Max concurrent requests: %d\n", maxConcurrentRequests)

	// start
	log.Println("starting up sender part")

	mgr := sendmgr.NewSendMgr(conf)
	defer mgr.Stop()

	s.POST("/enroll", func(context echo.Context) error {
		return handlers.Enroll(context, conf, mgr)
	})
	s.POST("/", func(context echo.Context) error {
		return handlers.Enroll(context, conf, mgr)
	})
	s.POST("/logger", func(context echo.Context) error {
		return handlers.Log(context, conf, mgr)
	})
	s.POST("/log", func(context echo.Context) error {
		return handlers.Log(context, conf, mgr)
	})
	s.GET("/ping", handlers.Ping)
	s.GET("/unistat", func(context echo.Context) error {
		mgr.UpdateMetrics()
		return handlers.Unistat(context)
	})
	s.GET("/solomon", func(context echo.Context) error {
		mgr.UpdateMetrics()
		return handlers.Solomon(context)
	})

	err = startSyslogServers(conf.Syslog, conf.HostsConfig, mgr)
	if err != nil {
		log.Println(err.Error())
		os.Exit(1)
	}

	signalHandler(s)

	runHealthcheckServer(conf)
	runManagementServer(conf, mgr)

	log.Println("starting up the server")
	address := fmt.Sprintf(":%d", port)
	if conf.EnableTLS {
		certFiles, keyFiles, err := parseTLSConfig(conf)
		if err != nil {
			log.Println(err.Error())
			os.Exit(1)
		}
		err = startSNI(s, address, certFiles, keyFiles, maxConcurrentConnections)
		if err != nil {
			log.Println(err.Error())
		}
	} else {
		setConnectionLimit(s, address, maxConcurrentConnections)
		err = s.Start(address)
		if err != nil {
			log.Println(err.Error())
		}
	}

	log.Printf("shutting down")
}

func setupIPExtractor(s *echo.Echo) {
	// NOTE: We trust the request data already, no big deal if we trust one more header.
	_, ipv4AllNets, err := net.ParseCIDR("0.0.0.0/0")
	if err != nil {
		log.Fatalf("ERROR: could not parse ipv4 trusted nets: %v\n", err)
	}
	_, ipv6AllNets, err := net.ParseCIDR("::/0")
	if err != nil {
		log.Fatalf("ERROR: could not parse ipv6 trusted nets: %v\n", err)
	}
	s.IPExtractor = echo.ExtractIPFromXFFHeader(
		echo.TrustLoopback(true),
		echo.TrustIPRange(ipv4AllNets),
		echo.TrustIPRange(ipv6AllNets),
	)
}

func setConnectionLimit(s *echo.Echo, address string, maxConcurrentConnections int) {
	// We enable both concurrent *request* limiting and concurrent *connection* limiting. This allows us to survive
	// both the high number of new connections and high request rate in already established connections.
	listener, err := net.Listen(s.ListenerNetwork, address)
	if err != nil {
		log.Fatalf("ERROR: configuring listener on %s: %v", address, err)
	}
	s.Listener = util.NewLimitListener(listener, maxConcurrentConnections)
	s.TLSListener = tls.NewListener(
		util.NewLimitListener(listener, maxConcurrentConnections),
		s.TLSServer.TLSConfig)
	log.Printf("Max concurrent connections: %d\n", maxConcurrentConnections)
}

func startSyslogServers(syslogConfig *config.SyslogConfig, splunkHosts map[string]*config.HostConfig, mgr *sendmgr.SendMgr) error {
	if syslogConfig != nil {
		if _, ok := splunkHosts[syslogparsing.PseudoDstHostname]; !ok {
			return errors.New("hosts config must contain '" + syslogparsing.PseudoDstHostname + "' if syslog is enabled")
		}

		defaultEventHandler := &syslogparsing.SendMgrHandler{Mgr: mgr}
		for _, serverConfig := range syslogConfig.Servers {
			var eventHandler syslogparsing.EventHandler
			if serverConfig.EventHandler == "" {
				eventHandler = defaultEventHandler
			} else {
				eventHandler = syslogparsing.GetEventHandler(serverConfig.EventHandler)
			}
			_, err := syslogparsing.StartSysLogServer(&serverConfig, eventHandler)
			if err != nil {
				log.Fatalf(err.Error())
			}
		}
	}
	return nil
}
