package main

import (
	"crypto/rand"
	"fmt"
	"math/big"
	"net/url"
	"os"
	"runtime/pprof"
	"time"

	"github.com/alexflint/go-arg"

	identifier "code.justin.tv/amzn/TwitchProcessIdentifier"
	"code.justin.tv/devhub/e2ml/libs/discovery/host"
	"code.justin.tv/devhub/e2ml/libs/discovery/protocol"
	"code.justin.tv/devhub/e2ml/libs/logging"
	"code.justin.tv/devhub/e2ml/libs/metrics"
	"code.justin.tv/devhub/e2ml/libs/setup"
	"code.justin.tv/devhub/e2ml/libs/stream"
	"code.justin.tv/devhub/e2ml/libs/ticket"
	"code.justin.tv/devhub/e2ml/libs/ticket/jwt"
	"code.justin.tv/devhub/e2ml/libs/timeout"
	"code.justin.tv/devhub/e2ml/libs/websocket"
	"code.justin.tv/devhub/e2ml/services/source"
	"code.justin.tv/devhub/lib-lifecycle/src/lifecycle"
)

const (
	s2sIssuer   = "eml-source"
	s2sAudience = "eml-pathfinder"
)

type Config struct {
	metrics.GoArgSource
	EnvName string `arg:"--env-name,env:ENV_NAME" help:"Current environment: [local,dev,prod]"`

	CertFile      string        `arg:"--cert-file,env:CERT_FILE" help:"(optional) file containing TLS server cert"`
	Insecure      bool          `help:"FOR LOCAL USE ONLY: allow connections to untrusted servers"`
	KeyFile       string        `arg:"--key-file,env:KEY_FILE" help:"(optional) file containing TLS server key"`
	Log           logging.Level `arg:"env" default:"info" help:"Logging level: [trace,debug,info,warning,error]"`
	MetricsHost   string        `arg:"--metrics-host,env:METRICS_HOST" default:"localhost" help:"hostname for metric reporting"`
	MetricsMethod string        `arg:"--metrics-method,env:METRICS_METHOD" default:"none" help:"method for metric collection [none,logged,twitchtelemetry]"`
	MetricsPeriod time.Duration `arg:"--metrics-period,env:METRICS_PERIOD" default:"10s" help:"period between metrics reporting ticks"`
	ProfilingFile string        `arg:"--profiling-dest-file,env:PROFILING_DEST_FILE" help:"(optional) file target for CPU profile data"`

	AuthCheckPeriod   time.Duration `arg:"--auth-check-duration,env:AUTH_CHECK_DURATION" default:"10s" help:"time between checks for credential expiration"`
	ClientPort        int           `arg:"--client-port,env:CLIENT_PORT" default:"3003" help:"port for client connections"`
	TicketDuration    time.Duration `arg:"--ticket-duration,env:TICKET_DURATION" default:"5s" help:"time that tickets issued by this instance are valid"`
	HistoryExpiration time.Duration `arg:"--history-expiration,env:HISTORY_EXPIRATION" default:"5m" help:"minimum time before an unused history expires"`

	RateLimitCount  uint8         `arg:"--rate-limit-count,env:RATE_LIMIT_COUNT" default:"10" help:"number of samples in rate limit buffer"`
	RateLimitPeriod time.Duration `arg:"--rate-limit-period,env:RATE_LIMIT_PERIOD" default:"6s" help:"shortest duration to allow rate samples to fill"`

	HostAuthMethod string             `arg:"--host-auth-method,env:HOST_AUTH_METHOD" default:"s2s" help:"authorization method for host connections"`
	HostURL        string             `arg:"--host-url,env:HOST_URL" default:"wss://pathfinder" help:"url for host connection"`
	HostName       string             `arg:"--host-name,env:HOST_NAME" default:"ws://source:3003" help:"host name this instance will report to discovery"`
	HostHeartbeat  time.Duration      `arg:"--host-heartbest,env:HOST_HEARTBEAT" default:"250ms" help:"minimum time before status reports to the host"`
	HostTimeout    time.Duration      `arg:"--host-timeout,env:HOST_TIMEOUT" default:"10s" help:"duration before a host request is assumed to have timed out"`
	HostWindowSize int                `arg:"--host-window-size,env:HOST_WINDOW_SIZE" default:"40" help:"number of samples used to track host load"`
	HostS2sSecret  stream.OpaqueBytes `arg:"--host-s2s-secret,env:HOST_S2S_SECRET" help:"secret for host s2s auth"`
}

