package main

import (
	"context"
	"crypto/ecdsa"
	"crypto/elliptic"
	crand "crypto/rand"
	"crypto/tls"
	"crypto/x509"
	"crypto/x509/pkix"
	"encoding/pem"
	"errors"
	"flag"
	"fmt"
	"io"
	"log"
	"math/big"
	"math/rand"
	"net"
	"net/http"
	"net/http/httptrace"
	"net/url"
	"os"
	"os/signal"
	"runtime/pprof"
	"runtime/trace"
	"strconv"
	"sync"
	"sync/atomic"
	"time"

	"code.justin.tv/rhys/nursery/cmd/multicp/netpipe"
	"code.justin.tv/rhys/nursery/cmd/multicp/picker"
	"code.justin.tv/rhys/nursery/cmd/multicp/pipeclient"
	"code.justin.tv/rhys/nursery/cmd/multicp/pipeserver"
	"code.justin.tv/rhys/nursery/cmd/multicp/quicpath"
	"github.com/lucas-clemente/quic-go"
	"golang.org/x/time/rate"
)

func main() {
	listen := flag.String("listen", "", "Bind address (for server mode)")
	target := flag.String("target", "", "Remote address (for client mode)")
	connCount := flag.Int("conns", 1, "Number of client connections to use")
	interval := flag.Duration("interval", 100*time.Millisecond, "Base sleep time between client messages")
	logSlow := flag.Duration("log-slow", 0, "Minimum round-trip time to warrant logging")
	useNetpipe := flag.Bool("netpipe", false, "Enable netpipe-based multipath")
	useQUIC := flag.Bool("quic", false, "Use QUIC-based multipath")
	cpuprofile := flag.String("cpuprofile", "", "Path to write CPU profile")
	execTrace := flag.String("trace", "", "Path to write execution trace")

	flag.Parse()

	log.SetFlags(log.LUTC | log.Ldate | log.Ltime | log.Lmicroseconds)

	if name := *cpuprofile; name != "" {
		f, err := os.Create(name)
		if err != nil {
			log.Fatalf("create profile: %v", err)
		}
		defer func() {
			err := f.Close()
			if err != nil {
				log.Fatalf("close profile: %v", err)
			}
		}()

		err = pprof.StartCPUProfile(f)
		if err != nil {
			log.Fatalf("start profile: %v", err)
		}
		defer pprof.StopCPUProfile()
	}

	if name := *execTrace; name != "" {
		f, err := os.Create(name)
		if err != nil {
			log.Fatalf("create trace: %v", err)
		}
		defer func() {
			err := f.Close()
			if err != nil {
				log.Fatalf("close trace: %v", err)
			}
		}()

		err = trace.Start(f)
		if err != nil {
			log.Fatalf("start trace: %v", err)
		}
		defer trace.Stop()
	}

	ctx := context.Background()
	ctx, cancel := context.WithCancel(ctx)
	defer cancel()

	sigCh := make(chan os.Signal, 1)
	signal.Notify(sigCh, os.Interrupt)
	go func() {
		<-sigCh
		cancel()
	}()

	if *listen != "" {
		serverMode(ctx, *listen, *useNetpipe, *useQUIC)
	} else {
		clientMode(ctx, *target, *connCount, *interval, *logSlow, *useNetpipe, *useQUIC)
	}
}

