package main

import (
	"bytes"
	"context"
	"crypto/tls"
	"crypto/x509"
	"flag"
	"fmt"
	"io/ioutil"
	"log"
	"net"
	"net/http"
	"os"
	"strings"
	"time"

	goji "goji.io"
	"goji.io/pat"

	"github.com/aws/aws-sdk-go/aws"
	"github.com/aws/aws-sdk-go/aws/credentials"
	"github.com/aws/aws-sdk-go/aws/ec2metadata"
	"github.com/aws/aws-sdk-go/aws/session"
	"github.com/aws/aws-sdk-go/service/s3"
	"github.com/aws/aws-sdk-go/service/s3/s3iface"
	"github.com/cactus/go-statsd-client/statsd"
	"github.com/pkg/errors"

	"code.justin.tv/common/gometrics"
	"code.justin.tv/video/lvsapi/internal/auth"
	"code.justin.tv/video/lvsapi/internal/awsutils"
	"code.justin.tv/video/lvsapi/internal/caching"
	"code.justin.tv/video/lvsapi/internal/digestion"
	"code.justin.tv/video/lvsapi/internal/logging"
	"code.justin.tv/video/lvsapi/internal/metrics"
	"code.justin.tv/video/lvsapi/internal/server"
	"code.justin.tv/video/lvsapi/internal/usher"
	"code.justin.tv/video/lvsapi/internal/utils"
	"code.justin.tv/video/lvsapi/internal/viewcounts"
	"code.justin.tv/video/lvsapi/streamkey"
	usherrpc "code.justin.tv/video/usherapi/rpc/usher"
)

var (
	port                      = flag.String("port", "8080", "api service port")
	tlsPort                   = flag.String("tls-port", "8443", "port over which to serve tls requests")
	clientCert                = flag.String("client-cert", "", "path to client tls certificate")
	serverCert                = flag.String("server-cert", "", "path to server tls public certificate ")
	serverKey                 = flag.String("server-key", "", "path to server tls private key")
	strictAuth                = flag.Bool("strict-auth", true, "strict-auth disables auth via http headers, and only accepts mutual tls auth")
	secretS3Bucket            = flag.String("secret-s3-bucket", "", "s3 bucket to use for streamkey secret storage")
	secretS3Prefix            = flag.String("secret-s3-prefix", "", "s3 bucket prefix to use when determining streamkey secret key names")
	secretS3Duration          = flag.Duration("secret-s3-duration", 60*time.Second, "specifies how long to cache secrets in between requests to s3")
	twitchEnv                 = flag.String("twitch-env", "production", "production/staging/dev")
	cacheExpiration           = flag.Duration("cache-expiration", 5*time.Second, "specifies how long to cache usher response")
	purgeInterval             = flag.Duration("purge-interval", 60*time.Minute, "interval to periodically cleanup expired entries from cache")
	statsdServer              = flag.String("statsd-server", "graphite.internal.justin.tv:8125", "The address of the statsd server where we want to send stats")
	disableMWS                = flag.Bool("disable-mws", false, "when set to true metrics will not be sent to MWS")
	digestionEndpoint         = flag.String("digestion-endpoint", "digestion.video.justin.tv,digestion-cmh01.video.justin.tv", "Comma-separated URLs for the Digestion service which manages stream state")
	viewcountsCacheExpiration = flag.Duration("viewcounts-cache-expiration", 30*time.Second, "specifies how long to cache viewcounts response")
	viewCountsApiBaseUrl      = flag.String("viewcounts-base-url", "", "specifies the base url for the viewcounts api")
)

var envMap = map[string]string{
	"port":                        "LVSAPI_PORT",
	"tls-port":                    "LVSAPI_TLS_PORT",
	"client-cert":                 "LVSAPI_CLIENT_CERT",
	"server-cert":                 "LVSAPI_SERVER_CERT",
	"server-key":                  "LVSAPI_SERVER_KEY",
	"strict-auth":                 "LVSAPI_STRICT_AUTH",
	"secret-s3-bucket":            "LVSAPI_SECRET_S3_BUCKET",
	"secret-s3-prefix":            "LVSAPI_SECRET_S3_PREFIX",
	"secret-s3-duration":          "LVSAPI_SECRET_S3_DURATION",
	"twitch-env":                  "TWITCH_ENV",
	"cache-expiration":            "LVSAPI_CACHE_EXPIRATION",
	"purge-interval":              "LVSAPI_PURGE_INTERVAL",
	"statsd-server":               "LVSAPI_STATSD_SERVER",
	"disable-mws":                 "LVSAPI_DISABLE_MWS",
	"digestion-endpoint":          "DIGESTION_HOST",
	"viewcounts-cache-expiration": "VIEWCOUNTS_CACHE_EXPIRATION",
	"viewcounts-base-url":         "VIEWCOUNTS_API_BASE_URL",
}

