package main

import (
	"context"
	"io"
	"io/ioutil"
	"net"
	"net/http"
	"os"
	"strconv"
	"sync"
	"time"

	log "github.com/Sirupsen/logrus"
	"github.com/aws/aws-sdk-go/aws/client"
	"github.com/aws/aws-sdk-go/aws/session"
	"github.com/gorilla/handlers"
	"github.com/pkg/errors"
	"github.com/soheilhy/cmux"
	"github.com/zenazn/goji/graceful"
	"goji.io/pat"
	"google.golang.org/grpc"
	"google.golang.org/grpc/grpclog"

	"code.justin.tv/common/config"
	"code.justin.tv/dta/rockpaperscissors/client/web"
	"code.justin.tv/dta/rockpaperscissors/internal/api"
	"code.justin.tv/dta/rockpaperscissors/internal/ingestqueueconsumer"
	"code.justin.tv/dta/rockpaperscissors/internal/taskmanager"
	"code.justin.tv/foundation/twitchserver"
)

func init() {
	config.Register(map[string]string{
		"enable-ingest-queue-consumer": "true",
		"http-access-log":              "",
	})
	grpclog.SetLogger(log.StandardLogger())
	twitchserver.SetLogger(log.StandardLogger())
}

func bindOrDie(address string) net.Listener {
	lis, err := net.Listen("tcp", address)
	if err != nil {
		log.Fatal(errors.Wrapf(err, "Could not bind to port %s", address))
	}
	return lis
}

func loggedHTTPHandler(h http.Handler) http.Handler {
	var writer io.Writer
	logfile := config.Resolve("http-access-log")
	if logfile == "" {
		writer = ioutil.Discard
	} else if logfile == "-" {
		writer = os.Stdout
	} else {
		var err error
		writer, err = os.OpenFile(
			logfile, os.O_CREATE|os.O_WRONLY|os.O_APPEND, 0644)
		if err != nil {
			log.Fatal(errors.Wrapf(err, "Error opening logfile %q", logfile))
		}
	}
	return handlers.CombinedLoggingHandler(writer, h)
}

type grpcServerTask struct {
	Listener   net.Listener
	AWSSession client.ConfigProvider
	apiServer  *grpc.Server
	doneChan   chan error
}

func (t *grpcServerTask) Start() {
	var err error
	t.apiServer, err = api.NewAPIServer(t.AWSSession)
	if err != nil {
		log.Fatal(err)
	}

	t.doneChan = make(chan error)

	go func(t *grpcServerTask) {
		err := t.apiServer.Serve(t.Listener)
		log.Info("gRPC server stopped: ", err)
		t.doneChan <- err
		close(t.doneChan)
	}(t)
}

func (t *grpcServerTask) Stop() {
	t.apiServer.GracefulStop()
}

func (t *grpcServerTask) Done() <-chan error {
	return t.doneChan
}

type httpServerTask struct {
	Listener      net.Listener
	ServerConfig  *twitchserver.ServerConfig
	cancelGateway func()
	handler       http.Handler
	doneChan      chan error
}

func (t *httpServerTask) Start() {
	ctx := context.Background()
	ctx, t.cancelGateway = context.WithCancel(ctx)

	graceful.PostHook(t.cancelGateway)

	gatewayMux, err := api.NewGatewayMux(
		ctx, "localhost"+t.ServerConfig.Addr)
	if err != nil {
		log.Fatal(err)
	}

	server := twitchserver.NewServer()
	server.Handle(pat.New("/api/*"), gatewayMux)
	server.Handle(pat.New("/*"), web.NewMux())
	t.handler = loggedHTTPHandler(server)

	twitchserver.AddDefaultSignalHandlers()

	t.doneChan = make(chan error)

	go func(t *httpServerTask) {
		err := twitchserver.Serve(t.Listener, t.handler, t.ServerConfig)
		log.Info("HTTP server stopped: ", err)
		t.doneChan <- err
		close(t.doneChan)
	}(t)
}