func clientMode(ctx context.Context, target string, connCount int, interval time.Duration, logSlow time.Duration, useNetpipe, useQUIC bool) {
	ctx, cancel := context.WithCancel(ctx)
	defer cancel()

	host, port, err := net.SplitHostPort(target)
	if err != nil {
		log.Fatalf("split: %v", err)
	}
	dns := &net.Resolver{}
	dialer := &net.Dialer{}

	portNum, err := dns.LookupPort(ctx, "tcp", port)
	if err != nil {
		log.Fatalf("LookupPort: %v", err)
	}

	ips, err := dns.LookupIPAddr(ctx, host)
	if err != nil {
		log.Fatalf("LookupIPAddr: %v", err)
	}

	dial := func(ctx context.Context) (net.Conn, error) {
		addr := net.JoinHostPort(ips[0].String(), strconv.Itoa(portNum))
		c, err := dialer.DialContext(ctx, "tcp", addr)
		return c, err
	}

	if useNetpipe {
		addr := &url.URL{
			Scheme: "https",
			Host:   net.JoinHostPort(ips[0].String(), strconv.Itoa(portNum)),
		}
		target := addr.String()
		httpClient := newHTTPClient()
		client := pipeclient.NewClient(netpipe.NewNetPipeProtobufClient(target, httpClient))
		dial = func(ctx context.Context) (net.Conn, error) {
			pipe, err := client.Open(ctx, target)
			if err != nil {
				return nil, err
			}
			conn := pipeclient.PipeConn(pipe)
			return conn, nil
		}
	}

	if useQUIC {
		dialer := &quicpath.Dialer{
			PathCount:      10,
			HealthInterval: 100 * time.Millisecond,
		}
		addr := net.JoinHostPort(ips[0].String(), strconv.Itoa(portNum))
		var mu sync.Mutex
		dial = func(ctx context.Context) (net.Conn, error) {
			mu.Lock()
			defer mu.Unlock()
			return dialer.DialContext(ctx, "udp", addr)
		}

		// pconn, err := listenMultiPacket(10)
		// if err != nil {
		// 	log.Fatalf("multipacket %v", err)
		// }
		// dial = func(ctx context.Context) (net.Conn, error) {
		// 	addr := &net.UDPAddr{IP: ips[0].IP, Port: portNum}
		// 	session, err := quic.DialContext(ctx, pconn, addr, "remote", &tls.Config{
		// 		InsecureSkipVerify: true,
		// 		NextProtos:         []string{"h2"},
		// 	}, &quic.Config{
		// 		ConnectionIDLength: 18,
		// 		KeepAlive:          true,
		// 	})
		// 	if err != nil {
		// 		return nil, err
		// 	}
		// 	stream1, err := session.OpenStream()
		// 	if err != nil {
		// 		return nil, err
		// 	}
		// 	stream2, err := session.OpenStream()
		// 	if err != nil {
		// 		return nil, err
		// 	}

		// 	go func() {
		// 		for {
		// 			d := 130 * time.Millisecond
		// 			time.Sleep(d - time.Duration(rand.Int63n(int64(d)/5)))
		// 			pconn.shuffle()
		// 			err := stream2.SetWriteDeadline(time.Now().Add(50 * time.Millisecond))
		// 			if err != nil {
		// 				return
		// 			}
		// 			_, err = stream2.Write([]byte("h"))
		// 			if err != nil {
		// 				return
		// 			}
		// 		}
		// 	}()

		// 	conn := &quicConn{
		// 		Stream:     stream1,
		// 		localAddr:  session.LocalAddr(),
		// 		remoteAddr: session.RemoteAddr(),
		// 	}
		// 	return conn, nil
		// }
	}

	var wg sync.WaitGroup
	for i := 0; i < connCount; i++ {
		ctx := pprof.WithLabels(ctx, pprof.Labels(
			"clientConn", strconv.Itoa(i),
		))
		pprof.SetGoroutineLabels(ctx)

		rng := rand.New(rand.NewSource(int64(i)))
		wg.Add(1)
		go func() {
			defer wg.Done()

			c, err := dial(ctx)
			if err != nil {
				log.Printf("Dial: %v", err)
				cancel()
				return
			}
			defer c.Close()

			wg.Add(1)
			go func() {
				defer wg.Done()
				<-ctx.Done()
				c.SetDeadline(time.Now())
			}()

			buf := make([]byte, 1024)
			wait := interval
			wait -= time.Duration(rng.Int63n(int64(interval)))
			next := time.NewTimer(wait)
			defer next.Stop()
			for {
				select {
				case <-ctx.Done():
					return
				case <-next.C:
				}

				wait := interval
				wait -= time.Duration(rng.Int63n(int64(interval))) / 5
				next.Reset(wait)

				t0 := time.Now().UTC()
				nw, err := c.Write(buf)
				if err != nil {
					log.Printf("Write: %v", err)
					cancel()
					return
				}
				nr, err := io.ReadFull(c, buf[:nw])
				if err != nil {
					log.Printf("ReadFull: %v", err)
					cancel()
					return
				}
				_ = nr
				t1 := time.Now().UTC()

				rtt := t1.Sub(t0)
				if rtt >= logSlow {
					fmt.Fprintf(os.Stdout, "%s %s %s %d\n",
						t0.Format("2006-01-02T15:04:05.000000Z07:00"),
						c.LocalAddr().String(), c.RemoteAddr().String(), rtt.Microseconds())
				}
			}
		}()
	}
	wg.Wait()
}

