package quicpath

import (
	"bufio"
	"context"
	crand "crypto/rand"
	"crypto/tls"
	"fmt"
	"io"
	"log"
	"math/rand"
	"net"
	"os"
	"sort"
	"strings"
	"sync"
	"time"

	"code.justin.tv/rhys/nursery/cmd/multicp/picker"
	"github.com/lucas-clemente/quic-go"
)

const (
	defaultPathCount = 10

	defaultHealthInterval = 100 * time.Millisecond
)

type quicConn struct {
	quic.Stream
	localAddr  net.Addr
	remoteAddr net.Addr
}

var _ net.Conn = (*quicConn)(nil)

func (qc *quicConn) LocalAddr() net.Addr  { return qc.localAddr }
func (qc *quicConn) RemoteAddr() net.Addr { return qc.remoteAddr }

type Dialer struct {
	PathCount      int
	HealthInterval time.Duration

	base *net.Dialer

	mu      sync.Mutex
	remotes map[string]*multiTargetPacketConn
}

func (d *Dialer) DialContext(ctx context.Context, network, address string) (net.Conn, error) {
	base := d.base
	if base == nil {
		base = &net.Dialer{}
	}

	d.mu.Lock()
	if d.remotes == nil {
		d.remotes = make(map[string]*multiTargetPacketConn)
	}
	key := fmt.Sprintf("%q:%q", network, address)
	remote, alreadyDialed := d.remotes[key]
	if !alreadyDialed {
		remote = &multiTargetPacketConn{
			healthInterval: d.HealthInterval,
			reads:          make(chan packetRead),
		}
		d.remotes[key] = remote
	}
	d.mu.Unlock()

	if !alreadyDialed {
		var id [32]byte
		_, err := io.ReadFull(crand.Reader, id[:])
		if err != nil {
			log.Printf("crypto/rand: %v", err)
			return nil, err
		}

		dialNetwork := "udp"
		dialAddress := address

		pathCount := d.PathCount
		if pathCount <= 0 {
			pathCount = defaultPathCount
		}
		conns := make([]*dialPath, pathCount)
		errs := make([]error, len(conns))
		var wg sync.WaitGroup
		for i := range conns {
			wg.Add(1)
			go func(i int) {
				defer wg.Done()

				path := &dialPath{}

				conn, err := base.DialContext(ctx, dialNetwork, dialAddress)
				if err != nil {
					errs[i] = err
					log.Printf("net.DialContext[%d]: %v", i, err)
					return
				}

				var ok bool
				path.base, ok = conn.(*net.UDPConn)
				if !ok {
					conn.Close()
					errs[i] = fmt.Errorf("dialer returned unexpected connection type %T instead of *net.UDPConn", conn)
					return
				}

				remoteAddr := &pathAddr{
					Addr:   path.base.RemoteAddr(),
					source: &singleTargetPacketConn{conn: path.base},
				}

				fn := remote.add(&singleTargetPacketConn{conn: path.base}, id[:])

				// TODO: The net.PacketConn here needs to be the same one as we
				// use later for the data connection. We need to pass in a
				// remote address (or something) that encodes the desired source
				// address.
				path.session, err = quic.DialContext(ctx, remote, remoteAddr, "remote", &tls.Config{
					InsecureSkipVerify: true,
					NextProtos:         []string{"multicp-health"},
				}, &quic.Config{
					ConnectionIDLength: 18,
					KeepAlive:          true,
				})
				if err != nil {
					conn.Close()
					errs[i] = err
					log.Printf("quic.DialContext[%d]: %v", i, err)
					return
				}

				path.health, err = path.session.OpenStream()
				if err != nil {
					conn.Close()
					errs[i] = err
					log.Printf("quic.OpenStream[%d]: %v", i, err)
					return
				}
				conns[i] = path
				fn(path.health)
			}(i)

			if i == 0 {
				wg.Wait()
				if err := errs[0]; err != nil {
					return nil, err
				}
				remote.remoteAddr = conns[0].base.RemoteAddr()
				dialAddress = remote.remoteAddr.String()
			}
		}
		wg.Wait()

		workingConns := make([]*dialPath, 0, len(conns))
		var reportErr error
		for i := range conns {
			conn := conns[i]
			err := errs[i]
			if err == nil {
				workingConns = append(workingConns, conn)
			} else if reportErr == nil {
				reportErr = err
			}
		}
		if len(workingConns) < len(conns)/2 {
			return nil, reportErr
		}

		// for _, path := range workingConns {
		// 	fn := remote.add(path.conn, id[:])
		// 	fn(path.health)
		// }

		// go func() {
		// 	for {
		// 		time.Sleep(time.Second)
		// 		remote.mu.Lock()
		// 		for _, opt := range remote.options {
		// 			opt.stats.mu.Lock()
		// 			log.Printf("stats %v %v %d %s",
		// 				opt.conn.LocalAddr(), remote.remoteAddr,
		// 				opt.stats.rtt.Microseconds(), opt.stats.pingStart.Format(time.RFC3339))
		// 			opt.stats.mu.Unlock()
		// 		}
		// 		remote.mu.Unlock()
		// 	}
		// }()
	}

	var session quic.Session
	var err error
	for i := 0; i < 5; i++ {

		// TODO: pick source port to use until dial is complete
		addr := &pathAddr{Addr: remote.remoteAddr}
		addr.source, err = remote.pickPath()
		if err != nil {
			return nil, err
		}

		session, err = quic.DialContext(ctx, remote, addr, "remote", &tls.Config{
			InsecureSkipVerify: true,
			NextProtos:         []string{"multicp-data"},
		}, &quic.Config{
			ConnectionIDLength: 18,
			KeepAlive:          true,
		})
		addr.mu.Lock()
		addr.source = nil
		addr.mu.Unlock()
		if err != nil {
			log.Printf("quic.DialContext: %T %#v %v", err, err, err)
			if qe, ok := err.(interface {
				IsApplicationError() bool
				IsCryptoError() bool
				Temporary() bool
				Timeout() bool
			}); ok {
				log.Printf("IsApplicationError: %t", qe.IsApplicationError())
				log.Printf("IsCryptoError: %t", qe.IsCryptoError())
				log.Printf("Temporary: %t", qe.Temporary())
				log.Printf("Timeout: %t", qe.Timeout())
			}
		} else {
			break
		}
	}
	if err != nil {
		return nil, err
	}

	stream, err := session.OpenStream()
	if err != nil {
		session.CloseWithError(1, "")
		log.Printf("quic.OpenStream: %v", err)
		return nil, err
	}

	conn := &quicConn{
		Stream:     stream,
		localAddr:  session.LocalAddr(),
		remoteAddr: session.RemoteAddr(),
	}

	return conn, nil

	// conn, err := base.DialContext(ctx, dialNetwork, address)
	// if err != nil {
	// 	return nil, err
	// }

	// uconn, ok := conn.(*net.UDPConn)
	// if !ok {
	// 	conn.Close()
	// 	return nil, fmt.Errorf("dialer returned unexpected connection type %T instead of *net.UDPConn", conn)
	// }

	// pconn := &singleTargetPacketConn{conn: uconn}

	// stats := &pathStats{}
	// healthSession, err := quic.DialContext(ctx, pconn, uconn.RemoteAddr(), "remote", &tls.Config{
	// 	InsecureSkipVerify: true,
	// 	NextProtos:         []string{"multicp-health"},
	// }, &quic.Config{
	// 	ConnectionIDLength: 18,
	// 	KeepAlive:          true,
	// })
	// if err != nil {
	// 	return nil, err
	// }

	// health, err := healthSession.OpenStream()
	// if err != nil {
	// 	conn.Close()
	// 	return nil, err
	// }

	// go func() {
	// 	buf := make([]byte, 1)
	// 	for {
	// 		stats.mu.Lock()
	// 		stats.pingStart = time.Now().UTC()
	// 		stats.mu.Unlock()

	// 		_, err := health.Write([]byte("0"))
	// 		if err != nil {
	// 			break
	// 		}
	// 		_, err = io.ReadFull(health, buf)
	// 		if err != nil {
	// 			break
	// 		}
	// 		if buf[0] != '1' {
	// 			break
	// 		}
	// 		pong := time.Now().UTC()

	// 		stats.mu.Lock()
	// 		stats.rtt = pong.Sub(stats.pingStart)
	// 		stats.pingStart = time.Time{}
	// 		stats.mu.Unlock()

	// 		_, err = health.Write([]byte("2"))
	// 		if err != nil {
	// 			// Ignore this one; we'll get another error soon enough, which
	// 			// will leave the stats tracker in a state that clearly shows
	// 			// the connection is broken.
	// 		}

	// 		d := 100 * time.Millisecond
	// 		d = d - time.Duration(rand.Int63n(int64(d)/5)) // 80 to 100%
	// 		time.Sleep(d)
	// 	}
	// }()

	// panic("")
}

