package pipeclient

import (
	"context"
	"io"
	"net"
	"net/http"
	"net/http/httptest"
	"sync"
	"testing"

	"code.justin.tv/rhys/nursery/cmd/multicp/conn"
	"code.justin.tv/rhys/nursery/cmd/multicp/netpipe"
	"code.justin.tv/rhys/nursery/cmd/multicp/pipeserver"
	"golang.org/x/net/nettest"
)

func TestClientWrite(t *testing.T) {
	fn := func() (c1, c2 net.Conn, stop func(), err error) {
		ctx := context.Background()
		ctx, cancel := context.WithCancel(ctx)

		var wg sync.WaitGroup

		const target = "target"

		conns := make(chan net.Conn, 1)
		server := &pipeserver.Server{
			NewConn: func(ctx context.Context, target string, c net.Conn) error {
				conns <- c
				return nil
			},
		}
		client := &Client{dest: server}

		pipe, err := client.Open(ctx, target)
		if err != nil {
			cancel()
			return nil, nil, nil, err
		}
		c := <-conns

		np1, np2 := net.Pipe()

		writer1 := &conn.TimeoutWriter{}
		wg.Add(1)
		go func() {
			defer wg.Done()
			defer cancel()
			writer1.WriteTo(pipe)
			pipe.Close()
		}()

		// client's net.Conn
		c1 = &conn.ReadWriter{
			Reader:  np1,
			Writer:  writer1,
			Closers: []io.Closer{np1, pipe},
			Local:   np1.LocalAddr(),
			Remote:  np1.RemoteAddr(),
		}

		reader2 := &conn.TimeoutReader{}
		wg.Add(1)
		go func() {
			defer wg.Done()
			defer cancel()
			reader2.ReadFrom(c)
			pipe.Close()
		}()

		// server's net.Conn
		c2 = &conn.ReadWriter{
			Reader:  reader2,
			Writer:  np2,
			Closers: []io.Closer{np2, c},
			Local:   np2.LocalAddr(),
			Remote:  np2.RemoteAddr(),
		}

		stop = func() {
			cancel()
			for _, c := range []io.Closer{np1, np2, c1, c2} {
				wg.Add(1)
				go func(c io.Closer) { defer wg.Done(); c.Close() }(c)
			}
			wg.Wait()
		}

		return c1, c2, stop, nil
	}

	t.Run("a", func(t *testing.T) {
		nettest.TestConn(t, fn)
	})
	t.Run("b", func(t *testing.T) {
		nettest.TestConn(t, func() (c1, c2 net.Conn, stop func(), err error) {
			c1, c2, stop, err = fn()
			return c2, c1, stop, err
		})
	})
}

func TestClient(t *testing.T) {
	fn := func() (c1, c2 net.Conn, stop func(), err error) {
		ctx := context.Background()
		ctx, cancel := context.WithCancel(ctx)

		var wg sync.WaitGroup

		const target = "target"

		conns := make(chan net.Conn, 1)
		server := &pipeserver.Server{
			NewConn: func(ctx context.Context, target string, c net.Conn) error {
				conns <- c
				return nil
			},
		}
		client := &Client{dest: server}

		pipe, err := client.Open(ctx, target)
		if err != nil {
			cancel()
			return nil, nil, nil, err
		}
		c1 = PipeConn(pipe)
		c2 = <-conns

		stop = func() {
			cancel()
			for _, c := range []io.Closer{c1, c2} {
				wg.Add(1)
				go func(c io.Closer) { defer wg.Done(); c.Close() }(c)
			}
			wg.Wait()
		}

		return c1, c2, stop, nil
	}

	t.Run("a", func(t *testing.T) {
		nettest.TestConn(t, fn)
	})
	t.Run("b", func(t *testing.T) {
		nettest.TestConn(t, func() (c1, c2 net.Conn, stop func(), err error) {
			c1, c2, stop, err = fn()
			return c2, c1, stop, err
		})
	})
}

func TestClientNet(t *testing.T) {
	fn := func() (c1, c2 net.Conn, stop func(), err error) {
		ctx := context.Background()
		ctx, cancel := context.WithCancel(ctx)

		var wg sync.WaitGroup

		const target = "target"

		conns := make(chan net.Conn, 1)
		server := &pipeserver.Server{
			NewConn: func(ctx context.Context, target string, c net.Conn) error {
				conns <- c
				return nil
			},
		}

		mux := http.NewServeMux()
		twirpServer := netpipe.NewNetPipeServer(server, nil)
		mux.Handle(netpipe.NetPipePathPrefix, twirpServer)
		srv := httptest.NewServer(mux)

		client := &Client{dest: netpipe.NewNetPipeProtobufClient(srv.URL, http.DefaultClient)}

		pipe, err := client.Open(ctx, target)
		if err != nil {
			cancel()
			return nil, nil, nil, err
		}
		c1 = PipeConn(pipe)
		c2 = <-conns

		stop = func() {
			cancel()
			for _, c := range []io.Closer{c1, c2} {
				wg.Add(1)
				go func(c io.Closer) { defer wg.Done(); c.Close() }(c)
			}
			wg.Wait()
			srv.Close()
		}

		return c1, c2, stop, nil
	}

	t.Run("a", func(t *testing.T) {
		nettest.TestConn(t, fn)
	})
	t.Run("b", func(t *testing.T) {
		nettest.TestConn(t, func() (c1, c2 net.Conn, stop func(), err error) {
			c1, c2, stop, err = fn()
			return c2, c1, stop, err
		})
	})
}