func serverMode(ctx context.Context, bind string, useNetpipe, useQUIC bool) {
	var wg sync.WaitGroup
	defer wg.Wait()

	ctx, cancel := context.WithCancel(ctx)
	defer cancel()

	if useQUIC {
		tlsConf, err := newServerTLSConfig()
		if err != nil {
			log.Fatalf("listen: %v", err)
		}
		pc, err := net.ListenPacket("udp4", bind)
		if err != nil {
			log.Fatalf("listen: %v", err)
		}
		ql, err := quicpath.Listen(pc, tlsConf, &quic.Config{
			ConnectionIDLength: 18,
			KeepAlive:          true,
		})
		if err != nil {
			log.Fatalf("listen: %v", err)
		}
		defer ql.Close()

		go func() {
			<-ctx.Done()
			ql.Close()
		}()

		for {
			conn, err := ql.Accept()
			if err == io.EOF || err == context.Canceled {
				return
			}
			if err != nil {
				log.Printf("accept: %v", err)
				time.Sleep(3 * time.Second)
				continue
			}
			go func() {
				// defer session.CloseWithError(42, "done")
				// stream1, err := session.AcceptStream(ctx)
				// if err != nil {
				// 	return
				// }
				// stream2, err := session.AcceptStream(ctx)
				// if err != nil {
				// 	return
				// }
				// go func() {
				// 	io.Copy(ioutil.Discard, stream2)
				// }()
				// conn := &quicConn{
				// 	Stream:     stream1,
				// 	localAddr:  session.LocalAddr(),
				// 	remoteAddr: session.RemoteAddr(),
				// }
				serveConn(conn)
			}()
		}
	}

	l, err := net.Listen("tcp", bind)
	if err != nil {
		log.Fatalf("listen: %v", err)
	}
	if useNetpipe {
		l, err = listenNetPipe(l)
		if err != nil {
			log.Fatalf("listenNetPipe: %v", err)
		}
	}
	wg.Add(1)
	go func() {
		defer wg.Done()
		<-ctx.Done()
		l.Close()
	}()

	for {
		c, err := l.Accept()
		if err == io.EOF {
			return
		}
		if err != nil {
			log.Printf("accept: %v", err)
			time.Sleep(3 * time.Second)
			continue
		}
		go serveConn(c)
	}
}

func serveConn(c net.Conn) {
	defer c.Close()
	buf := make([]byte, 4<<10)
	for {
		nr, err := c.Read(buf)
		if err != nil {
			log.Printf("read %v/%v: %v", c.LocalAddr(), c.RemoteAddr(), err)
			return
		}
		nw, err := c.Write(buf[:nr])
		if err != nil {
			log.Printf("write %v/%v: %v", c.LocalAddr(), c.RemoteAddr(), err)
			return
		}
		_ = nw
	}
}