// func (d *Dialer) addPath(ctx context.Context, address string) error {
// 	conn, err := net.DialUDP("udp", nil, addr)
// 	if err != nil {
// 		return err
// 	}

// 	dialNetwork := "udp"

// 	conn, err := base.DialContext(ctx, dialNetwork, address)
// 	if err != nil {
// 		return nil, err
// 	}

// 	uconn, ok := conn.(*net.UDPConn)
// 	if !ok {
// 		conn.Close()
// 		return nil, fmt.Errorf("dialer returned unexpected connection type %T instead of *net.UDPConn", conn)
// 	}

// 	pconn := &singleTargetPacketConn{conn: uconn}

// 	stats := &pathStats{}
// 	healthSession, err := quic.DialContext(ctx, pconn, uconn.RemoteAddr(), "remote", &tls.Config{
// 		InsecureSkipVerify: true,
// 		NextProtos:         []string{"multicp-health"},
// 	}, &quic.Config{
// 		ConnectionIDLength: 18,
// 		KeepAlive:          true,
// 	})
// 	if err != nil {
// 		return nil, err
// 	}

// 	health, err := healthSession.OpenStream()
// 	if err != nil {
// 		conn.Close()
// 		return nil, err
// 	}

// 	panic(conn)
// }

type packetRead struct {
	buf []byte
	fn  func(n int, addr net.Addr, err error)
}

