package main

import (
	"context"
	"fmt"
	stdlog "log"
	"net/http"
	"os"
	"os/signal"
	"sync"
	"syscall"

	"github.com/aws/aws-sdk-go/aws"
	"github.com/aws/aws-sdk-go/aws/ec2metadata"
	"github.com/aws/aws-sdk-go/aws/endpoints"
	"github.com/aws/aws-sdk-go/aws/session"
	"github.com/aws/aws-sdk-go/service/ec2"
	awsECS "github.com/aws/aws-sdk-go/service/ecs"
	"github.com/pkg/errors"

	identifier "code.justin.tv/amzn/TwitchProcessIdentifier"
	"code.justin.tv/edge/go-statsd-proxy/hashring"
	"code.justin.tv/edge/go-statsd-proxy/internal/build"
	"code.justin.tv/edge/go-statsd-proxy/internal/ecs"
	"code.justin.tv/edge/go-statsd-proxy/internal/fingerprint"
	"code.justin.tv/edge/go-statsd-proxy/internal/health"
	"code.justin.tv/edge/go-statsd-proxy/internal/hostconfig"
	"code.justin.tv/edge/go-statsd-proxy/internal/observe"
	"code.justin.tv/edge/go-statsd-proxy/proxy"
)

func main() {
	shutdownCtx, cancel := context.WithCancel(context.Background())
	defer cancel()

	// Include information about the build in log messages.
	// This information is provided at build time using -ldflags.
	ctx := build.Context{
		Commit:  commit,
		Date:    date,
		Version: version,
	}

	// Create a structured logger that writes to the specified output and contains the build context.
	log, err := observe.NewLogger(os.Stdout.Name(), ctx)
	if err != nil {
		stdlog.Fatalf("unable to init logger: %s", err)
	}

	// exit encloses the logger to provide a convenience function for logging an error
	// then exiting the process with code 1.
	exit := func(msg string, err error) {
		log.Error(msg, err)
		os.Exit(1)
	}

	// Establish a new AWS session.
	sess, err := getSession(os.Getenv("AWS_REGION"))
	if err != nil {
		exit("unable to initialize AWS session", err)
	}

	// Grab some metadata about where the application is running on ECS infrastructure.
	metaURI := os.Getenv("ECS_CONTAINER_METADATA_URI_V4")
	meta, err := getMetadata(metaURI, sess)
	if err != nil {
		exit("unable to fetch ECS metadata", err)
	}

	// This uniquely identifies the process, and is used in logs and metrics.
	pid := identifier.ProcessIdentifier{
		Service:  os.Args[0],
		Region:   meta.Region,
		Stage:    os.Getenv("ENVIRONMENT"),
		Version:  version,
		Machine:  meta.InstanceID,
		LaunchID: meta.TaskID,
	}

	// Spin up a new TwitchTelemetry observer that flushes to CloudWatch.
	metrics, err := observe.NewMetrics(pid, sess, log)
	if err != nil {
		exit("unable to initialize metrics", err)
	}

	observer := observe.Observer{
		Logger:  log,
		Metrics: metrics,
	}

	// Parse the command-line configuration.
	c, err := parse(os.Args)
	if err != nil {
		exit("unable to parse configuration", err)
	}

	// Validate configuration
	if err := validate(c); err != nil {
		exit("invalid configuration", err)
	}

	log.Info("loaded configuration", "values", c)

	collector, err := proxy.NewUDPCollector(c.StatsPort, c.BufferPoolSize, c.BufferBytes, observer)
	if err != nil {
		exit("unable to initialize UDP collector", err)
	}

	addresses := c.Addresses

	var downstreams *downstreams
	var reconfigureMu sync.Mutex

	if c.UpstreamCluster != "" && c.UpstreamService != "" {
		reconfigure := func(hosts []string) {
			newDownstreams, err := startDownstreams(hosts, c, collector, observer)
			if err != nil {
				log.Warn("We were asked to reconfigure but could not", "err", err)
				return
			}
			reconfigureMu.Lock()
			oldDownstreams := downstreams
			downstreams = newDownstreams
			reconfigureMu.Unlock()

			// stop the previous downstreams now we've got working ones
			go oldDownstreams.Stop()
		}

		addresses, err = setupHostsFromCluster(shutdownCtx, c, observer, sess, reconfigure)
		if err != nil {
			exit("unable to initialize hosts from cluster", err)
		}
	}

	downstreams, err = startDownstreams(addresses, c, collector, observer)
	if err != nil {
		exit("could not start downstreams", err)
	}

	go collector.Run()
	go healthcheck(c.HealthPort, log)

	// Wait for a signal to exit.
	sigC := make(chan os.Signal, 1)
	signal.Notify(sigC, os.Interrupt, syscall.SIGTERM, syscall.SIGQUIT, syscall.SIGHUP)

	sig := <-sigC

	log.Info("shutting down", "signal", sig)
	cancel() // stopping the outer context gives things a chance to gracefully shut down
	reconfigureMu.Lock()
	downstreams.Stop()
	reconfigureMu.Unlock()
}