func listenNetPipe(base net.Listener) (net.Listener, error) {
	tlsConfig, err := newServerTLSConfig()
	if err != nil {
		return nil, err
	}

	tlsBase := tls.NewListener(base, tlsConfig)

	type contextKey struct{ v string }
	remoteAddrContextKey := &contextKey{v: "remote-addr"}

	mux := http.NewServeMux()
	l := &pipeListener{
		conns: make(chan net.Conn, 0),
		base:  tlsBase,
		httpServer: &http.Server{
			Handler: http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
				ctx := r.Context()
				ctx = context.WithValue(ctx, remoteAddrContextKey, r.RemoteAddr)
				mux.ServeHTTP(w, r.WithContext(ctx))
			}),
		},
	}

	server := &pipeserver.Server{
		NewConn: func(ctx context.Context, target string, c net.Conn) error {
			l.conns <- c
			return nil
		},
		LocalAddr: func(ctx context.Context) net.Addr {
			return &net.UnixAddr{Name: (&url.URL{Scheme: "https", Host: tlsBase.Addr().String()}).String()}
		},
		RemoteAddr: func(ctx context.Context) net.Addr {
			addr, _ := ctx.Value(remoteAddrContextKey).(string)
			return &net.UnixAddr{Name: addr}
		},
	}
	twirpServer := netpipe.NewNetPipeServer(server, nil)
	mux.Handle(netpipe.NetPipePathPrefix, twirpServer)

	go l.run()

	return l, nil
}

type pipeListener struct {
	conns      chan net.Conn
	base       net.Listener
	httpServer *http.Server
}

var _ net.Listener = (*pipeListener)(nil)

func (l *pipeListener) run() error {
	return l.httpServer.Serve(l.base)
}

func (l *pipeListener) Accept() (net.Conn, error) {
	conn, ok := <-l.conns
	if !ok {
		return nil, io.EOF
	}
	return conn, nil
}

func (l *pipeListener) Addr() net.Addr {
	return l.base.Addr()
}

func (l *pipeListener) Close() error {
	close(l.conns)
	return l.httpServer.Close()
}

func newServerTLSConfig() (*tls.Config, error) {
	// From https://github.com/golang/go/blob/go1.15.1/src/crypto/tls/generate_cert.go
	priv, err := ecdsa.GenerateKey(elliptic.P256(), crand.Reader)
	if err != nil {
		return nil, err
	}
	keyUsage := x509.KeyUsageDigitalSignature
	notBefore := time.Now().Add(-30 * time.Minute)
	notAfter := notBefore.Add(1000 * time.Hour)

	serialNumber, err := crand.Int(crand.Reader, new(big.Int).Lsh(big.NewInt(1), 128))
	if err != nil {
		return nil, err
	}

	template := x509.Certificate{
		SerialNumber:          serialNumber,
		Subject:               pkix.Name{Organization: []string{"Acme Co"}},
		NotBefore:             notBefore,
		NotAfter:              notAfter,
		KeyUsage:              keyUsage,
		ExtKeyUsage:           []x509.ExtKeyUsage{x509.ExtKeyUsageServerAuth},
		BasicConstraintsValid: true,
	}
	template.IPAddresses = append(template.IPAddresses, net.IPv4(127, 0, 0, 1))

	derBytes, err := x509.CreateCertificate(crand.Reader, &template, &template, &priv.PublicKey, priv)
	if err != nil {
		return nil, err
	}

	privBytes, err := x509.MarshalPKCS8PrivateKey(priv)
	if err != nil {
		return nil, err
	}

	cert, err := tls.X509KeyPair(
		pem.EncodeToMemory(&pem.Block{Type: "CERTIFICATE", Bytes: derBytes}),
		pem.EncodeToMemory(&pem.Block{Type: "PRIVATE KEY", Bytes: privBytes}),
	)

	config := &tls.Config{
		Certificates: []tls.Certificate{cert},
		NextProtos: []string{
			"h2",             // HTTP/2 (standard)
			"multicp-health", // Check health of a five-tuple
			"multicp-data",   // Transfer application data
		},
		MinVersion: tls.VersionTLS13,
		MaxVersion: tls.VersionTLS13,
	}

	return config, nil
}

