package conn

import (
	"context"
	"io"
	"net"
	"sync"
	"testing"
	"time"

	"golang.org/x/net/nettest"
	"golang.org/x/sync/errgroup"
)

func makePipe(t *testing.T) nettest.MakePipe {
	return func() (c1, c2 net.Conn, stop func(), err error) {
		a1 := &net.IPAddr{IP: net.IPv4(127, 0, 0, 1)}
		a2 := &net.IPAddr{IP: net.IPv4(127, 0, 0, 2)}

		conn1 := &ReadWriter{Reader: new(TimeoutReader), Writer: new(TimeoutWriter), Local: a1, Remote: a2}
		conn2 := &ReadWriter{Reader: new(TimeoutReader), Writer: new(TimeoutWriter), Local: a2, Remote: a1}

		ctx, cancel := context.WithCancel(context.Background())

		var wg sync.WaitGroup

		connect := func(t *testing.T, a, b *ReadWriter) {
			pr, pw := io.Pipe()

			wg.Add(1)
			go func() {
				defer wg.Done()
				defer cancel()
				a.Reader.(io.ReaderFrom).ReadFrom(pr)
				pw.Close()
			}()

			wg.Add(1)
			go func() {
				defer wg.Done()
				defer cancel()
				b.Writer.(io.WriterTo).WriteTo(pw)
				pw.Close()
			}()

			wg.Add(1)
			go func() {
				defer wg.Done()
				defer cancel()
				<-ctx.Done()
				pw.Close()
			}()
		}

		connect(t, conn1, conn2)
		connect(t, conn2, conn1)

		stop = func() {
			cancel()
			conn1.Close()
			conn2.Close()
			wg.Wait()
		}

		return conn1, conn2, stop, nil
	}
}

func TestConn(t *testing.T) {
	nettest.TestConn(t, makePipe(t))
}

func TestNothing(t *testing.T) {
	fn := func() (c1, c2 net.Conn, stop func(), err error) {
		c1, c2 = net.Pipe()
		stop = func() {
			c1.Close()
			c2.Close()
		}
		return c1, c2, stop, nil
	}

	nettest.TestConn(t, fn)
}

func TestReaderConn(t *testing.T) {
	fn := func() (c1, c2 net.Conn, stop func(), err error) {
		c1, c2 = net.Pipe()
		c1 = newReaderConn(c1)
		stop = func() {
			c1.Close()
			c2.Close()
		}
		return c1, c2, stop, nil
	}

	t.Run("a", func(t *testing.T) {
		nettest.TestConn(t, fn)
	})
	t.Run("b", func(t *testing.T) {
		nettest.TestConn(t, func() (c1, c2 net.Conn, stop func(), err error) {
			c1, c2, stop, err = fn()
			return c2, c1, stop, err
		})
	})
}

func TestWriterConn(t *testing.T) {
	fn := func() (c1, c2 net.Conn, stop func(), err error) {
		c1, c2 = net.Pipe()
		c1 = newWriterConn(c1)
		stop = func() {
			c1.Close()
			c2.Close()
		}
		return c1, c2, stop, nil
	}

	t.Run("a", func(t *testing.T) {
		nettest.TestConn(t, fn)
	})
	t.Run("b", func(t *testing.T) {
		nettest.TestConn(t, func() (c1, c2 net.Conn, stop func(), err error) {
			c1, c2, stop, err = fn()
			return c2, c1, stop, err
		})
	})
}

type readerConn struct {
	reader TimeoutReader
	net.Conn
}

func newReaderConn(c net.Conn) *readerConn {
	rc := &readerConn{Conn: c}
	go rc.reader.ReadFrom(c)
	return rc
}

func (rc *readerConn) SetDeadline(t time.Time) error {
	e1 := rc.SetReadDeadline(t)
	e2 := rc.SetWriteDeadline(t)
	if e1 != nil {
		return e1
	}
	return e2
}

func (rc *readerConn) SetReadDeadline(t time.Time) error {
	return rc.reader.SetReadDeadline(t)
}

func (rc *readerConn) Close() error {
	var eg errgroup.Group
	eg.Go(rc.reader.Close)
	eg.Go(rc.Conn.Close)
	return eg.Wait()
}

func (rc *readerConn) Read(p []byte) (int, error) {
	return rc.reader.Read(p)
}

type writerConn struct {
	writer TimeoutWriter
	net.Conn
}

func newWriterConn(c net.Conn) *writerConn {
	wc := &writerConn{Conn: c}
	go wc.writer.WriteTo(c)
	return wc
}

func (wc *writerConn) SetDeadline(t time.Time) error {
	e1 := wc.SetReadDeadline(t)
	e2 := wc.SetWriteDeadline(t)
	if e1 != nil {
		return e1
	}
	return e2
}

func (wc *writerConn) SetWriteDeadline(t time.Time) error {
	return wc.writer.SetWriteDeadline(t)
}

func (wc *writerConn) Close() error {
	var eg errgroup.Group
	eg.Go(wc.writer.Close)
	eg.Go(wc.Conn.Close)
	return eg.Wait()
}

func (wc *writerConn) Write(p []byte) (int, error) {
	return wc.writer.Write(p)
}

func TestCloseRace(t *testing.T) {
	t.Run("TimeoutReader", func(t *testing.T) {
		for i := 0; i < 100; i++ {
			var r TimeoutReader
			t := time.Now()
			var eg errgroup.Group
			eg.Go(r.Close)
			eg.Go(func() error { return r.SetReadDeadline(t) })
			_ = eg.Wait()
		}
	})
	t.Run("TimeoutWriter", func(t *testing.T) {
		for i := 0; i < 100; i++ {
			var r TimeoutWriter
			t := time.Now()
			var eg errgroup.Group
			eg.Go(r.Close)
			eg.Go(func() error { return r.SetWriteDeadline(t) })
			_ = eg.Wait()
		}
	})
}
