package dns

import (
	"sync"
	"testing"
	"time"

	"code.justin.tv/devhub/e2ml/libs/peering"
	"github.com/stretchr/testify/assert"
)

type testSource struct {
	closed int
	added  int
}

func (t *testSource) create(name string) peering.ClosableServerList {
	return newList(t, name)
}

func (t *testSource) onAdded(listener peering.Listener) {
	t.added += 1
}

func (t *testSource) onClosed(*serverList) error {
	t.closed += 1
	return nil
}

type testListener struct {
	found map[string]struct{}
	mutex sync.Mutex
	wg    sync.WaitGroup
}

func newTestListener() *testListener {
	return &testListener{found: make(map[string]struct{})}
}

func (t *testListener) OnPeerAdded(name string) {
	t.mutex.Lock()
	t.found[name] = struct{}{}
	t.wg.Done()
	t.mutex.Unlock()
}

func (t *testListener) OnPeerRemoved(name string) {
	t.mutex.Lock()
	delete(t.found, name)
	t.wg.Done()
	t.mutex.Unlock()
}

func TestServerList(t *testing.T) {
	t.Run("should forward items automatically", func(t *testing.T) {
		list := NewServerList("localhost", "localhost", time.Hour)
		x := newTestListener()
		x.wg.Add(1)
		list.AddListener(x)
		x.wg.Wait()
		assert.Len(t, x.found, 1) // IPv4
		assert.NoError(t, list.Close())
	})

	t.Run("should update local name correctly", func(t *testing.T) {
		list := NewServerList("localhost", "localhost", time.Hour)
		assert.Equal(t, "localhost", list.LocalName())
		other := list.WithLocal("other")
		assert.Equal(t, "other", other.LocalName())
		other.Close()
		assert.Equal(t, "localhost", list.LocalName())
		list.Close()
	})

	t.Run("should notify listeners correctly on insert", func(t *testing.T) {
		src := &testSource{}
		list := src.create("localhost")
		x := newTestListener()
		list.AddListener(x)
		assert.Equal(t, 1, src.added)
	})

	t.Run("should notify listeners correctly on update", func(t *testing.T) {
		src := createSource("localhost")
		list := src.create("localhost")
		x := newTestListener()
		x.wg.Add(2)
		src.update([]string{"a", "b"})
		list.AddListener(x)
		assert.Len(t, x.found, 2)
		x.wg.Add(2)
		src.update([]string{})
		assert.Empty(t, x.found)
	})

	t.Run("should notify src correctly on close", func(t *testing.T) {
		src := &testSource{}
		list := src.create("localhost")
		list.Close()
		assert.Equal(t, 1, src.closed)
	})

	t.Run("should forward adds/removes to listeners", func(t *testing.T) {
		src := &testSource{}
		list := newList(src, "localhost")
		x := newTestListener()
		list.AddListener(x)
		x.wg.Add(1)
		list.onAdded("localhost")
		list.onAdded("value")
		assert.Equal(t, map[string]struct{}{"value": struct{}{}}, x.found)
		x.wg.Add(1)
		list.onRemoved("localhost")
		list.onRemoved("value")
		assert.Empty(t, x.found)

		list.RemoveListener(x)
		list.onAdded("value")
		assert.Empty(t, x.found)
	})

	t.Run("should notify src correctly on close", func(t *testing.T) {
		src := &testSource{}
		list := src.create("localhost")
		list.Close()
		assert.Equal(t, 1, src.closed)
	})
}