func newHTTPClient() *http.Client {
	limiter := rate.NewLimiter(rate.Every(10*time.Millisecond), 10)
	dialer := &net.Dialer{}
	newTransport := func() *http.Transport {
		tr := &http.Transport{
			DialContext: func(ctx context.Context, network, address string) (net.Conn, error) {
				err := limiter.Wait(ctx)
				if err != nil {
					<-ctx.Done()
					return nil, ctx.Err()
				}
				return dialer.DialContext(ctx, network, address)
			},

			Proxy:             http.ProxyFromEnvironment,
			ForceAttemptHTTP2: true,

			MaxConnsPerHost:     100,
			MaxIdleConns:        100,
			MaxIdleConnsPerHost: 100,

			TLSHandshakeTimeout: 10 * time.Second,
			IdleConnTimeout:     2 * time.Minute,

			TLSClientConfig: &tls.Config{InsecureSkipVerify: true},
		}
		return tr
	}
	spread := &spreadTransport{
		base: make([]*countingTransport, 10),
		intn: rand.Intn,
	}
	for i := range spread.base {
		spread.base[i] = &countingTransport{
			base: &loggingTransport{
				base: newTransport(),
			},
		}
	}
	cl := &http.Client{Transport: spread}
	return cl
}

type spreadTransport struct {
	base []*countingTransport
	intn func(n int) int
}

func (t *spreadTransport) RoundTrip(req *http.Request) (*http.Response, error) {
	ctx := req.Context()
	ctx, cancel := context.WithCancel(ctx)

	var (
		pickMu   sync.Mutex
		picker   = picker.New(len(t.base), t.intn)
		pickBest = func() *countingTransport {
			pickMu.Lock()
			defer pickMu.Unlock()
			// "The Power of Two" load balancing
			options := make([]int, 0, 2)
			for i := 0; i < cap(options); i++ {
				opt := picker.Pick()
				if opt < 0 {
					break
				}
				options = append(options, opt)
			}

			if len(options) == 0 {
				return nil
			}

			choice := options[0]
			for _, other := range options[1:] {
				totalA, outstandingA := t.base[choice].counts(req)
				totalB, outstandingB := t.base[other].counts(req)

				if (outstandingA > outstandingB) || ((outstandingA == outstandingB) && (totalA > totalB)) {
					choice, other = other, choice
				}
				picker.Replace(other)
			}
			return t.base[choice]
		}
	)

	type reply struct {
		resp *http.Response
		err  error
	}
	replies := make(chan reply)
	responses := make(chan *http.Response)

	// The real requests use req.GetBody, but we must close the original body.
	req.Body.Close()

	var wg sync.WaitGroup
	send := func() {
		wg.Add(1)
		go func() {
			defer wg.Done()
			transport := pickBest()
			var resp *http.Response
			err := errors.New("no transport available")
			if transport != nil {
				r2 := req.WithContext(ctx)
				r2.Body, err = r2.GetBody()
				if err == nil {
					resp, err = transport.RoundTrip(r2)
				}
			}
			if resp != nil {
				resp.Request = req
			}

			if err == nil {
				select {
				case responses <- resp:
					// receiver takes care of Body
				case <-ctx.Done():
					if resp != nil {
						resp.Body.Close()
					}
				}
			}
			select {
			case <-ctx.Done():
			case replies <- reply{resp: resp, err: err}:
			}
		}()
	}

	done := func() {
		cancel()
		wg.Wait()
	}

	defer func() {
		if done != nil {
			done()
		}
	}()

	tries := 1
	switch req.URL.Path {
	case "/twirp/nursery.multicp.netpipe.NetPipe/WriteAt":
		tries = 5
	}

	tick := time.NewTicker(50 * time.Millisecond)
	defer tick.Stop()
	for i := 0; i < tries; i++ {
		send()

		select {
		case <-ctx.Done():
			return nil, ctx.Err()
		case resp := <-responses:
			resp.Body = &notifyBody{
				base: resp.Body,
				done: done,
			}
			done = nil
			return resp, nil
		case <-tick.C:
			// Add another request to the race
		}
	}

	select {
	case <-ctx.Done():
		return nil, ctx.Err()
	case resp := <-responses:
		resp.Body = &notifyBody{
			base: resp.Body,
			done: done,
		}
		done = nil
		return resp, nil
	case reply := <-replies:
		resp, err := reply.resp, reply.err
		if resp != nil {
			resp.Body = &notifyBody{
				base: resp.Body,
				done: done,
			}
			done = nil
		}
		return resp, err
	}
}

