package main

import (
	"context"
	"crypto/rand"
	"fmt"
	"os"
	"runtime/pprof"
	"time"

	"github.com/alexflint/go-arg"
	"github.com/aws/aws-sdk-go/aws"
	"github.com/aws/aws-sdk-go/aws/session"
	awsS3 "github.com/aws/aws-sdk-go/service/s3"

	identifier "code.justin.tv/amzn/TwitchProcessIdentifier"
	"code.justin.tv/amzn/TwitchProcessIdentifier/expvars"
	"code.justin.tv/devhub/lib-lifecycle/src/lifecycle"
	"code.justin.tv/video/autoprof/profs3"

	"code.justin.tv/devhub/e2ml/libs/discovery/host"
	"code.justin.tv/devhub/e2ml/libs/discovery/protocol"
	"code.justin.tv/devhub/e2ml/libs/logging"
	"code.justin.tv/devhub/e2ml/libs/metrics"
	"code.justin.tv/devhub/e2ml/libs/setup"
	"code.justin.tv/devhub/e2ml/libs/stream"
	"code.justin.tv/devhub/e2ml/libs/stream/registry"
	"code.justin.tv/devhub/e2ml/libs/ticket"
	"code.justin.tv/devhub/e2ml/libs/ticket/counting"
	"code.justin.tv/devhub/e2ml/libs/ticket/jwt"
	"code.justin.tv/devhub/e2ml/libs/timeout"
	"code.justin.tv/devhub/e2ml/libs/websocket"
	"code.justin.tv/devhub/e2ml/services/threshold"
)

const (
	s2sClientAudience = "eml-pathfinder"
	s2sIssuer         = "eml-threshold"
	s2sHostAudience   = "eml-greeter"
)

type Config struct {
	metrics.GoArgSource
	EnvName string `arg:"--env-name,env:ENV_NAME" help:"Current environment: [local,dev,prod]"`

	CertFile      string        `arg:"--cert-file,env:CERT_FILE" help:"(optional) file containing TLS server cert"`
	Insecure      bool          `help:"FOR LOCAL USE ONLY: allow connections to untrusted servers"`
	KeyFile       string        `arg:"--key-file,env:KEY_FILE" help:"(optional) file containing TLS server key"`
	Log           logging.Level `arg:"env" default:"info" help:"Logging level: [trace,debug,info,warning,error]"`
	MetricsHost   string        `arg:"--metrics-host,env:METRICS_HOST" default:"localhost" help:"hostname for metric reporting"`
	MetricsMethod string        `arg:"--metrics-method,env:METRICS_METHOD" default:"none" help:"method for metric collection [none,logged,twitchtelemetry]"`
	MetricsPeriod time.Duration `arg:"--metrics-period,env:METRICS_PERIOD" default:"10s" help:"period between metrics reporting ticks"`
	ProfilingFile string        `arg:"--profiling-dest-file,env:PROFILING_DEST_FILE" help:"(optional) file target for CPU profile data"`

	AuthCheckPeriod   time.Duration      `arg:"--auth-check-duration,env:AUTH_CHECK_DURATION" default:"10s" help:"time between checks for credential expiration"`
	AuthWarningPeriod time.Duration      `arg:"--auth-warning-duration,env:AUTH_WARNING_DURATION" default:"5s" help:"advance notice given before a kick due to auth expiration"`
	ClientTimeout     time.Duration      `arg:"--client-timeout,env:CLIENT_TIMEOUT" default:"10s" help:"time before discovery requests timeout"`
	ClientAuthMethod  string             `arg:"--client-auth-method,env:CLIENT_AUTH_METHOD" default:"s2s" help:"authorization method for client connections"`
	ClientURL         setup.URL          `arg:"--client-url,env:CLIENT_URL" default:"https://pathfinder:8001" help:"url for client connection"`
	ClientPort        int                `arg:"--client-port,env:CLIENT_PORT" default:"3002" help:"port for client connections"`
	ClientS2sSecret   stream.OpaqueBytes `arg:"--client-s2s-secret,env:CLIENT_S2S_SECRET" help:"secret for host s2s auth"`

	TicketMethod     string        `arg:"--ticket-method,env:TICKET_METHOD" default:"jwt" help:"method used for connection tickets [counter, jwt]"`
	TicketDuration   time.Duration `arg:"--ticket-duration,env:TICKET_DURATION" default:"5s" help:"time that tickets issued by this instance are valid"`
	MaxMessageLength uint16        `arg:"--max-message-length" default:"10240" help:"maximum allowed message length including headers"`

	HostAuthMethod string             `arg:"--host-auth-method,env:HOST_AUTH_METHOD" default:"s2s" help:"authorization method for host connections"`
	HostURL        setup.URL          `arg:"--host-url,env:HOST_URL" default:"wss://greeter:3010" help:"url for host connection"`
	HostName       setup.URL          `arg:"--host-name,env:HOST_NAME" default:"ws://threshold:3002" help:"host name this instance will report to discovery"`
	HostHeartbeat  time.Duration      `arg:"--host-heartbest,env:HOST_HEARTBEAT" default:"250ms" help:"minimum time before status reports to the host"`
	HostTimeout    time.Duration      `arg:"--host-timeout,env:HOST_TIMEOUT" default:"10s" help:"duration before a host request is assumed to have timed out"`
	HostWindowSize int                `arg:"--host-window-size,env:HOST_WINDOW_SIZE" default:"40" help:"number of samples used to track host load"`
	HostS2sSecret  stream.OpaqueBytes `arg:"--host-s2s-secret,env:HOST_S2S_SECRET" help:"secret for host s2s auth"`

	AutoprofBucket string `arg:"--autoprof-bucket,env:AUTOPROF_BUCKET" help:"S3 bucket name to upload autoprof profiles"`
}