type multiTargetPacketConn struct {
	healthInterval time.Duration

	reads      chan packetRead
	remoteAddr net.Addr

	mu      sync.Mutex
	options []pathOption
}

var _ net.PacketConn = (*multiTargetPacketConn)(nil)

func (pc *multiTargetPacketConn) WriteTo(p []byte, addr net.Addr) (int, error) {
	var err error
	var conn net.PacketConn

	if single, ok := addr.(*pathAddr); ok {
		addr = single.Addr

		single.mu.Lock()
		conn = single.source
		single.mu.Unlock()
	}

	if conn == nil {
		conn, err = pc.pickPath()
		if err != nil {
			return 0, err
		}
	}

	max := 20
	n := len(p)
	if n < max {
		max = n
	}
	prefix := p[:max]
	if false {
		log.Printf("WriteTo %v %v %d %02x", conn.LocalAddr(), addr, n, prefix)
	}

	return conn.WriteTo(p, addr)
}

func (pc *multiTargetPacketConn) pickPath() (net.PacketConn, error) {
	pc.mu.Lock()
	defer pc.mu.Unlock()

	// log.Printf("options: %v", pc.options)

	if len(pc.options) == 0 {
		return nil, fmt.Errorf("no path available")
	}

	// TODO: If the first two selections don't look great, consider picking more
	// from the pool.

	p := picker.New(len(pc.options), nil)

	// "Power of Two Choices" load balancing.
	options := make([]int, 0, 2)
	for i := 0; i < cap(options); i++ {
		opt := p.Pick()
		if opt >= 0 {
			options = append(options, opt)
		}
	}
	// Sort for lock order
	sort.Ints(options)

	for _, i := range options {
		opt := pc.options[i]
		opt.stats.mu.Lock()
	}

	var (
		fastIndex = -1
		fastRTT   time.Duration

		latestStartIndex = -1
		latestStartTime  time.Time
	)

	for _, i := range options {
		opt := pc.options[i]
		rtt := opt.stats.rtt
		start := opt.stats.pingStart

		// log.Printf("stats %v", opt.stats)

		if start.IsZero() {
			if rtt > 0 && (rtt < fastRTT || fastRTT == 0) {
				fastIndex = i
				fastRTT = rtt
			}
		} else {
			if latestStartTime.IsZero() || start.After(latestStartTime) {
				latestStartIndex = i
				latestStartTime = start
			}
		}
	}

	for _, i := range options {
		opt := pc.options[i]
		opt.stats.mu.Unlock()
	}

	// log.Printf("picking from %d choices, considering %d, fastRTT=%d fastIndex=%d latestStartIndex=%d",
	// 	len(pc.options), options,
	// 	fastRTT.Microseconds(), fastIndex, latestStartIndex)

	if fastIndex >= 0 {
		return pc.options[fastIndex].conn, nil
	}

	// If all connections are currently running health checks, pick the one that
	// started most recently.
	if latestStartIndex >= 0 {
		return pc.options[latestStartIndex].conn, nil
	}

	// If all connections are considered equal, pick the first one that the picker returned.
	return pc.options[options[0]].conn, nil
}