func (t *httpServerTask) Stop() {
	twitchserver.Shutdown(30 * time.Second)
	t.cancelGateway()
}

func (t *httpServerTask) Done() <-chan error {
	return t.doneChan
}

type queueWorkerTask struct {
	AWSSession     client.ConfigProvider
	ServerConfig   *twitchserver.ServerConfig
	ingestConsumer *ingestqueueconsumer.IngestQueueConsumer
	stopOnce       sync.Once
	doneChan       chan error
}

func (t *queueWorkerTask) Start() {
	// Start background goroutine that processes the SQS queue.
	t.ingestConsumer = ingestqueueconsumer.New(t.AWSSession, "localhost"+t.ServerConfig.Addr)
	if err := t.ingestConsumer.Start(); err != nil {
		log.Fatal(err)
	}

	t.doneChan = make(chan error)

	graceful.PreHook(t.Stop)
}

func (t *queueWorkerTask) Stop() {
	t.stopOnce.Do(func() {
		t.ingestConsumer.Stop()
		close(t.doneChan)
	})
}

func (t *queueWorkerTask) Done() <-chan error {
	return t.doneChan
}

type muxedListenerTask struct {
	MuxedListener cmux.CMux
	RootListener  net.Listener
	stopOnce      sync.Once
	doneChan      chan error
}

func (t *muxedListenerTask) Start() {
	t.doneChan = make(chan error)

	go func(t *muxedListenerTask) {
		err := t.MuxedListener.Serve()
		log.Info("Connection dispatcher closed: ", err)
		t.doneChan <- err
		close(t.doneChan)
	}(t)
}

func (t *muxedListenerTask) Stop() {
	t.stopOnce.Do(func() {
		_ = t.RootListener.Close() // Ignore error
	})
}

func (t *muxedListenerTask) Done() <-chan error {
	return t.doneChan
}

func run() int {
	err := config.Parse()
	if err != nil {
		log.Fatal(err)
	}
	enableIngestQueueConsumer, err := strconv.ParseBool(
		config.MustResolve("enable-ingest-queue-consumer"))
	if err != nil {
		log.Fatalf("Can't parse enable-ingest-queue-consumer flag: %v", err)
	}

	log.Infof("Configured for environment: %s", config.Environment())

	awsSession, err := session.NewSessionWithOptions(session.Options{
		SharedConfigState: session.SharedConfigEnable,
	})
	if err != nil {
		log.Fatalf("Failed to create AWS session: %v", err)
	}

	serverConfig := twitchserver.NewConfig()

	// cmux does low-level dispatching of connections, before even the HTTP layer.
	// We use it here to mux grpc connections and regular HTTP on the same port.
	rootListener := bindOrDie(serverConfig.Addr)
	muxedListener := cmux.New(rootListener)
	grpcListener := muxedListener.Match(
		cmux.HTTP2HeaderField("content-type", "application/grpc"))
	httpListener := muxedListener.Match(cmux.HTTP1Fast())

	taskManager := taskmanager.New()

	grpcTask := &grpcServerTask{Listener: grpcListener, AWSSession: awsSession}
	defer grpcTask.Stop()
	taskManager.AddTask(grpcTask)

	httpTask := &httpServerTask{Listener: httpListener, ServerConfig: serverConfig}
	defer httpTask.Stop()
	taskManager.AddTask(httpTask)

	if enableIngestQueueConsumer {
		queueWorkerTask := &queueWorkerTask{AWSSession: awsSession, ServerConfig: serverConfig}
		defer queueWorkerTask.Stop()
		taskManager.AddTask(queueWorkerTask)
	}

	muxedListenerTask := &muxedListenerTask{MuxedListener: muxedListener, RootListener: rootListener}
	defer muxedListenerTask.Stop()
	taskManager.AddTask(muxedListenerTask)

	taskManager.StartTasks()
	taskManager.Wait()

	log.Info("Exiting...")
	return 0
}

func main() {
	os.Exit(run())
}