type downstreams struct {
	generation  int
	validatorWg *sync.WaitGroup
	cancel      func()
	forwarders  []proxy.Forwarder
	observer    observe.Observer
}

func (d *downstreams) Stop() {
	log := func(message string) {
		d.observer.Debug("downstreams shutdown: "+message, "downstream_gen", d.generation)
	}
	log("cancelling validators")
	d.cancel()
	log("waiting for validators to stop forwarding")
	d.validatorWg.Wait()
	log("stopping forwarders")
	for _, forwarder := range d.forwarders {
		forwarder.Stop()
	}
	log("complete")
}

var forwarderGen int

func startDownstreams(addresses []string, c configuration, collector *proxy.UDPCollector, observer observe.Observer) (*downstreams, error) {
	forwarderGen += 1
	observer.Info("Starting downstreams", "generation", forwarderGen, "hosts", fingerprint.Summary(addresses, true))

	var validatorWg sync.WaitGroup
	// Kick everything off.
	rawForwarders, err := getUDPForwarders(addresses, c.MaxPacketBytes, observer)
	if err != nil {
		return nil, errors.Wrap(err, "unable to initialize forwarders")
	}

	for _, forwarder := range rawForwarders {
		go forwarder.Run()
	}

	forwarders, err := hashring.MakeRing(rawForwarders, addresses, 20)
	if err != nil {
		return nil, errors.Wrap(err, "unable to make consistent-hashing ring")
	}

	shutdownCtx, cancel := context.WithCancel(context.Background())
	for i := 0; i < c.NumWorkers; i++ {
		v, err := proxy.NewStatsdValidator(collector, forwarders, c.Hash, observer, shutdownCtx.Done())
		if err != nil {
			cancel()
			return nil, errors.Wrap(err, "unable to initialize validator")
		}

		validatorWg.Add(1)
		go func(v proxy.Validator) {
			defer validatorWg.Done()
			v.Run()
		}(v)
	}
	return &downstreams{
		generation:  forwarderGen,
		validatorWg: &validatorWg,
		cancel:      cancel,
		forwarders:  rawForwarders,
		observer:    observer,
	}, nil
}

func getUDPForwarders(addresses []string, maxPacketSize int, o observe.Observer) ([]proxy.Forwarder, error) {
	forwarders := make([]proxy.Forwarder, len(addresses))

	for i, addr := range addresses {
		f, err := proxy.NewUDPForwarder(addr, maxPacketSize, o)
		if err != nil {
			return nil, err
		}

		forwarders[i] = f
	}

	return forwarders, nil
}

// Get an AWS session scoped to the region.
func getSession(region string) (*session.Session, error) {
	conf := &aws.Config{
		CredentialsChainVerboseErrors: aws.Bool(true),
		STSRegionalEndpoint:           endpoints.RegionalSTSEndpoint,
	}

	if region != "" {
		conf.Region = &region
	}

	return session.NewSession(conf)
}

func getMetadata(uri string, sess *session.Session) (*ecs.Metadata, error) {
	svc := ec2metadata.New(sess)
	return ecs.GetMetadata(uri, svc)
}

// Create a healthcheck endpoint that the load balancer can use to determine health.
func healthcheck(port int, log *observe.Logger) {
	// The standard Go mux is a bit goofy in how it treats paths.
	// To avoid redirects from "/path" to "/path/", we have to specify both paths.
	mux := http.NewServeMux()

	// TODO(tony): Specific just one route for the health check.
	//
	// I believe the current configuration for this application sets the
	// load balancer to check against /metrics.
	//
	// Since this simplified endpoint doesn't return any metrics, we should probably
	// call it /health or something similar.
	//
	// We specify both for now to provide an easy migration path.
	mux.HandleFunc("/metrics", health.Check)
	mux.HandleFunc("/metrics/", health.Check)
	mux.HandleFunc("/health", health.Check)
	mux.HandleFunc("/health/", health.Check)

	addr := fmt.Sprintf(":%d", port)
	if err := http.ListenAndServe(addr, mux); err != http.ErrServerClosed {
		log.Error("healthcheck listener", err)
	}
}

func setupHostsFromCluster(shutdownCtx context.Context, c configuration, observer observe.Observer, sess *session.Session, reconfigure func([]string)) ([]string, error) {
	upstreamConf := &hostconfig.ServiceConfig{
		ServiceName: c.UpstreamService,
		ClusterName: c.UpstreamCluster,
	}
	ecsClient := awsECS.New(sess)
	ec2Client := ec2.New(sess)
	addresses, err := hostconfig.GetStatsHosts(shutdownCtx, observer, ecsClient, ec2Client, upstreamConf)
	if err != nil {
		return nil, err
	}

	poller := hostconfig.Poller{
		Conf:     upstreamConf,
		EC2:      ec2Client,
		ECS:      ecsClient,
		Observer: observer,
	}
	go poller.Poll(shutdownCtx, reconfigure)

	return addresses, nil
}