type countingTransport struct {
	base http.RoundTripper

	mu                  sync.Mutex
	totalRequests       map[string]int64
	outstandingRequests map[string]int64
}

func countKey(req *http.Request) string {
	return req.URL.Host
}

func (t *countingTransport) initLocked() {
	if t.totalRequests == nil {
		t.totalRequests = make(map[string]int64)
	}
	if t.outstandingRequests == nil {
		t.outstandingRequests = make(map[string]int64)
	}
}

func (t *countingTransport) counts(req *http.Request) (total, outstanding int64) {
	key := countKey(req)

	t.mu.Lock()
	t.initLocked()
	total = t.totalRequests[key]
	outstanding = t.outstandingRequests[key]
	t.mu.Unlock()

	return total, outstanding
}

func (t *countingTransport) RoundTrip(req *http.Request) (*http.Response, error) {
	ctx := req.Context()

	key := countKey(req)

	t.mu.Lock()
	t.initLocked()
	t.totalRequests[key]++
	t.outstandingRequests[key]++
	t.mu.Unlock()

	done := func() {
		t.mu.Lock()
		t.outstandingRequests[key]--
		t.mu.Unlock()
	}

	defer func() {
		if done != nil {
			done()
		}
	}()

	resp, err := t.base.RoundTrip(req.WithContext(ctx))
	if resp != nil {
		resp.Request = req
		resp.Body = &notifyBody{
			base: resp.Body,
			done: done,
		}
		done = nil
	}

	return resp, err
}

type notifyBody struct {
	base io.ReadCloser

	doneOnce sync.Once
	done     func()
}

func (b *notifyBody) Read(p []byte) (int, error) {
	n, err := b.base.Read(p)
	if err == io.EOF {
		b.doneOnce.Do(b.done)
	}
	return n, err
}

func (b *notifyBody) Close() error {
	err := b.base.Close()
	b.doneOnce.Do(b.done)
	return err
}

type loggingTransport struct {
	base http.RoundTripper
}

func (t *loggingTransport) RoundTrip(req *http.Request) (*http.Response, error) {
	ctx := req.Context()

	var (
		localAddr string
		code      int
		proto     string
	)
	tr := &httptrace.ClientTrace{
		GotConn: func(info httptrace.GotConnInfo) {
			localAddr = info.Conn.LocalAddr().String()
		},
	}

	ctx = httptrace.WithClientTrace(ctx, tr)

	t0 := time.Now().UTC()
	resp, err := t.base.RoundTrip(req.WithContext(ctx))
	t1 := time.Now().UTC()
	if resp != nil {
		resp.Request = req
		code = resp.StatusCode
		proto = resp.Proto
	}

	rtt := t1.Sub(t0)

	fmt.Fprintf(os.Stdout, "%s %s %s %d %d\n",
		t0.Format("2006-01-02T15:04:05.000000Z07:00"),
		localAddr, req.URL.Path, rtt.Microseconds(), code)

	if false {
		log.Printf("proto=%q local=%q request=%q duration=%dµs code=%d err=%q",
			proto, localAddr, req.URL.Path, rtt.Microseconds(), code, err)
	}

	return resp, err
}

