package pipeserver

import (
	"bytes"
	"context"
	"io"
	"net"
	"testing"
	"time"

	"code.justin.tv/rhys/nursery/cmd/multicp/netpipe"
	"github.com/golang/protobuf/ptypes"
)

func TestWrite(t *testing.T) {
	t.Skip("flaky")

	// httpSrv := httptest.NewUnstartedServer(http.DefaultServeMux)
	// httpSrv.EnableHTTP2 = true
	// httpSrv.StartTLS()
	// defer httpSrv.Close()

	listener, err := net.Listen("tcp", "localhost:0")
	if err != nil {
		t.Fatalf("Listen; err = %v", err)
	}

	accept := make(chan net.Conn)
	go func() {
		defer close(accept)
		for {
			conn, err := listener.Accept()
			if err != nil {
				return
			}
			accept <- conn
		}
	}()

	addr := listener.Addr()
	dial := func(ctx context.Context, target string) (io.ReadWriteCloser, error) {
		conn, err := (&net.Dialer{}).DialContext(ctx, addr.Network(), addr.String())
		return conn, err
	}

	srv := &Server{
		dial:  dial,
		pipes: make(map[string]*pipe),
	}
	var client netpipe.NetPipe = srv

	ctx := context.Background()
	ctx, cancel := context.WithCancel(ctx)
	defer cancel()

	pipe, err := client.CreatePipe(ctx, &netpipe.CreatePipeRequest{
		Target:            "foo",
		KeepaliveDuration: ptypes.DurationProto(50 * time.Millisecond),
	})
	if err != nil {
		t.Fatalf("CreatePipe; err = %v", err)
	}

	name := pipe.GetName()
	t.Logf("pipe %s", pipe)

	message := []byte("hello world\n")
	var writeResp *netpipe.WriteAtResponse
	var n int64

	var conn net.Conn
	{
		timeout := time.NewTimer(100 * time.Millisecond)
		select {
		case conn = <-accept:
			timeout.Stop()
		case <-timeout.C:
			t.Fatalf("no connection available")
		}
	}
	connTimeout := func(d time.Duration) {
		t.Helper()
		err := conn.SetDeadline(time.Now().Add(d))
		if err != nil {
			t.Errorf("SetDeadline; err = %v", err)
		}
	}
	var srvReadBuf bytes.Buffer

	// Begin writing data segments. Write bytes 3,4. Write bytes 3,4. Write bytes 0,1,2,3.

	writeResp, err = client.WriteAt(ctx, &netpipe.WriteAtRequest{
		Name:        name,
		WriteData:   message[3:5],
		WriteOffset: 3,
	})
	if err != nil {
		t.Fatalf("WriteAt; err = %v", err)
	}
	if have, want := writeResp.GetCommittedWriteOffset(), int64(0); have != want {
		t.Errorf("committed write offset; %d != %d", have, want)
	}

	writeResp, err = client.WriteAt(ctx, &netpipe.WriteAtRequest{
		Name:        name,
		WriteData:   message[3:5],
		WriteOffset: 3,
	})
	if err != nil {
		t.Fatalf("WriteAt; err = %v", err)
	}
	if have, want := writeResp.GetCommittedWriteOffset(), int64(0); have != want {
		t.Errorf("committed write offset; %d != %d", have, want)
	}

	connTimeout(1 * time.Millisecond)
	n, err = io.Copy(&srvReadBuf, conn)
	if have, want := n, int64(0); have != want {
		t.Errorf("Read length; %d != %d", have, want)
	}
	if ne, ok := err.(net.Error); err == nil || !(ok && ne.Timeout()) {
		t.Errorf("Read error, not a timeout; err = %v", err)
	}

	writeResp, err = client.WriteAt(ctx, &netpipe.WriteAtRequest{
		Name:        name,
		WriteData:   message[0:4],
		WriteOffset: 0,
	})
	if err != nil {
		t.Fatalf("WriteAt; err = %v", err)
	}
	if have, want := writeResp.GetCommittedWriteOffset(), int64(5); have != want {
		t.Errorf("committed write offset; %d != %d", have, want)
	}

	connTimeout(1 * time.Millisecond)
	n, err = io.Copy(&srvReadBuf, conn)
	if have, want := n, int64(5); have != want {
		t.Errorf("Read length; %d != %d", have, want)
	}
	if ne, ok := err.(net.Error); err == nil || !(ok && ne.Timeout()) {
		t.Errorf("Read error, not a timeout; err = %v", err)
	}

	writeResp, err = client.WriteAt(ctx, &netpipe.WriteAtRequest{
		Name:        name,
		WriteData:   message[5:],
		WriteOffset: 5,
	})
	if err != nil {
		t.Fatalf("WriteAt; err = %v", err)
	}
	if have, want := writeResp.GetCommittedWriteOffset(), int64(12); have != want {
		t.Errorf("committed write offset; %d != %d", have, want)
	}

	connTimeout(1 * time.Millisecond)
	n, err = io.Copy(&srvReadBuf, conn)
	if have, want := n, int64(7); have != want {
		t.Errorf("Read length; %d != %d", have, want)
	}
	if ne, ok := err.(net.Error); err == nil || !(ok && ne.Timeout()) {
		t.Errorf("Read error, not a timeout; err = %v", err)
	}

	if have, want := len(message), 12; have != want {
		t.Errorf("Write message length; %d != %d", have, want)
	}
	if have, want := srvReadBuf.Len(), 12; have != want {
		t.Errorf("Read buffer length; %d != %d", have, want)
	}

	if have, want := srvReadBuf.Bytes(), message; !bytes.Equal(have, want) {
		t.Errorf("Message is corrupted; %q != %q", have, want)
	}
}
