package handlers

import (
	"encoding/json"
	"fmt"
	"net"
	"net/http"
	"net/http/pprof"
	"sync"
	"time"

	"github.com/gofrs/uuid"
	"github.com/gorilla/websocket"
	"go.uber.org/zap"

	"code.justin.tv/edge/network"
	"code.justin.tv/websocket-edge/server/internal/environment"
	"code.justin.tv/websocket-edge/server/internal/gqlsubs"
	"code.justin.tv/websocket-edge/server/internal/logs"
	"code.justin.tv/websocket-edge/server/internal/metrics"
	"code.justin.tv/websocket-edge/server/internal/queue"
	"code.justin.tv/websocket-edge/server/internal/wsconnection"
	"code.justin.tv/websocket-edge/server/protocol"
)

var upgrader = wsconnection.Uwrapper{
	Upgrader: &websocket.Upgrader{
		// Browsers provide an origin header. When an origin header is present in addition to a host, gorilla WS will
		// by default disallow the upgrade if the two don't match. Because of the large number of potential origins
		// we follow the same approach as graphQL and allow connections from any origin.
		CheckOrigin: func(r *http.Request) bool {
			return true
		},
	},
}

var ok = []byte("OK")

const (
	metricAcceptConnDuration        = "AcceptConnDuration"
	metricAcceptConnErr             = "AcceptConnError"
	metricAcceptConnSuccess         = "AcceptConnSuccess"
	metricCloseConnSuccess          = "CloseConnSuccess"
	metricDisconnectNotificationErr = "DisconnectNotificationErr"
	metricEcho                      = "Echo"
	metricPayloadIngestDuration     = "PayloadIngestDuration"
	metricPayloadIngestErr          = "PayloadIngestError"
	metricPayloadIngestSuccess      = "PayloadIngestSuccess"
	metricHealthErr                 = "HealthCheckErr"
)

type handler struct {
	logger      logs.Logger
	statter     metrics.Statter
	hostIP      string
	conns       *connectionMap
	mux         *http.ServeMux
	disconnectQ queue.Queue
	subsClient  gqlsubs.Client
	shutdown    chan interface{}
}

type connectionMap struct {
	conns map[string]wsconnection.Connection
	sync.RWMutex
}

// Creates a new handler.
// Returns a channel that, upon SIGINT or SIGTERM, will send `true` when all connections have been closed or shutdownTimeout seconds have passed.
func New(
	logger logs.Logger,
	statter metrics.Statter,
	disconnectQ queue.Queue,
	subsClient gqlsubs.Client,
	handlerShutdownDone chan interface{}) (*handler, chan interface{}) {
	hostAddr := environment.HostAddress()
	if hostAddr == "" {
		logger.Fatal("could not get instance metadata")
	}

	h := &handler{
		logger:  logger,
		statter: statter,
		hostIP:  hostAddr,
		conns: &connectionMap{
			conns: make(map[string]wsconnection.Connection),
		},
		mux:         http.NewServeMux(),
		disconnectQ: disconnectQ,
		subsClient:  subsClient,
		shutdown:    make(chan interface{}),
	}

	h.handle("/v1/connect", h.acceptWebsocketConnection)
	h.handle("/wsecho", h.echo)
	h.handle("/v1/send", h.payloadIngest)
	h.handle("/health", h.health)
	h.handle("/debug/pprof/", pprof.Index)
	h.handle("/debug/pprof/cmdline", pprof.Cmdline)
	h.handle("/debug/pprof/profile", pprof.Profile)
	h.handle("/debug/pprof/symbol", pprof.Symbol)
	h.handle("/debug/pprof/trace", pprof.Trace)

	h.statter.StartPoller("NumWebsocketConnections", h.countConnections, 10*time.Second)

	go func() {
		<-h.shutdown
		h.cleanupAllConns()

		fmt.Printf("connection cleanup done")
		handlerShutdownDone <- "ok"
	}()

	return h, h.shutdown
}

func (h *handler) countConnections() float64 {
	h.conns.RLock()
	defer h.conns.RUnlock()

	numConns := len(h.conns.conns)
	return float64(numConns)
}

// Iterates over all connections and attempts to gracefully close them.
// Blocks until all connections are closed, or timeout has passed.
func (h *handler) cleanupAllConns() {
	wg := &sync.WaitGroup{}
	h.conns.Lock()
	defer h.conns.Unlock()
	for _, conn := range h.conns.conns {
		wg.Add(1)
		go conn.GracefulClose(wg)
	}
	wg.Wait()
	h.logger.Info("cleaned up all conns")
}

func (h *handler) ServeHTTP(w http.ResponseWriter, r *http.Request) {
	h.mux.ServeHTTP(w, r)
}