var usherBackends = []string{
	"http://video-api.sjc02.justin.tv",
	"http://video-api.cmh01.justin.tv",
}

var s3api s3iface.S3API = s3.New(session.New())

func mustNotErr(err error) {
	if err != nil {
		panic(err)
	}
}

func enableTLS() (bool, error) {
	// either all flags or no flags must be set
	if len(*clientCert) == 0 && len(*serverCert) == 0 && len(*serverKey) == 0 {
		return false, nil
	}

	if len(*clientCert) != 0 && len(*serverCert) != 0 && len(*serverKey) != 0 {
		return true, nil
	}

	return false, fmt.Errorf("client-cert, server-cert, server-key options must all be set, or all be unset")
}

func configureServer(h http.Handler) *http.Server {
	if !*strictAuth {
		// set up the http only auth mechanism (auth from headers)
		h = auth.FromHeader(h)
	}

	return &http.Server{
		Addr:    net.JoinHostPort("", *port),
		Handler: h,
	}
}

func configureTLSServer(h http.Handler) (*http.Server, error) {
	clientPEM, err := ioutil.ReadFile(*clientCert)
	if err != nil {
		return nil, errors.Wrapf(err, "failed to read specified client cert: %s", *clientCert)
	}
	clientCAs := x509.NewCertPool()
	if ok := clientCAs.AppendCertsFromPEM(clientPEM); !ok {
		return nil, fmt.Errorf("Invalid client cert specified: %s", *clientCert)
	}

	cert, err := tls.LoadX509KeyPair(*serverCert, *serverKey)
	if err != nil {
		return nil, errors.Wrap(err, "failed to read server tls keypair")
	}

	tlsConfig := &tls.Config{
		Certificates: []tls.Certificate{cert},
		ClientCAs:    clientCAs,
		ClientAuth:   tls.VerifyClientCertIfGiven,
	}
	tlsConfig.BuildNameToCertificate()

	// set up the https only auth mechanism (auth from certificate)
	h = auth.FromCertificate(h)

	return &http.Server{
		Addr:      net.JoinHostPort("", *tlsPort),
		Handler:   h,
		TLSConfig: tlsConfig,
	}, nil
}

func awsSession() (*session.Session, error) {
	metaSession, err := session.NewSession()
	if err != nil {
		return nil, errors.Wrap(err, "failed to create aws session")
	}
	metaClient := ec2metadata.New(metaSession)

	// we explicitly ignore this error because if this call errors we're just
	// not running on an EC2 instance
	region, _ := metaClient.Region()

	sess, err := session.NewSession(&aws.Config{Region: aws.String(region)})
	if err != nil {
		return nil, errors.Wrap(err, "failed to create aws session")
	}

	return sess, nil
}

func configureHandler(logger logging.Logger, s statsd.Statter) http.Handler {
	mux := goji.NewMux()

	mux.Use(logging.Middleware(logger))

	// simple "is this server on" healthcheck
	mux.Handle(pat.Get("/debug/health"), http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
		_, _ = w.Write([]byte("We Gucci"))
	}))

	sess, sessErr := awsSession()
	var secretSource streamkey.SecretSource

	if *secretS3Bucket == "" {
		if *twitchEnv == "production" || *twitchEnv == "production-updated" || *twitchEnv == "staging" {
			logger.Fatal("S3 Bucket for streamkey secrets missing")
		}

		logger.Debug("Autogenerating secret as no s3 bucket was specified")
		secret, err := streamkey.GenerateSecret()
		if err != nil {
			logger.WithError(err).Fatal("failed to generate secret")
		}

		secretSource = &streamkey.SingleSecretSource{Secret: secret}
	} else {
		if sessErr != nil {
			logger.WithError(sessErr).Fatal("failed to create aws session")
		}
		secretSource = streamkey.NewS3SecretSource(streamkey.S3SecretSourceConfig{
			S3:           s3.New(sess),
			Bucket:       *secretS3Bucket,
			Prefix:       *secretS3Prefix,
			CacheTimeout: *secretS3Duration,
		})
	}

	// Read the credentials from the env or from the session if there is one
	creds := credentials.NewSharedCredentials("", "")
	validSession := sess != nil && sess.Config != nil && sess.Config.Credentials != nil
	if validSession {
		creds = sess.Config.Credentials
	}
	m := metrics.New("lvsapi", "", *twitchEnv, creds)
	if !*disableMWS && validSession {
		m.Start()
	}

	var usherClients []usher.Client
	for _, backend := range usherBackends {
		usherClient, err := usherrpc.NewUsherClient(backend, http.DefaultClient)
		if err != nil {
			logger.WithError(err).Fatal("failed to create usher twirp client")
		}
		usherClients = append(usherClients, usherClient)
	}

	// all other paths go to pkg server's handler
	viewcountsBackend := viewcounts.NewInMemoryViewcountsCache(viewcountsCacheExpiration, purgeInterval, *viewCountsApiBaseUrl, s)
	mux.Handle(pat.New("/*"), server.New(secretSource, caching.NewInMemoryCache(cacheExpiration, purgeInterval, usher.NewUsher(usherClients, s), s), awsutils.NewS3RegionChecker(), s, m, digestion.NewDigestionAPI(*digestionEndpoint), viewcountsBackend))

	return mux
}

