package handlers

import (
	"sync"
	"testing"
	"time"

	"code.justin.tv/websocket-edge/server/internal/gqlsubs"
	"code.justin.tv/websocket-edge/server/internal/logs"
	"code.justin.tv/websocket-edge/server/internal/metrics"
	"code.justin.tv/websocket-edge/server/internal/queue"
	"code.justin.tv/websocket-edge/server/protocol"
)

// This mockConn calls wg.Done() on GracefulClose, but that's it.
type wellBehavedConn struct{}

func (m *wellBehavedConn) Forward(wm protocol.ServiceToClientMessage) error { return nil }
func (m *wellBehavedConn) GracefulClose(wg *sync.WaitGroup)                 { wg.Done() }
func (m *wellBehavedConn) SendSessionID() error                             { return nil }

// This mockConn never returns on GracefulClose.
type neverCloseConn struct{}

func (m *neverCloseConn) Forward(wm protocol.ServiceToClientMessage) error { return nil }
func (m *neverCloseConn) GracefulClose(wg *sync.WaitGroup) {
	c := make(chan bool)
	<-c
}
func (m *neverCloseConn) SendSessionID() error { return nil }

func TestHandler(t *testing.T) {
	t.Run("blocks on conn close", func(t *testing.T) {
		conn := &neverCloseConn{}

		handlerShutodwnDone := make(chan interface{})
		h, startCleanup := New(&logs.Noop{}, &metrics.Noop{}, &queue.Noop{}, &gqlsubs.Noop{}, handlerShutodwnDone)
		h.addConnection("abcd", conn)
		startCleanup <- true

		// This should timeout since the conn never returns from GracefulClose()
		select {
		case <-handlerShutodwnDone:
			t.FailNow()
		case <-time.After(100 * time.Millisecond):
			return
		}
	})

	t.Run("cleans up all connections", func(t *testing.T) {
		c1 := &wellBehavedConn{}
		c2 := &wellBehavedConn{}

		handlerShutodwnDone := make(chan interface{})
		h, startCleanup := New(&logs.Noop{}, &metrics.Noop{}, &queue.Noop{}, &gqlsubs.Noop{}, handlerShutodwnDone)
		h.addConnection("abcd", c1)
		h.addConnection("efgh", c2)

		startCleanup <- true
		// The fact that it finished indicates that GracefulClose was called and the wait group finished.
		select {
		case <-handlerShutodwnDone:
			return
		case <-time.After(100 * time.Millisecond):
			t.FailNow()
		}
	})
}