type multiPacketConn struct {
	conns []net.PacketConn

	reads  chan packetRead
	active int64 // accessed atomically
}

var _ net.PacketConn = (*multiPacketConn)(nil)

type packetRead struct {
	buf []byte
	fn  func(n int, addr net.Addr, err error)
}

func listenMultiPacket(count int) (*multiPacketConn, error) {
	multi := &multiPacketConn{
		reads: make(chan packetRead),
	}

	for i := 0; i < count; i++ {
		conn, err := net.ListenUDP("udp4", &net.UDPAddr{
			IP:   nil, // leave IP unset to get all interfaces
			Port: 0,   // leave port unset to get one assigned
		})
		if err != nil {
			return nil, err
		}
		multi.conns = append(multi.conns, conn)
	}

	for _, conn := range multi.conns {
		conn := conn
		buf := make([]byte, 10240)
		go func() {
			for {
				buf = buf[:cap(buf)]
				n, addr, err := conn.ReadFrom(buf)
				buf = buf[:n]

				if err == nil {
					max := 20
					if n < max {
						max = n
					}
					prefix := buf[:max]
					log.Printf("ReadFrom %v %v %d %02x", conn.LocalAddr(), addr, n, prefix)
				}

				read, ok := <-multi.reads
				if !ok {
					return
				}

				n = copy(read.buf, buf)
				read.fn(n, addr, err)
			}
		}()
	}

	return multi, nil
}

func (mpc *multiPacketConn) shuffle() {
	old := atomic.LoadInt64(&mpc.active)
	next := rand.Intn(len(mpc.conns) - 1)
	if next >= int(old) {
		next++
	}
	// fine if we lost the race, we only care that it's different now (and don't
	// want to shuffle too much)
	atomic.CompareAndSwapInt64(&mpc.active, old, int64(next))
}

func (mpc *multiPacketConn) eachConn(fn func(i int) error) error {
	var err error
	for i := range mpc.conns {
		e := fn(i)
		if e != nil {
			err = e
		}
	}
	return err
}

func (mpc *multiPacketConn) Close() error {
	close(mpc.reads)
	return mpc.eachConn(func(i int) error { return mpc.conns[i].Close() })
}

func (mpc *multiPacketConn) LocalAddr() net.Addr {
	return mpc.conns[0].LocalAddr()
}

func (mpc *multiPacketConn) ReadFrom(p []byte) (n int, addr net.Addr, err error) {
	var wg sync.WaitGroup
	wg.Add(1)
	fn := func(fnn int, fnaddr net.Addr, fnerr error) {
		defer wg.Done()
		n, addr, err = fnn, fnaddr, fnerr
	}
	mpc.reads <- packetRead{buf: p, fn: fn}
	wg.Wait()
	return n, addr, err
}

func (mpc *multiPacketConn) WriteTo(p []byte, addr net.Addr) (int, error) {
	active := int(atomic.LoadInt64(&mpc.active))
	conn := mpc.conns[active]
	log.Printf("WriteTo %v %v", conn.LocalAddr(), addr)
	return conn.WriteTo(p, addr)
}

func (mpc *multiPacketConn) SetDeadline(t time.Time) error {
	return mpc.eachConn(func(i int) error { return mpc.conns[i].SetDeadline(t) })
}
func (mpc *multiPacketConn) SetReadDeadline(t time.Time) error {
	return mpc.eachConn(func(i int) error { return mpc.conns[i].SetReadDeadline(t) })
}
func (mpc *multiPacketConn) SetWriteDeadline(t time.Time) error {
	return mpc.eachConn(func(i int) error { return mpc.conns[i].SetWriteDeadline(t) })
}

type quicConn struct {
	quic.Stream
	localAddr  net.Addr
	remoteAddr net.Addr
}

var _ net.Conn = (*quicConn)(nil)

func (qc *quicConn) LocalAddr() net.Addr  { return qc.localAddr }
func (qc *quicConn) RemoteAddr() net.Addr { return qc.remoteAddr }