func main() {
	var config Config
	arg.MustParse(&config)

	if config.ProfilingFile != "" {
		f, err := os.Create(config.ProfilingFile)
		if err != nil {
			setup.PanicWithMessage("Unable to start profiler", err)
		}
		if err := pprof.StartCPUProfile(f); err != nil {
			setup.PanicWithMessage("Unable to StartCPUProfile", err)
		}
		defer pprof.StopCPUProfile()
	}

	logger := setup.Logger(config.Log)

	mgr := lifecycle.NewManager()
	defer func() { _ = mgr.ExecuteAll() }()
	lifecycle.SetDefaultPanicReporter(func(key interface{}, err error) { logger(logging.Error, err) })
	lifecycle.SetDefaultErrorReporter(func(key interface{}, err error) {
		logger(logging.Error, fmt.Sprintf("Shutdown error for %+v: %v", key, err))
	})

	tPid := identifier.ProcessIdentifier{
		Machine:  config.MetricsTaskARN,
		Service:  "E2MLSource",
		Stage:    config.EnvName,
		Substage: "primary",
		Region:   "us-west-2",
	}

	metrics := setup.MetricsTracker(config.MetricsMethod, tPid, logger)
	mgr.RegisterHook(metrics, func() error {
		err := metrics.Close()
		time.Sleep(time.Second) // allow metrics to drain
		return err
	})
	mgr.TickUntilClosed(metrics.Tick, config.MetricsPeriod)
	metrics.Count("lifecycle", []string{"action:startup"}).Add(1)

	gostats := setup.GoStatsCollector(tPid, 30*time.Second, logger)
	gostats.Start()
	mgr.RegisterHook(gostats, func() error {
		gostats.Stop()
		return nil
	})

	hostUrl, err := url.Parse(config.HostURL)
	if err != nil {
		setup.PanicWithMessage("Illegal url specified for host", err)
	}

	hostName, err := url.Parse(config.HostName)
	if err != nil {
		setup.PanicWithMessage("Illegal url specified for public address", err)
	}

	tls := setup.TLSConfig(config.CertFile, config.KeyFile, config.Insecure)

	// reporter connections should close last
	hostClients := websocket.NewClientFactory(hostUrl, &websocket.Settings{
		Certs:     tls,
		Lifecycle: mgr,
		Logger:    logger,
		Timeout:   timeout.NewPrecomputedSampler(100, 10*time.Second, 1*time.Second),
	})

	// after ticket management ...
	tickets := createTicketStore(config.TicketDuration, metrics)
	mgr.RegisterHook(tickets, tickets.Close)
	mgr.TickUntilClosed(tickets.Tick, config.TicketDuration)

	// after address reporting ...
	reporter, err := host.NewReporter(
		hostClients,
		hostName,
		stream.AddressSourceMap{stream.AnyAddress.Key(): stream.None},
		tickets,
		onAddressRequest,
		setup.AuthSource("host", config.HostAuthMethod, s2sIssuer, s2sAudience, config.HostS2sSecret, logger),
		metrics,
		logger,
	)
	if err != nil {
		setup.PanicWithMessage("Unable to create reporter", err)
	}
	mgr.RegisterHook(reporter, reporter.Close)
	mgr.TickUntilClosed(reporter.Tick, config.HostTimeout)

	// after the status reporting loop has wrapped up
	status := host.NewRateLimitedStatus(reporter, protocol.Available|protocol.Source)
	mgr.RegisterHook(status, status.Close) // necessary if we panic

	// after the core logic/audience has completed
	server := source.NewServer(
		onSourceIDRequest,
		tickets,
		reporter,
		config.RateLimitCount,
		config.RateLimitPeriod,
		config.HistoryExpiration,
		config.AuthCheckPeriod,
		metrics,
		logger,
	)
	mgr.RegisterHook(server, server.Shutdown)
	mgr.TickUntilClosed(server.Tick, config.AuthCheckPeriod)

	// after the status reporting loop has wrapped up
	status.Start(server.LoadFactor, config.HostWindowSize)
	mgr.TickUntilClosed(status.Tick, config.HostHeartbeat)

	// after the audience listener has stopped
	logic, err := websocket.NewServiceFactory(config.ClientPort, &websocket.Settings{
		Certs:     tls,
		Lifecycle: mgr, // allow clients to drain, but stop logic on exit
		Logger:    logger,
		Timeout:   timeout.NewPrecomputedSampler(100, 10*time.Second, 1*time.Second),
	})(server.Factory())
	if err == nil {
		err = logic.Start()
	}
	if err != nil {
		setup.PanicWithMessage("Unable to create serice", err)
	}

	logger(logging.Info, "Exiting on signal:", mgr.ListenForInterrupt())
	metrics.Count("lifecycle", []string{"action:drain"}).Add(1)

	// mark ourselves unavailable, gracefully close discovery
	_ = mgr.ExecuteHook(status)
	logger(logging.Info, "Dropping from discovery")
	_ = mgr.ExecuteHook(reporter)

	// HACK: wait for discovery peers to get update
	time.Sleep(2 * time.Second)

	// notify clients they need to reconnect, stop accepting connections
	_ = mgr.ExecuteHook(server)
	logic.Stop()
	logger(logging.Info, "Draining client connections")
	logic.WaitForDrainingConnections(time.Now().Add(10 * time.Second))
	metrics.Count("lifecycle", []string{"action:shutdown"}).Add(1)
}

func createTicketStore(duration time.Duration, metrics metrics.Tracker) ticket.Store {
	secret := make([]byte, 64)
	if _, err := rand.Read(secret); err != nil {
		setup.PanicWithMessage("Unable to seed jwt secret", err)
	}
	return ticket.NewStore(jwt.NewHS512Factory(secret), duration, metrics)
}

// source returns exact matches for incoming addresses to give traffic the best chance of being
// balanced.
func onAddressRequest(addr stream.Address, creds stream.Credentials) (stream.AddressScopes, error) {
	if creds.CanListen(addr) || creds.CanSend(addr) {
		return stream.AddressScopes{addr}, nil
	}
	return nil, stream.ErrForbidden
}

func onSourceIDRequest() stream.SourceID {
	max := big.NewInt(int64(1<<32 - 1))
	val, err := rand.Int(rand.Reader, max)
	if err != nil || !val.IsUint64() {
		setup.PanicWithMessage("Unable to generate source id", err)
	}
	return stream.SourceID(val.Uint64())
}
