package main

import (
	"flag"
	"fmt"
	"log"
	"net/url"
	"os"
	"os/signal"
	"strconv"
	"syscall"
	"time"

	"code.justin.tv/devhub/lib-lifecycle/src/lifecycle"
	"code.justin.tv/devhub/e2ml/libs/logging"
	"code.justin.tv/devhub/e2ml/libs/peering"
	"code.justin.tv/devhub/e2ml/libs/peering/hardcoded"
	"code.justin.tv/devhub/e2ml/libs/session/devnull"
	"code.justin.tv/devhub/e2ml/libs/timeout"
	"code.justin.tv/devhub/e2ml/libs/websocket"
)

const (
	defaultPort      = 13000
	defaultCount     = 5
	defaultHeartbeat = "100ms"
	defaultLogLevel  = "info"
	defaultUsePing   = true
)

func checkEnv(key, def string) string {
	if value, ok := os.LookupEnv(key); ok {
		return value
	}
	return def
}

func checkInt(key string, def int64) int64 {
	if value, ok := os.LookupEnv(key); ok {
		cast, err := strconv.ParseInt(value, 10, 32)
		if err != nil {
			panicWithMessage("Illegal value for "+key+": "+value, err)
		}
		return cast
	}
	return def
}

func checkBool(key string, def bool) bool {
	if value, ok := os.LookupEnv(key); ok {
		cast, err := strconv.ParseBool(value)
		if err != nil {
			panicWithMessage("Illegal value for "+key+": "+value, err)
		}
		return cast
	}
	return def
}

type panicMessage struct {
	msg string
	err error
}

func (p *panicMessage) String() string       { return p.msg + ": " + p.err.Error() }
func panicWithMessage(msg string, err error) { panic(&panicMessage{msg, err}) }

func main() {
	port := flag.Int64("port", checkInt("FLOCK_START_PORT", defaultPort), "port for incoming connections")
	count := flag.Int64("count", checkInt("FLOCK_COUNT", defaultCount), "number of ports to use")
	heartbeatString := flag.String("heartbeat", checkEnv("FLOCK_HEARTBEAT", defaultHeartbeat), "time between messages")
	logLevel := flag.String("log", checkEnv("LOG", defaultLogLevel), "set logging level: [trace,debug,info,warning,error]")
	pingTest := flag.Bool("usePing", checkBool("FLOCK_USE_PING", defaultUsePing), "timeout instead of sending messages on heartbeat")

	flag.Parse()

	logger := loggerFuncWithLevel(*logLevel)

	heartbeat, err := time.ParseDuration(*heartbeatString)
	if err != nil {
		panicWithMessage("Illegal heartbeat", err)
	}

	mgr := lifecycle.NewManager()
	defer func() { _ = mgr.ExecuteAll() }()
	lifecycle.SetDefaultPanicReporter(func(key interface{}, err error) { logger(logging.Error, err) })
	lifecycle.SetDefaultErrorReporter(func(key interface{}, err error) {
		logger(logging.Error, fmt.Sprintf("Shutdown error for %+v: %v", key, err))
	})

	min := 10 * time.Second
	max := 25 * time.Second
	if *pingTest {
		min = time.Duration(float64(heartbeat) * 0.9)
		max = time.Duration(float64(heartbeat) * 1.1)
	}

	settings := &websocket.Settings{
		Logger:    logger,
		Lifecycle: mgr,
		Timeout:   timeout.NewPrecomputedSampler(100, min, max),
	}

	nodes := make([]string, *count)
	for i := int64(0); int(i) < len(nodes); i += 1 {
		nodes[i] = strconv.FormatInt(*port+i, 10)
	}

	list := hardcoded.NewServerList("", nodes)
	mgr.RegisterHook(list, list.Close)

	var first peering.Manager
	var last peering.Manager
	for i, n := range nodes {
		last = peering.NewManager(
			devnull.Factory,
			websocket.NewClientResolver(settings),
			websocket.NewServiceFactory(int(*port)+i, settings),
			list.WithLocal(n),
			func(name string) (*url.URL, error) { return url.Parse("ws://localhost:" + name) },
			logger,
		)
		last.Start()
		if first == nil {
			first = last
		}
	}

	sigChan := make(chan os.Signal, 1)
	signal.Notify(sigChan, os.Interrupt, syscall.SIGTERM, syscall.SIGUSR2)
	ticker := time.NewTicker(heartbeat)
	if *pingTest {
		ticker.Stop() // don't send normal heartbeats
	} else {
		defer ticker.Stop()
	}

	for {
		select {
		case sig, ok := <-sigChan:
			signal.Stop(sigChan)
			if ok {
				logger(logging.Info, "Exiting on signal:", sig)
			}
			return
		case <-ticker.C:
			// will send client -> server
			first.BroadcastText(fmt.Sprintf("%v says hi", first.LocalName()))
			// will send server -> client
			last.BroadcastBinary([]byte{1, 2, 3, 4})
		}
	}
}

func loggerFuncWithLevel(logLevelStr string) logging.Function {
	logLevel, ok := logging.ParseLevel(logLevelStr)
	if !ok {
		panic("invalid arg value: log: " + logLevelStr)
	}
	return logging.NewFilter(logLevel, log.Println).Log
}