func (pc *multiTargetPacketConn) add(conn net.PacketConn, id []byte) func(health io.ReadWriter) {
	healthInterval := pc.healthInterval
	if healthInterval <= 0 {
		healthInterval = defaultHealthInterval
	}

	pc.mu.Lock()
	defer pc.mu.Unlock()

	stats := &pathStats{}

	opt := pathOption{
		stats: stats,
		conn:  conn,
	}
	pc.options = append(pc.options, opt)
	// log.Printf("options: %v", pc.options)

	go func() {
		buf := make([]byte, 10240)
		for {
			buf = buf[:cap(buf)]
			n, addr, err := conn.ReadFrom(buf)
			buf = buf[:n]

			if err == nil {
				max := 20
				if n < max {
					max = n
				}
				prefix := buf[:max]
				if false {
					log.Printf("ReadFrom %v %v %d %02x", conn.LocalAddr(), addr, n, prefix)
				}
			}

			read, ok := <-pc.reads
			if !ok {
				return
			}

			n = copy(read.buf, buf)
			read.fn(n, addr, err)
		}
	}()

	start := func(health io.ReadWriter) {
		go func() {
			_, err := fmt.Fprintf(health, "ID %02x\n", id)
			if err != nil {
				return
			}

			buf := make([]byte, 1)
			for i := 0; ; i++ {
				stats.mu.Lock()
				t0 := time.Now().UTC()
				stats.pingStart = t0
				stats.mu.Unlock()

				_, err := health.Write([]byte("0"))
				if err != nil {
					return
				}
				_, err = io.ReadFull(health, buf)
				if err != nil {
					return
				}
				if buf[0] != '1' {
					return
				}
				t1 := time.Now().UTC()

				rtt := t1.Sub(t0)
				stats.mu.Lock()
				stats.rtt = rtt
				stats.pingStart = time.Time{}
				stats.mu.Unlock()

				fmt.Fprintf(os.Stdout, "%s %s %s %d health\n",
					t0.Format("2006-01-02T15:04:05.000000Z07:00"),
					conn.LocalAddr().String(), pc.remoteAddr.String(), rtt.Microseconds())

				_, err = health.Write([]byte("2"))
				if err != nil {
					// Ignore this one; we'll get another error soon enough, which
					// will leave the stats tracker in a state that clearly shows
					// the connection is broken.
				}

				d := healthInterval
				if i == 0 {
					d = d - time.Duration(rand.Int63n(int64(d))) // 0 to 100%
				} else {
					d = d - time.Duration(rand.Int63n(int64(d)/5)) // 80 to 100%
				}
				time.Sleep(d)
			}
		}()
	}

	return start
}

func (pc *multiTargetPacketConn) ReadFrom(p []byte) (n int, addr net.Addr, err error) {
	var wg sync.WaitGroup
	wg.Add(1)
	fn := func(fnn int, fnaddr net.Addr, fnerr error) {
		defer wg.Done()
		n, addr, err = fnn, fnaddr, fnerr
	}
	pc.reads <- packetRead{buf: p, fn: fn}
	wg.Wait()
	return n, addr, err
}