func main() {
	var config Config
	arg.MustParse(&config)

	if config.ProfilingFile != "" {
		f, err := os.Create(config.ProfilingFile)
		if err != nil {
			setup.PanicWithMessage("Unable to start profiler", err)
		}
		if err := pprof.StartCPUProfile(f); err != nil {
			setup.PanicWithMessage("Unable to StartCPUProfile", err)
		}
		defer pprof.StopCPUProfile()
	}

	logger := setup.Logger(config.Log)

	mgr := lifecycle.NewManager()
	defer func() { _ = mgr.ExecuteAll() }()
	lifecycle.SetDefaultPanicReporter(func(key interface{}, err error) { logger(logging.Error, err) })
	lifecycle.SetDefaultErrorReporter(func(key interface{}, err error) {
		logger(logging.Error, fmt.Sprintf("Shutdown error for %+v: %v", key, err))
	})

	tPid := identifier.ProcessIdentifier{
		Machine:  config.MetricsTaskARN,
		Service:  "E2MLThreshold",
		Stage:    config.EnvName,
		Substage: "primary",
		Region:   "us-west-2",
	}

	metrics := setup.MetricsTracker(config.MetricsMethod, tPid, logger)
	mgr.RegisterHook(metrics, func() error {
		err := metrics.Close()
		time.Sleep(time.Second) // allow metrics to drain
		return err
	})
	mgr.TickUntilClosed(metrics.Tick, config.MetricsPeriod)
	metrics.Count("lifecycle", []string{"action:startup"}).Add(1)

	gostats := setup.GoStatsCollector(tPid, 30*time.Second, logger)
	gostats.Start()
	mgr.RegisterHook(gostats, func() error {
		gostats.Stop()
		return nil
	})

	if config.AutoprofBucket != "" {
		logger(logging.Info, "Setup autoprof")
		expvars.Publish(&tPid) // each bundle includes the contents of /debug/vars, used on the app name (differentiates "dev" from "prod")
		go func() {
			sess, _ := session.NewSession(&aws.Config{
				Region: aws.String("us-west-2"),
			})

			autoprofCollector := profs3.Collector{
				S3:       awsS3.New(sess),
				S3Bucket: config.AutoprofBucket,
				OnError: func(err error) error {
					logger(logging.Error, "profs3.Collector.Run:", err.Error())
					return nil
				},
			}
			err := autoprofCollector.Run(context.Background())
			if err != nil {
				logger(logging.Error, "failed to run autoprof collector:", err.Error())
			}
		}()
	}

	tls := setup.TLSConfig(config.CertFile, config.KeyFile, config.Insecure)

	// host connections should close last
	hostClients := websocket.NewClientFactory(setup.Unwrap(config.HostURL), &websocket.Settings{
		Certs:     tls,
		Lifecycle: mgr,
		Logger:    logger,
		Timeout:   timeout.NewPrecomputedSampler(100, 10*time.Second, 1*time.Second),
	})

	// after registry connections ...
	clientClients := websocket.NewClientResolver(&websocket.Settings{
		Certs:     tls,
		Lifecycle: mgr,
		Logger:    logger,
		Timeout:   timeout.NewPrecomputedSampler(100, 10*time.Second, 1*time.Second),
	})

	// after the registry ...
	reg := registry.NewRemote( // connections will be checked by pathfinder
		setup.AuthSource("client", config.ClientAuthMethod, s2sIssuer, s2sClientAudience, config.ClientS2sSecret, logger),
		threshold.NewHTTPBroker(setup.Unwrap(config.ClientURL), tls, logger),
		clientClients,
		config.ClientTimeout,
		metrics,
		logger,
	)
	mgr.RegisterHook(reg, reg.Close)

	// after the ticket system ...
	logger(logging.Info, "Granting tickets in", config.TicketMethod, "mode")
	tickets := createTicketStore(config.TicketMethod, config.TicketDuration, metrics)
	mgr.RegisterHook(tickets, tickets.Close)
	mgrTicketsTicker := mgr.TickUntilClosed(tickets.Tick, config.TicketDuration)

	// after audience reporting ...
	reporter, err := host.NewReporter(
		hostClients,
		setup.Unwrap(config.HostName),
		stream.AddressSourceMap{stream.AnyAddress.Key(): stream.None},
		tickets,
		onAddressRequest,
		setup.AuthSource("host", config.HostAuthMethod, s2sIssuer, s2sHostAudience, config.HostS2sSecret, logger),
		metrics,
		logger,
	)
	if err != nil {
		setup.PanicWithMessage("Unable to create reporter", err)
	}
	mgr.RegisterHook(reporter, reporter.Close)
	mgrReporterTicker := mgr.TickUntilClosed(reporter.Tick, config.HostTimeout)

	// after the status reporting loop has wrapped up
	status := host.NewRateLimitedStatus(reporter, protocol.Available)
	mgr.RegisterHook(status, status.Close) // necessary if we panic

	// support both validation and redemption at this point; choose one or the
	// other for threshold in the future
	redeemer := ticket.NewCompositeRedeemer(map[stream.AuthMethod]ticket.Redeemer{
		stream.Validation:  reporter,
		stream.Reservation: tickets,
	})

	// after the core logic/audience has completed
	server := threshold.NewServer(reg, redeemer, config.AuthWarningPeriod, metrics, logger)
	mgr.RegisterHook(server, server.Shutdown)
	mgrSeverTicker := mgr.TickUntilClosed(server.Tick, config.AuthCheckPeriod)

	// after the status reporting loop has wrapped up
	status.Start(server.LoadFactor, config.HostWindowSize)
	mgrStatusTicker := mgr.TickUntilClosed(status.Tick, config.HostHeartbeat)

	// after the audience listener has stopped
	logic, err := websocket.NewServiceFactory(config.ClientPort, &websocket.Settings{
		Certs:            tls,
		Lifecycle:        mgr,
		Logger:           logger,
		Timeout:          timeout.NewPrecomputedSampler(100, 10*time.Second, 1*time.Second),
		MaxMessageLength: int(config.MaxMessageLength),
	})(server.Factory())
	if err == nil {
		err = logic.Start()
	}
	if err != nil {
		setup.PanicWithMessage("Unable to start service", err)
	}

	// Wait for SIGTERM signal
	sig := mgr.ListenForInterrupt()
	logger(logging.Info, "Exiting on signal:", sig)
	metrics.Count("lifecycle", []string{"action:drain"}).Add(1)

	// mark ourselves unavailable, gracefully close discovery
	logger(logging.Warning, ">>> Exiting ... mark ourselves unavailable, gracefully close discovery")
	mgrStatusTicker.Close()
	mgrReporterTicker.Close()
	_ = mgr.ExecuteHook(status)
	_ = mgr.ExecuteHook(reporter)

	// HACK: wait for discovery peers to get update
	logger(logging.Warning, ">>> Exiting ... wait 2 seconds for discovery peers to get update")
	time.Sleep(2 * time.Second)

	// notify clients they need to reconnect, stop accepting connections
	logger(logging.Warning, ">>> Exiting ... close server")
	mgrSeverTicker.Close()
	_ = mgr.ExecuteHook(server)
	mgrTicketsTicker.Close()

	logger(logging.Warning, ">>> Exiting ... calling logic.Stop and wait for draining connections")
	logic.Stop()
	logic.WaitForDrainingConnections(time.Now().Add(10 * time.Second))
	metrics.Count("lifecycle", []string{"action:shutdown"}).Add(1)
	logger(logging.Info, "Exited drain completed after signal:", sig)
	logger(logging.Warning, ">>> Exiting ... done")
}

func createTicketStore(method string, duration time.Duration, metrics metrics.Tracker) ticket.Store {
	switch method {
	case "jwt":
		// secret should be unique to this instance and is never shared
		secret := make([]byte, 64)
		if _, err := rand.Read(secret); err != nil {
			setup.PanicWithMessage("Unable to seed jwt secret", err)
		}
		return ticket.NewStore(jwt.NewHS512Factory(secret), duration, metrics)
	case "counting":
		return ticket.NewStore(counting.NewFactory(), duration, metrics)
	}
	panic("Unknown ticket method")
}

// if the user has asked for something they actually have permission to use,
// threshold returns any address so that each client uses exactly 1 socket.
func onAddressRequest(addr stream.Address, creds stream.Credentials) (stream.AddressScopes, error) {
	if creds.CanListen(addr) || creds.CanSend(addr) {
		return stream.AddressScopes{stream.AnyAddress}, nil
	}
	return nil, stream.ErrForbidden
}
