package dnscache

import (
	"context"
	"sync"
	"testing"
	"time"

	. "github.com/smartystreets/goconvey/convey"
)

func TestDNSCacheClose(t *testing.T) {
	Convey("Create Cache with testHook to test ... ", t, func() {
		closec := make(chan struct{}, 1)
		cache := &Cache{
			testOnClose: func() {
				closec <- struct{}{}
			},
		}
		So(cache.Close(), ShouldResemble, errCacheRefreshLoopNotStarted)
		So(cache.isClosed(), ShouldEqual, false)
		cache.start()

		Convey("Closing cache second time should return an error", func() {
			err := cache.Close()
			So(err, ShouldBeNil)
			So(cache.isClosed(), ShouldEqual, true)
			select {
			case <-closec:
			case <-time.After(defaultRefreshInterval):
				t.Error("Timed out waiting for cache updateLoop to close")
			}

			err = cache.Close()
			So(err, ShouldResemble, errCacheAlreadyClosed)
			So(cache.isClosed(), ShouldEqual, true)
		})

		Convey("Close should be goroutine safe", func() {
			So(cache.isClosed(), ShouldEqual, false)
			var err1, err2 error
			// race two goroutines to close the cache
			var wg sync.WaitGroup
			wg.Add(2)
			go func() {
				defer wg.Done()
				err1 = cache.Close()
			}()

			go func() {
				defer wg.Done()
				err2 = cache.Close()
			}()

			<-closec
			wg.Wait() // avoid race on reading/writing err{1|2}

			So(cache.isClosed(), ShouldEqual, true)

			if err2 == nil {
				So(err1, ShouldResemble, errCacheAlreadyClosed)
			}

			if err1 == nil {
				So(err2, ShouldResemble, errCacheAlreadyClosed)
			}
			So(err1, ShouldNotResemble, err2)
		})

		Convey("Dial succeeds even if cache is closed", func() {
			So(cache.isClosed(), ShouldEqual, false)
			err := cache.Close()
			<-closec
			So(err, ShouldBeNil)
			So(cache.isClosed(), ShouldEqual, true)
			conn, err := cache.DialContext(context.Background(), "tcp", "google.com:80")
			So(err, ShouldBeNil)
			So(conn, ShouldNotBeNil)
			_ = conn.Close()
		})
	})
}

func TestDNSCacheOnErrCallback(t *testing.T) {
	Convey("OnErr called if...  ", t, func() {
		errCh := make(chan error, 1)
		cache := &Cache{
			OnErr: func(err error, host string) {
				errCh <- err
			},
		}
		cache.start()
		Reset(func() { _ = cache.Close() })

		Convey("host lookup fails", func() {
			host := "twitch.impossibleTLD"
			conn, err := cache.DialContext(context.Background(), "tcp", host+":http")
			So(conn, ShouldBeNil)
			So(err, ShouldNotBeNil)
			err = <-errCh
			So(err, ShouldNotBeNil)
		})

		Convey("all cached ips are bad", func() {
			host := "twitch.tv"
			insertHostWithIPs(cache, host, []string{"127.0.0.1:6234"})
			conn, err := cache.DialContext(context.Background(), "tcp", host+":http")
			errCb := <-errCh
			So(errCb, ShouldNotBeNil)

			Convey("verify hail mary worked", func() {
				defer func() { _ = conn.Close() }()
				So(conn, ShouldNotBeNil)
				So(err, ShouldBeNil)
			})
		})
	})
}

func TestDNSCacheUpsert(t *testing.T) {
	Convey("Cache is Updated ....", t, func() {
		upsertCh := make(chan string, 1) // host that was upsert
		updateDuration := 500 * time.Millisecond
		slack := 150 * time.Millisecond
		cache := &Cache{
			Every: updateDuration,
			OnCacheUpsert: func(host string, oldAddrs []string, newAddrs []string, lookupTime time.Duration) {
				upsertCh <- host
			},
		}
		cache.start()
		Reset(func() { _ = cache.Close() })

		Convey("cache.Every seconds", func() {
			host := "twitch.tv"
			insertHostWithIPs(cache, host, []string{"1.1.1.1"})
			select {
			case <-time.After(updateDuration + slack):
				t.Errorf("cache was not updated within %s", updateDuration.String())
			case h := <-upsertCh:
				So(h, ShouldEqual, host)
			}
		})

		Convey("On a cache miss", func() {
			conn, err := cache.DialContext(context.Background(), "tcp", "google.com:https")
			So(err, ShouldBeNil)
			So(conn, ShouldNotBeNil)
			defer func() { _ = conn.Close() }()

			select {
			case host := <-upsertCh:
				So(host, ShouldEqual, "google.com")
			case <-time.After(updateDuration):
				t.Error("timed out waiting for OnCacheUpsert callback to trigger")
			}
		})

		Convey("Only on successful host to ip lookup", func() {
			host := "twitch.impossibleTLD"
			testIPs := []string{"1.1.1.1"}
			insertHostWithIPs(cache, host, testIPs)
			select {
			case <-time.After(2 * updateDuration):
				ips, err := cache.lookup(context.Background(), host)
				So(err, ShouldBeNil)
				So(ips, ShouldResemble, testIPs)
			case h := <-upsertCh:
				t.Errorf("host = %s, should not trigger OnCacheUpsert callback", h)
			}
		})
	})
}

func insertHostWithIPs(c *Cache, host string, ips []string) {
	c.mu.Lock()
	defer c.mu.Unlock()
	c.cache[host] = ips
}