func (pc *multiTargetPacketConn) LocalAddr() net.Addr {
	pc.mu.Lock()
	defer pc.mu.Unlock()

	if len(pc.options) == 0 {
		return &net.IPAddr{IP: net.IPv4(127, 0, 0, 1)}
	}

	return pc.options[0].conn.LocalAddr()
}

func (pc *multiTargetPacketConn) each(fn func(i int) error) error {
	pc.mu.Lock()
	defer pc.mu.Unlock()
	return pc.eachLocked(fn)
}

func (pc *multiTargetPacketConn) eachLocked(fn func(i int) error) error {
	var err error
	for i := range pc.options {
		e := fn(i)
		if e != nil && err == nil {
			err = e
		}
	}
	return err
}

func (pc *multiTargetPacketConn) Close() error {
	close(pc.reads)
	return pc.each(func(i int) error { return pc.options[i].conn.Close() })
}

func (pc *multiTargetPacketConn) SetDeadline(t time.Time) error {
	return pc.each(func(i int) error { return pc.options[i].conn.SetDeadline(t) })
}

func (pc *multiTargetPacketConn) SetReadDeadline(t time.Time) error {
	return pc.each(func(i int) error { return pc.options[i].conn.SetReadDeadline(t) })
}

func (pc *multiTargetPacketConn) SetWriteDeadline(t time.Time) error {
	return pc.each(func(i int) error { return pc.options[i].conn.SetWriteDeadline(t) })
}

type dialPath struct {
	base    *net.UDPConn
	conn    net.PacketConn
	session quic.Session
	health  quic.Stream
}

type pathOption struct {
	stats *pathStats
	conn  net.PacketConn
}

type pathStats struct {
	mu        sync.Mutex
	rtt       time.Duration
	pingStart time.Time
}

type singleTargetPacketConn struct {
	conn *net.UDPConn
}

var _ net.PacketConn = (*singleTargetPacketConn)(nil)

func (pc *singleTargetPacketConn) ReadFrom(p []byte) (int, net.Addr, error) {
	n, err := pc.conn.Read(p)
	addr := pc.conn.RemoteAddr()
	return n, addr, err
}

func (pc *singleTargetPacketConn) WriteTo(p []byte, addr net.Addr) (int, error) {
	raddr, _ := pc.conn.RemoteAddr().(*net.UDPAddr)
	tried, _ := addr.(*net.UDPAddr)
	if raddr == nil || tried == nil || !raddr.IP.Equal(tried.IP) || raddr.Port != tried.Port || raddr.Zone != tried.Zone {
		err := &singleTargetRemoteError{
			raddr: raddr,
			tried: addr,
		}
		panic(err)
		return 0, err
	}
	return pc.conn.Write(p)
}

func (pc *singleTargetPacketConn) Close() error {
	return pc.conn.Close()
}

func (pc *singleTargetPacketConn) LocalAddr() net.Addr {
	return pc.conn.LocalAddr()
}

func (pc *singleTargetPacketConn) SetDeadline(t time.Time) error {
	return pc.conn.SetDeadline(t)
}

func (pc *singleTargetPacketConn) SetReadDeadline(t time.Time) error {
	return pc.conn.SetReadDeadline(t)
}

func (pc *singleTargetPacketConn) SetWriteDeadline(t time.Time) error {
	return pc.conn.SetWriteDeadline(t)
}

type singleTargetRemoteError struct {
	raddr net.Addr
	tried net.Addr
}

var _ net.Error = (*singleTargetRemoteError)(nil)

func (err *singleTargetRemoteError) Error() string {
	return fmt.Sprintf("incorrect remote addres: socket bound to %q so cannot send data to %q",
		err.raddr, err.tried)
}

func (err *singleTargetRemoteError) Timeout() bool   { return false }
func (err *singleTargetRemoteError) Temporary() bool { return false }