func (h *handler) handle(path string, f http.HandlerFunc) {
	h.mux.HandleFunc(path+"/", f)
	h.mux.HandleFunc(path, f)
}

func (h *handler) acceptWebsocketConnection(w http.ResponseWriter, r *http.Request) {
	start := time.Now()

	id, err := uuid.NewV4()
	if err != nil {
		h.logger.Error("could not create sessID", zap.Error(err))
		http.Error(w, "internal server error", http.StatusInternalServerError)
		h.statter.IncrementErr(metricAcceptConnErr, "UUIDGeneration")
		return
	}
	sessID := id.String()

	connLogger := h.logger.With(zap.String("sessID", sessID))
	conn, err := upgrader.Upgrade(w, r, nil)
	if err != nil {
		connLogger.Error("error upgrading connection to WebSocket", zap.Error(err))
		h.statter.IncrementErr(metricAcceptConnErr, "UpgradeConn")
		return
	}

	clientIP, err := network.ClientIP(r, []*net.IPNet{})
	if err != nil {
		connLogger.Error("Could not parse IP", zap.Error(err))
		clientIP = net.IP{}
	}

	wsConn := wsconnection.New(
		conn,
		sessID,
		clientIP,
		h.cleanupHook(sessID),
		connLogger,
		h.statter,
		h.subsClient)
	h.addConnection(sessID, wsConn)

	h.statter.Increment(metricAcceptConnSuccess)
	h.statter.Timing(metricAcceptConnDuration, time.Since(start))
}

func (h *handler) addConnection(sessID string, conn wsconnection.Connection) {
	h.conns.Lock()
	defer h.conns.Unlock()
	h.conns.conns[sessID] = conn
}

// Removes a connection from the local map and emits disconnect event.
func (h *handler) cleanupHook(sessID string) func() {
	return func() {
		h.conns.Lock()
		delete(h.conns.conns, sessID)
		h.conns.Unlock()

		err := h.disconnectQ.EmitDisconnectEvent(sessID)
		if err != nil {
			h.logger.Error("Error emitting disconnect event", zap.Error(err), zap.String("sessID", sessID))
			h.statter.Increment(metricDisconnectNotificationErr)
		}

		h.statter.Increment(metricCloseConnSuccess)
	}
}

func (h *handler) payloadIngest(w http.ResponseWriter, r *http.Request) {
	start := time.Now()

	var wm protocol.ServiceToClientMessage
	decoder := json.NewDecoder(r.Body)
	err := decoder.Decode(&wm)
	if err != nil {
		h.logger.Error("JSON decode error", zap.Error(err))
		http.Error(w, "JSON decode error", http.StatusBadRequest)
		h.statter.IncrementErr(metricPayloadIngestErr, "DecodeJSON")
		return
	}

	h.conns.RLock()
	target, ok := h.conns.conns[wm.SessionID]
	h.conns.RUnlock()
	if !ok {
		h.logger.Error("connection not found", zap.String("sessID", wm.SessionID))
		http.Error(w, "connection not found", http.StatusNotFound)
		h.statter.IncrementErr(metricPayloadIngestErr, "ConnNotFound")
		return
	}

	err = target.Forward(wm)
	if err != nil {
		h.logger.Error("could not forward message", zap.Error(err))
		http.Error(w, "could not forward message", http.StatusInternalServerError)
		h.statter.IncrementErr(metricPayloadIngestErr, "ForwardingMsg")
		return
	}

	h.statter.Increment(metricPayloadIngestSuccess)
	h.statter.Timing(metricPayloadIngestDuration, time.Since(start))
}

// simple HTTP health check. Responds OK even if things are very not OK.
func (h *handler) health(w http.ResponseWriter, r *http.Request) {
	_, err := w.Write(ok)
	if err != nil {
		h.statter.Increment(metricHealthErr)
	}
}

// echo simply echos back the messages it receives. Useful as a connectivity check.
// Should not be exposed to the world for obvious reasons.
func (h *handler) echo(w http.ResponseWriter, r *http.Request) {
	h.statter.Increment(metricEcho)
	c, err := upgrader.Upgrade(w, r, nil)
	if err != nil {
		h.logger.Error("upgrade:", zap.Error(err))
		return
	}
	defer c.Close()
	for {
		mt, message, err := c.ReadMessage()
		if err != nil {
			h.logger.Error("recv:", zap.Error(err))
			break
		}
		h.logger.Debug("recv", zap.String("msg", string(message)))
		err = c.WriteMessage(mt, message)
		if err != nil {
			h.logger.Error("write:", zap.Error(err))
			break
		}
	}
}
