//go:build !race
// +build !race

package wswriter

import (
	"errors"
	"log"
	"net/http"
	"net/http/httptest"
	"strings"
	"sync"
	"testing"
	"time"

	data "code.justin.tv/amzn/streamlogclient/data/v1"
	"github.com/golang/protobuf/proto"
	"github.com/gorilla/websocket"
)

type TestServer struct {
	*httptest.Server
	wg *sync.WaitGroup
}

func (ts *TestServer) Wait() {
	ts.wg.Wait()
}

// Start a websocket test server and call 'fn' with the websocket connection
func StartWebsocketServer(fn func(*websocket.Conn)) (server *TestServer) {
	wg := &sync.WaitGroup{}
	wg.Add(1)

	ts := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
		wg.Done()

		if r.Method != "GET" {
			http.Error(w, "Method not allowed", 405)
			return
		}

		var upgrader = websocket.Upgrader{
			ReadBufferSize:  1024,
			WriteBufferSize: 1024,
			CheckOrigin: func(r *http.Request) bool {
				return true
			},
		}

		conn, err := upgrader.Upgrade(w, r, nil)
		if err != nil {
			log.Println(err)
			return
		}
		defer conn.Close()

		// Do the custom logic for test
		fn(conn)
	}))

	ts.URL = strings.Replace(ts.URL, "http", "ws", 1)
	return &TestServer{ts, wg}
}

func DecodeEvent(conn *websocket.Conn, event *data.Event) error {
	typ, buffer, err := conn.ReadMessage()
	if typ != websocket.BinaryMessage {
		return errors.New("Not a binary message")
	}
	if err != nil {
		return err
	}

	return proto.Unmarshal(buffer, event)
}

func TestSimpleStreamLog(t *testing.T) {
	numMessages := 10

	ts := StartWebsocketServer(func(conn *websocket.Conn) {
		var event data.Event

		for i := 0; i < numMessages; i++ {
			err := DecodeEvent(conn, &event)
			if err != nil {
				t.Error(err)
				t.Errorf("Only received %v of %v messages", i, numMessages)
				return
			}

			if length := len(event.Records); length != 1 {
				t.Errorf("%v records when there should be 1", length)
			}

			record := event.Records[0]

			if event.Metric != "Kappa" {
				t.Errorf("Incorrect key: %s", event.Metric)
			}

			if val := int(record.I); val != i {
				t.Errorf("Incorrect value: %v, Should be %v", val, i)
			}
		}
	})
	defer ts.Close()

	cfg := Settings{
		WebsocketEndpoint: ts.URL,
		QueueSize:         numMessages,
	}
	streamlog, err := New(cfg)
	if err != nil {
		t.Fatal(err)
	}
	defer streamlog.Close()

	// Wait until our connections have been established before signalling close
	ts.Wait()

	for i := 0; i < numMessages; i++ {
		_ = streamlog.Log("live_user_FrankerZ", "Kappa", i)
	}
}

func TestLogAfterClosed(t *testing.T) {
	ts := StartWebsocketServer(func(conn *websocket.Conn) {
		var event data.Event
		for {
			err := DecodeEvent(conn, &event)
			if err != nil {
				return
			}
		}
	})
	defer ts.Close()

	var err error
	cfg := Settings{
		WebsocketEndpoint: ts.URL,
		QueueSize:         5,
	}
	streamlog, err := New(cfg)
	if err != nil {
		t.Fatal(err)
	}
	ts.Wait()

	err = streamlog.Log("live_user_PogChamp", "Kippa", 42)
	if err != nil {
		t.Fatal(err)
	}

	streamlog.Close()

	func() {
		defer func() {
			if err := recover(); err == nil {
				t.Fatal("No error when logging to closed streamlog")
			}
		}()
		_ = streamlog.Log("live_user_PogChamp", "SwiftRage", 666)
	}()

}

func TestReconnect(t *testing.T) {
	// First server that we'll close after sending a message
	ts1 := StartWebsocketServer(func(conn *websocket.Conn) {
		var event data.Event
		err := DecodeEvent(conn, &event)
		if err != nil {
			return
		}
	})

	// We'll switch to this server after waiting a bit to simulate
	// the server coming back online
	ts2 := StartWebsocketServer(func(conn *websocket.Conn) {
		var event data.Event
		err := DecodeEvent(conn, &event)
		if err != nil {
			return
		}
		err = DecodeEvent(conn, &event)
		if err != nil {
			return
		}
	})
	cfg := Settings{
		WebsocketEndpoint: ts1.URL,
		QueueSize:         5,
	}
	streamlog, err := New(cfg)
	if err != nil {
		t.Fatal(err)
	}
	ts1.Wait()

	err = streamlog.Log("live_user_PogChamp", "Kippa", 42)
	if err != nil {
		t.Fatal(err)
	}

	ts1.Close()
	time.Sleep(time.Millisecond) // Wait a bit for everything to close
	err = streamlog.Log("live_user_PogChamp", "Kippa", 43)
	if err != nil {
		t.Fatal(err)
	}

	err = streamlog.Log("live_user_PogChamp", "Kippa", 44)
	if err != nil {
		t.Fatal(err)
	}

	// Wait some time before correcting URL
	time.Sleep(time.Second)
	streamlog.WebsocketEndpoint = ts2.URL
	ts2.Wait()

	streamlog.Close()
	ts2.Close()
}

func TestHighVolume(t *testing.T) {
	numMessages := 100000
	ts := StartWebsocketServer(func(conn *websocket.Conn) {
		for i := 0; i < numMessages; i++ {
			var event data.Event
			err := DecodeEvent(conn, &event)
			if err != nil {
				return
			}
		}
	})
	defer ts.Close()
	cfg := Settings{
		WebsocketEndpoint: ts.URL,
		QueueSize:         1000,
	}
	streamlog, err := New(cfg)
	if err != nil {
		t.Fatal(err)
	}
	defer streamlog.Close()

	ts.Wait()

	sleepTime := time.Duration(int(time.Second) / numMessages)
	for i := 0; i < numMessages; i++ {
		err := streamlog.Log("live_user_PogChamp", "Kappa", i)
		if err != nil {
			t.Error(err)
		}
		time.Sleep(sleepTime)
	}
}
