package main

import (
	"context"
	"net/http"

	"github.com/aws/aws-sdk-go/aws/session"
	"github.com/aws/aws-sdk-go/service/cloudwatch"
	"github.com/aws/aws-sdk-go/service/s3"

	"code.justin.tv/eventbus/controlplane/eventstreams"
	"code.justin.tv/eventbus/controlplane/infrastructure"
	"code.justin.tv/eventbus/controlplane/infrastructure/routing"
	rpcinfra "code.justin.tv/eventbus/controlplane/infrastructure/rpc"
	"code.justin.tv/eventbus/controlplane/internal/auditlog"
	"code.justin.tv/eventbus/controlplane/internal/autoprof"
	"code.justin.tv/eventbus/controlplane/internal/clients/cloudformation"
	"code.justin.tv/eventbus/controlplane/internal/clients/kms"
	ldapManager "code.justin.tv/eventbus/controlplane/internal/clients/ldap"
	"code.justin.tv/eventbus/controlplane/internal/clients/servicecatalog"
	"code.justin.tv/eventbus/controlplane/internal/clients/slack"
	"code.justin.tv/eventbus/controlplane/internal/clients/sns"
	"code.justin.tv/eventbus/controlplane/internal/clients/sqs"
	"code.justin.tv/eventbus/controlplane/internal/clients/sts"
	"code.justin.tv/eventbus/controlplane/internal/db"
	"code.justin.tv/eventbus/controlplane/internal/db/observability"
	"code.justin.tv/eventbus/controlplane/internal/db/postgres"
	"code.justin.tv/eventbus/controlplane/internal/e2eaccounts"
	"code.justin.tv/eventbus/controlplane/internal/environment"
	"code.justin.tv/eventbus/controlplane/internal/featureflags"
	"code.justin.tv/eventbus/controlplane/internal/ldap"
	"code.justin.tv/eventbus/controlplane/internal/logger"
	"code.justin.tv/eventbus/controlplane/internal/metrics"
	"code.justin.tv/eventbus/controlplane/internal/metrics/clients/eventstreamstats"
	"code.justin.tv/eventbus/controlplane/internal/s2s"
	"code.justin.tv/eventbus/controlplane/internal/validator"
	"code.justin.tv/eventbus/controlplane/rpc"
	"code.justin.tv/eventbus/controlplane/services"
	"code.justin.tv/eventbus/controlplane/subscriptions"
	"code.justin.tv/eventbus/controlplane/targets"
	"github.com/twitchtv/twirp"
	"go.uber.org/zap"
)

var port = ":8888"