// pathAddr is a net.Addr that specifies a source address in addition to the
// usual destination address.
//
// When the client dials QUIC health-check sessions, the packets need to come
// from a particular source port. However, the way the QUIC package listens for
// inbound data requires using a single net.PacketConn for communicating with a
// given remote address. This type enables a net.PacketConn that includes many
// local addresses, while also supporting the need for the packets for
// health-check connections to come from a specific source.
type pathAddr struct {
	net.Addr

	mu     sync.Mutex
	source net.PacketConn
}

// uniqueAddr is a net.Addr which by convention has a unique memory address. The
// net package seems to return a fresh *net.UDPAddr value from each ReadFrom
// call; this type enables a net.PacketConn that guarantees that behavior and
// prevents other packages from undoing the result.
//
// The address is attached to QUIC sessions on the server, and the server
// includes the address in calls to WriteTo on its PacketConn. That allows users
// of the QUIC session to request particular behavior from the PacketConn, such
// as sending packets for health-check connections on a single path and sending
// packets for data connections on the best available path.
type uniqueAddr struct {
	net.Addr
}

func (u *uniqueAddr) key() string {
	return fmt.Sprintf("%s/%p", u.Addr, u)
}

type serverPacketConn struct {
	base     net.PacketConn
	listener *Listener

	mu           sync.Mutex
	directReturn map[*uniqueAddr]struct{}
}

var _ net.PacketConn = (*serverPacketConn)(nil)

func (pc *serverPacketConn) WriteTo(p []byte, addr net.Addr) (int, error) {
	// log.Printf("addr %v %p %T len %d", addr, addr, addr, len(p))

	ua, ok := addr.(*uniqueAddr)
	if !ok {
		return 0, fmt.Errorf("cannot operate on address type %t", addr)
	}
	addr = ua.Addr

	// if len(p) < 500 {
	// 	// This packet is probably for a health check connection. Send it only
	// 	// to the requested address.
	// 	//
	// 	// TODO: Do better.
	// 	return pc.base.WriteTo(p, addr)
	// }

	pc.mu.Lock()
	_, directReturn := pc.directReturn[ua]
	pc.mu.Unlock()
	if directReturn {
		log.Printf("direct return")
		return pc.base.WriteTo(p, addr)
	}

	// log.Printf("non-direct")

	pc.listener.mu.Lock()
	defer pc.listener.mu.Unlock()

	id, ok := pc.listener.remoteIDs[addr.String()]
	if !ok {
		// log.Printf("no known remote health checker for address %q", addr)
		return pc.base.WriteTo(p, addr)
	}
	remotes, ok := pc.listener.idRemotes[id]
	if !ok || len(remotes) == 0 {
		// log.Printf("no known remote addresses for health checker id %q", id)
		return pc.base.WriteTo(p, addr)
	}

	pick := picker.New(len(remotes), nil)

	var n int
	var err error
	for i := 0; i < 1; i++ {
		j := pick.Pick()
		if j >= 0 {
			n, err = pc.base.WriteTo(p, remotes[j].Addr)
		}
	}

	return n, err
}

func (pc *serverPacketConn) ReadFrom(p []byte) (n int, addr net.Addr, err error) {
	n, addr, err = pc.base.ReadFrom(p)
	addr = &uniqueAddr{Addr: addr}
	return n, addr, err
}

func (pc *serverPacketConn) Close() error {
	return pc.base.Close()
}

func (pc *serverPacketConn) LocalAddr() net.Addr {
	return pc.base.LocalAddr()
}

func (pc *serverPacketConn) SetDeadline(t time.Time) error {
	return pc.base.SetDeadline(t)
}

func (pc *serverPacketConn) SetReadDeadline(t time.Time) error {
	return pc.base.SetReadDeadline(t)
}

func (pc *serverPacketConn) SetWriteDeadline(t time.Time) error {
	return pc.base.SetWriteDeadline(t)
}

func Listen(conn net.PacketConn, tlsConf *tls.Config, config *quic.Config) (*Listener, error) {
	l := &Listener{
		conns:        make(chan listenConn),
		remoteIDs:    make(map[string]string),
		idRemotes:    make(map[string][]*uniqueAddr),
		directReturn: make(map[*uniqueAddr]struct{}),
	}

	srvConn := &serverPacketConn{
		base:     conn,
		listener: l,
	}

	ql, err := quic.Listen(srvConn, tlsConf, config)
	if err != nil {
		return nil, err
	}

	l.base = ql

	l.ctx, l.cancel = context.WithCancel(context.Background())
	go l.run()

	return l, nil
}

