package dialer

import (
	"crypto/tls"
	"io/ioutil"
	"net"
	"strconv"
	"strings"
	"testing"
	"time"
)

func TestTCP(t *testing.T) {
	payload := "happy-path"

	listener := mustListenTCP(t)
	defer listener.Close()
	handleIncomingConnection(t, listener, 0, payload)

	d, err := TCP("127.0.0.1", mustGetListenerPort(t, listener))
	if err != nil {
		t.Fatalf("error mismatch: got: %v, want: no error", err)
	}

	rwc, err := d.Dial()
	if err != nil {
		t.Fatalf("dial error mismatch: got: %v, want: no error", err)
	}

	b, err := ioutil.ReadAll(rwc)
	if err != nil {
		t.Fatalf("read all error mismatch: got: %v, want: no error", err)
	}

	if string(b) != payload {
		t.Errorf("response mismatch: got: %s, want: %s", string(b), payload)
	}
}

func TestTCPDialer(t *testing.T) {
	tests := []struct {
		name            string
		netDialer       *net.Dialer
		connectionSleep time.Duration
		wantTimeout     bool
	}{
		{
			name: "nil-dialer",
		},
		{
			name:      "default-dialer",
			netDialer: &net.Dialer{},
		},
		{
			name:            "specific-dialer-with-timeout",
			netDialer:       &net.Dialer{Timeout: time.Millisecond},
			connectionSleep: time.Millisecond * 100,
			wantTimeout:     true,
		},
	}

	for _, test := range tests {
		test := test
		t.Run(test.name, func(t *testing.T) {
			listener := mustListenTCP(t)
			defer listener.Close()
			handleIncomingConnection(t, listener, test.connectionSleep, test.name)

			d, err := TCPDialer("127.0.0.1", mustGetListenerPort(t, listener), test.netDialer)
			if err != nil {
				t.Fatalf("error mismatch: got: %v, want: no error", err)
			}

			rwc, err := d.Dial()
			if err != nil {
				if test.wantTimeout {
					if !isTimeoutError(err) {
						t.Errorf("dial error mismatch: got: %v, want: timeout", err)
					}
					return
				}
				t.Fatalf("dial error mismatch: got: %v, want: no error", err)
				return
			}

			b, err := ioutil.ReadAll(rwc)
			if err != nil {
				t.Fatalf("read all error mismatch: got: %v, want: no error", err)
			}

			if string(b) != test.name {
				t.Errorf("response mismatch: got: %s, want: %s", string(b), test.name)
			}
		})
	}
}

func TestTLS(t *testing.T) {
	payload := "happy-path"

	listener := mustListenTLS(t)
	defer listener.Close()
	handleIncomingConnection(t, listener, 0, payload)

	d, err := TLS("127.0.0.1", mustGetListenerPort(t, listener), &tls.Config{InsecureSkipVerify: true})
	if err != nil {
		t.Fatalf("error mismatch: got: %v, want: no error", err)
	}

	rwc, err := d.Dial()
	if err != nil {
		t.Fatalf("dial error mismatch: got: %v, want: no error", err)
	}

	b, err := ioutil.ReadAll(rwc)
	if err != nil {
		t.Fatalf("read all error mismatch: got: %v, want: no error", err)
	}

	if string(b) != payload {
		t.Errorf("response mismatch: got: %s, want: %s", string(b), payload)
	}
}

func TestTLSDialer(t *testing.T) {
	tests := []struct {
		name            string
		config          *tls.Config
		netDialer       *net.Dialer
		connectionSleep time.Duration
		wantTimeout     bool
	}{
		{
			name:   "nil-dialer",
			config: &tls.Config{InsecureSkipVerify: true},
		},
		{
			name:      "default-dialer",
			config:    &tls.Config{InsecureSkipVerify: true},
			netDialer: &net.Dialer{},
		},
		{
			name:            "specific-dialer-with-timeout",
			config:          &tls.Config{InsecureSkipVerify: true},
			netDialer:       &net.Dialer{Timeout: time.Millisecond},
			connectionSleep: time.Millisecond * 100,
			wantTimeout:     true,
		},
	}

	for _, test := range tests {
		test := test
		t.Run(test.name, func(t *testing.T) {
			listener := mustListenTLS(t)
			defer listener.Close()
			done := handleIncomingConnection(t, listener, test.connectionSleep, test.name)
			// defer here prevents a panic from log in goroutine after caller goroutine is completed
			defer func() {
				<-done
			}()

			dialer := TLSDialer("127.0.0.1", mustGetListenerPort(t, listener), test.config, test.netDialer)

			rwc, err := dialer.Dial()
			if err != nil {
				if test.wantTimeout == isTimeoutError(err) {
					return
				}
				t.Fatalf("dial error mismatch: got: %v, want: no error", err)
			}

			b, err := ioutil.ReadAll(rwc)
			if err != nil {
				t.Fatalf("read all error mismatch: got: %v, want: no error", err)
			}

			if string(b) != test.name {
				t.Errorf("response mismatch: got: %s, want: %s", string(b), test.name)
			}
		})
	}
}

func isTimeoutError(err error) bool {
	if ne, ok := err.(net.Error); ok {
		return ne.Timeout()
	}
	return false
}

func mustListenTLS(t *testing.T) net.Listener {
	cert, err := tls.LoadX509KeyPair("testdata/server-cert.pem", "testdata/server-cert.key")
	if err != nil {
		t.Fatalf("failed to load server certs: %v", err)
	}

	listener, err := tls.Listen("tcp", "0.0.0.0:0", &tls.Config{Certificates: []tls.Certificate{cert}})
	if err != nil {
		t.Fatalf("failed to listen for tls server: %v", err)
	}

	return listener
}

func mustListenTCP(t *testing.T) net.Listener {
	listener, err := net.Listen("tcp", "0.0.0.0:0")
	if err != nil {
		t.Fatalf("failed to listen for tcp server: %v", err)
	}
	return listener
}

func mustGetListenerPort(t *testing.T, listener net.Listener) uint16 {
	parts := strings.Split(listener.Addr().String(), ":")
	v, err := strconv.Atoi(parts[len(parts)-1])
	if err != nil {
		t.Fatalf("failed to determine port of address %s: %v", listener.Addr().String(), err)
	}
	return uint16(v)
}

// handleIncomingConnection handles a single incoming connection on the listener.
// When a client connects the server waits a brief amount of time, writes the payload to the client, and closes the
// connection.
func handleIncomingConnection(t *testing.T, listener net.Listener, sleep time.Duration, payload string) chan struct{} {
	done := make(chan struct{})
	go func() {
		defer close(done)
		connection, err := listener.Accept()
		if err != nil {
			t.Logf("warning: failed to accept incoming connection: %v", err)
			return
		}
		defer connection.Close()
		t.Logf("accepted connection from %v", connection.RemoteAddr())

		t.Logf("waiting %v to allow testing of custom timeouts", sleep)
		time.Sleep(sleep)

		if _, err := connection.Write([]byte(payload)); err != nil {
			t.Logf("failed to send payload to connection: %v", err)
		}
	}()
	return done
}