func main() {
	ctx := context.Background()
	log := logger.FromContext(ctx)

	if environment.IsEndToEndTest() {
		err := e2eaccounts.InitializeCredentials(session.Must(session.NewSession()))
		if err != nil {
			log.Fatal("could not initialize e2e creds", zap.Error(err))
		}
	}

	config, err := environment.Read()
	if err != nil {
		log.Fatal("could not generate application config", zap.Error(err))
	}

	sess := session.Must(session.NewSession(config.AWSConfig()))

	if config.AutoprofEnabled {
		bucket := config.AutoprofBucketName
		if bucket == "" {
			log.Fatal("autoprof enabled but no bucket provided")
		}
		autoprof.Start(s3.New(sess), bucket, log)
	}

	var dbConn db.DB
	dbConn, err = dbClient(config)
	if err != nil {
		log.Fatal("could not initialize db client", zap.Error(err))
	}

	var slackConn slack.Slack
	slackConn, err = slackClient(config)
	if err != nil {
		log.Fatal("could not initialize slack client", zap.Error(err))
	}

	// Initialize metrics, if enabled
	if config.CloudwatchEnabled {
		if err := metrics.InitializeProcessIdentifier("eventbus-httpserver"); err != nil {
			log.Fatal("error initializing process identifier", zap.Error(err))
		}
		metrics.InitializeMetrics(log.Logger)
		metrics.StartGoStatsCollection(log.Logger)
		metrics.InitializeReporting(dbConn, log)
	}

	stsManager := sts.NewManager(sess, config.Environment)

	eventBusAccountID, err := stsManager.GetBaseAWSAccountID(ctx)
	if err != nil {
		log.Fatal("Could not determine aws account id", zap.Error(err))
	}

	sqsManager := sqs.NewManager(sess, stsManager)
	snsManager := sns.NewManager(sess, stsManager, config.EncryptionAtRestKeyARN, log)
	cfnManager := cloudformation.NewManager(sess, stsManager)

	queueValidator := validator.NewQueueValidator(sqsManager, eventBusAccountID, config.EncryptionAtRestKeyARN)

	infraService := &infrastructure.InfrastructureService{
		DB:                          dbConn,
		AuthorizedFieldGrantManager: kms.NewManager(sess, config.AuthorizedFieldKeyARN),
		Slack:                       slackConn,
		RouteConfigActions: &routing.ConfigActions{
			S3: s3.New(sess),
		},
		SNSManager:    snsManager,
		DisableGrants: config.DisableGrants,
		Logger:        log,
	}

	catalog := servicecatalog.New()

	eventStreamsService := &eventstreams.EventStreamsService{
		DB:             dbConn,
		ServiceCatalog: catalog,
	}
	if config.CloudwatchEnabled {
		cw := cloudwatch.New(sess)
		eventStreamsService.Stats = eventstreamstats.NewCloudwatchClient(cw)
	} else {
		eventStreamsService.Stats = &eventstreamstats.RandomClient{} // generate fake statistics
	}

	eventTypesService := &eventstreams.EventTypesService{
		DB: dbConn,
	}

	featureFlagsService := &featureflags.FeatureFlagsService{
		DB: dbConn,
	}

	auditLogsService := &auditlog.AuditLogService{
		DB: dbConn,
	}

	targetsService := &targets.TargetsService{
		DB:             dbConn,
		RootSession:    sess,
		QueueValidator: queueValidator,
		SQSManager:     sqsManager,
	}

	subscriptionsService := &subscriptions.SubscriptionsService{
		DB:         dbConn,
		SNSManager: snsManager,
	}

	servicesService := &services.ServicesService{
		DB:                      dbConn,
		ServiceCatalog:          catalog,
		EncryptionAtRestManager: kms.NewManager(sess, config.EncryptionAtRestKeyARN),
		CloudformationManager:   cfnManager,
		DisableGrants:           config.DisableGrants,
		LDAPManager:             ldapManager.New(config.LDAPURL, config.LDAPUsersBaseDN, config.LDAPGroupsBaseDN),
	}

	servicesHandler := rpc.NewServicesServer(servicesService, twirpServerHooks("Services"))
	eventStreamsHandler := rpc.NewEventStreamsServer(eventStreamsService, twirpServerHooks("EventStreams"))
	eventTypesHandler := rpc.NewEventTypesServer(eventTypesService, twirpServerHooks("EventStreams"))
	featureFlagsHandler := rpc.NewFeatureFlagsServer(featureFlagsService, twirpServerHooks("FeatureFlags"))
	auditLogsHandler := rpc.NewAuditLogsServer(auditLogsService, twirpServerHooks("AuditLogs"))
	targetsHandler := rpc.NewTargetsServer(targetsService, twirpServerHooks("Targets"))
	subscriptionsHandler := rpc.NewSubscriptionsServer(subscriptionsService, twirpServerHooks("Subscriptions"))
	infraHandler := rpcinfra.NewInfrastructureServer(infraService, twirpServerHooks("Infrastructure"))
	healthcheckHandler := func(w http.ResponseWriter, r *http.Request) {
		w.WriteHeader(http.StatusOK)
	}

	middlewares := make([]middleware, 0)
	middlewares = append(middlewares, ldap.Middleware)
	middlewares = append(middlewares, log.Middleware())
	if !config.S2SDisabled {
		s2sMiddleware, err := s2s.Middleware(log, config.S2SServiceName)
		if err != nil {
			log.Fatal("failed to initialize s2s middleware", zap.Error(err))
		}
		middlewares = append(middlewares, s2sMiddleware)
	}
	chainedMiddlewares := chainMiddlewares(middlewares...)

	mux := http.NewServeMux()
	mux.Handle(rpc.EventStreamsPathPrefix, chainedMiddlewares(eventStreamsHandler))
	mux.Handle(rpc.EventTypesPathPrefix, chainedMiddlewares(eventTypesHandler))
	mux.Handle(rpc.FeatureFlagsPathPrefix, chainedMiddlewares(featureFlagsHandler))
	mux.Handle(rpc.AuditLogsPathPrefix, chainedMiddlewares(auditLogsHandler))
	mux.Handle(rpc.ServicesPathPrefix, chainedMiddlewares(servicesHandler))
	mux.Handle(rpc.TargetsPathPrefix, chainedMiddlewares(targetsHandler))
	mux.Handle(rpc.SubscriptionsPathPrefix, chainedMiddlewares(subscriptionsHandler))
	mux.Handle(rpcinfra.InfrastructurePathPrefix, chainedMiddlewares(infraHandler))
	mux.HandleFunc("/health", healthcheckHandler)

	log.Info("Listening for requests on 0.0.0.0", zap.String("port", port))
	if err = http.ListenAndServe(port, mux); err != nil {
		log.Fatal("Listen and serve error", zap.Error(err))
	}
}

// reference: https://hackernoon.com/simple-http-middleware-with-go-79a4ad62889b

type middleware func(http.Handler) http.Handler

func chainMiddlewares(mw ...middleware) middleware {
	return func(final http.Handler) http.Handler {
		last := final
		for i := len(mw) - 1; i >= 0; i-- {
			last = mw[i](last)
		}
		return last
	}
}

func twirpServerHooks(api string) *twirp.ServerHooks {
	loggingHooks := &twirp.ServerHooks{
		Error: logger.TwirpErrorHook,
	}
	metricHooks := metrics.TwirpMiddleware(api)
	return twirp.ChainHooks(loggingHooks, metricHooks)
}

func dbClient(config environment.Config) (db.DB, error) {
	var dbConn db.DB
	var err error
	sslMode := "require"
	if config.PostgresDisableSSL {
		sslMode = "disable"
	}
	dbConn, err = postgres.New(postgres.Config{
		Reader: postgres.ConnectionConfig{
			Username: config.PostgresReaderUsername,
			Password: config.PostgresReaderPassword,
			Hostname: config.PostgresReaderHostname,
			Dbname:   config.PostgresDBName,
			Sslmode:  sslMode,
		},
		Writer: postgres.ConnectionConfig{
			Username: config.PostgresWriterUsername,
			Password: config.PostgresWriterPassword,
			Hostname: config.PostgresWriterHostname,
			Dbname:   config.PostgresDBName,
			Sslmode:  sslMode,
		},
	})
	if err != nil {
		return nil, err
	}

	dbConn = observability.WithLogging(dbConn)
	if config.CloudwatchEnabled {
		dbConn = observability.WithMetrics(dbConn)
	}

	return dbConn, nil
}

func slackClient(config environment.Config) (slack.Slack, error) {
	var slackClient slack.Slack
	if config.SlackChannel != "" {
		var err error
		slackClient, err = slack.New(&slack.Config{
			Token:   config.SlackToken,
			Channel: config.SlackChannel,
		})
		if err != nil {
			return nil, err
		}
	} else {
		slackClient = &slack.Noop{}
	}

	return slackClient, nil
}