func runServer(s *http.Server) chan error {
	ch := make(chan error, 1)
	go func() {
		defer close(ch)
		if s.TLSConfig != nil {
			ch <- s.ListenAndServeTLS("", "")
		} else {
			ch <- s.ListenAndServe()
		}
	}()
	return ch
}

func mustParseFlags() {
	flag.Parse()
	missing, err := utils.ValidateFlags(envMap)
	mustNotErr(err)

	*twitchEnv = strings.ToLower(*twitchEnv)
	if len(missing) != 0 && (*twitchEnv == "production" || *twitchEnv == "production-updated" || *twitchEnv == "staging") {
		panic(fmt.Errorf("Missing flags: %v", missing))
	}
}

func reqLogger(base http.Handler) http.Handler {
	return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
		body, err := ioutil.ReadAll(r.Body)
		if err != nil {
			// Broken HTTP connection
			w.WriteHeader(500)
			return
		}
		log.Printf("request headers: %+v", r.Header)
		log.Printf("request body: %s", body)

		r.Body = ioutil.NopCloser(bytes.NewBuffer(body))
		base.ServeHTTP(w, r)
	})
}

func main() {
	mustParseFlags()

	logger := logging.New(logging.Config{})
	logger.Info("LiveVideoService API")

	// Create the statsd client and pass it around to the configureHandler
	s, err := metrics.NewStatter(fmt.Sprintf("lvsapi.%s", *twitchEnv), *statsdServer)
	if err != nil {
		logger.WithError(err).Errorf("Failed to connect to statsd: %v", err)
	}

	// Collect Golang metrics periodically
	gometrics.Monitor(s, 10*time.Second)

	// handler to be used by both http and https listener
	handler := reqLogger(configureHandler(logger, s))

	// Add security response headers
	handler = auth.AddSecurityHeader(handler)

	// set up our servers
	var httpServer, httpsServer *http.Server
	var httpErrCh, httpsErrCh chan error

	httpServer = configureServer(handler)
	httpErrCh = runServer(httpServer)

	if ok, err := enableTLS(); err != nil {
		logger.WithError(err).Fatal("invalid tls options specified")
	} else if ok {
		httpsServer, err = configureTLSServer(handler)
		if err != nil {
			logger.WithError(err).Fatal("failed to configure https server")
		}
		httpsErrCh = runServer(httpsServer)
	}

	select {
	case err := <-httpsErrCh:
		logger.WithError(err).Error("https listener failed")

		// shut down the http server if the https server fails
		ctx, cancel := context.WithTimeout(context.Background(), 10*time.Second)
		defer cancel()

		// attempt to shut the other server down gracefully before bailing
		if err = httpServer.Shutdown(ctx); err != nil {
			logger.WithError(err).Error("failed to shut down http server gracefully")
		}

		os.Exit(1)
	case err := <-httpErrCh:
		logger.WithError(err).Error("http listener failed")

		// shut down the https server if the http server fails
		if httpsServer != nil {
			ctx, cancel := context.WithTimeout(context.Background(), 10*time.Second)
			defer cancel()

			if err = httpsServer.Shutdown(ctx); err != nil {
				logger.WithError(err).Error("failed to shut down https server gracefully")
			}
		}
		os.Exit(1)
	}
}