type Listener struct {
	base quic.Listener

	mu           sync.Mutex
	ctx          context.Context
	cancel       context.CancelFunc
	conns        chan listenConn
	remoteIDs    map[string]string
	idRemotes    map[string][]*uniqueAddr
	directReturn map[*uniqueAddr]struct{}
}

var _ net.Listener = (*Listener)(nil)

func (l *Listener) Accept() (net.Conn, error) {
	res, ok := <-l.conns
	if !ok {
		return nil, io.EOF
	}
	return res.conn, res.err
}

func (l *Listener) Close() error {
	l.mu.Lock()

	if err := l.ctx.Err(); err != nil {
		l.mu.Unlock()
		return err
	}

	close(l.conns)
	l.cancel()
	base := l.base

	l.mu.Unlock()

	return base.Close()
}

func (l *Listener) Addr() net.Addr {
	return l.base.Addr()
}

func (l *Listener) run() {
	for {
		sess, err := l.base.Accept(l.ctx)
		if err != nil {
			l.handleError(err)
		}

		switch proto := sess.ConnectionState().NegotiatedProtocol; proto {
		case "multicp-data":
			go l.handleData(sess)
		case "multicp-health":
			go l.handleHealth(sess)
		default:
			sess.CloseWithError(2, "")
		}
	}
}

func (l *Listener) handleError(err error) {
	l.mu.Lock()
	defer l.mu.Unlock()

	if err := l.ctx.Err(); err != nil {
		return
	}

	l.conns <- listenConn{err: err}
}

func (l *Listener) handleData(sess quic.Session) {
	stream, err := sess.AcceptStream(l.ctx)
	if err != nil {
		sess.CloseWithError(4, "")
		return
	}

	l.mu.Lock()
	defer l.mu.Unlock()

	if err := l.ctx.Err(); err != nil {
		return
	}

	conn := &quicConn{
		Stream:     stream,
		localAddr:  sess.LocalAddr(),
		remoteAddr: sess.RemoteAddr(),
	}

	l.conns <- listenConn{conn: conn}
}

func (l *Listener) handleHealth(sess quic.Session) {
	defer sess.CloseWithError(3, "")
	health, err := sess.AcceptStream(l.ctx)
	if err != nil {
		return
	}

	addr, ok := sess.RemoteAddr().(*uniqueAddr)
	if !ok {
		return
	}

	stats := &pathStats{}

	sc := bufio.NewScanner(health)
	realSplit := bufio.ScanLines
	split := func(data []byte, atEOF bool) (advance int, token []byte, err error) {
		return realSplit(data, atEOF)
	}
	sc.Split(split)

	var id string
	for sc.Scan() {
		cmd := sc.Text()
		if id == "" {
			parts := strings.SplitN(cmd, " ", 2)
			switch parts[0] {
			case "ID":
				if len(parts) < 2 || len(parts[1]) < 10 {
					return
				}
				id = parts[1]
				realSplit = bufio.ScanBytes
				l.registerID(id, addr)
			default:
				return
			}
		}

		switch cmd {
		case "0":
			stats.mu.Lock()
			stats.pingStart = time.Now().UTC()
			stats.mu.Unlock()

			_, err := health.Write([]byte("1"))
			if err != nil {
				return
			}
		case "2":
			pong := time.Now().UTC()

			stats.mu.Lock()
			stats.rtt = pong.Sub(stats.pingStart)
			stats.pingStart = time.Time{}
			stats.mu.Unlock()
		}
	}
}

func (l *Listener) registerID(id string, addr *uniqueAddr) {
	l.mu.Lock()
	defer l.mu.Unlock()

	// log.Printf("id %q now at %q", id, addr)
	l.idRemotes[id] = append(l.idRemotes[id], addr)
	l.remoteIDs[addr.Addr.String()] = id
	l.directReturn[addr] = struct{}{}
}

type listenConn struct {
	conn net.Conn
	err  error
}
