package pipeclient

import (
	"bytes"
	"context"
	"io"
	"net"
	"net/http/httptrace"
	"os"
	"runtime/trace"
	"sync"
	"time"

	"code.justin.tv/rhys/nursery/cmd/multicp/conn"
	"code.justin.tv/rhys/nursery/cmd/multicp/join"
	"code.justin.tv/rhys/nursery/cmd/multicp/netpipe"
	"github.com/golang/protobuf/ptypes"
	"golang.org/x/time/rate"
)

const (
	netpipeTimeout = 200 * time.Millisecond
)

type Client struct {
	dest netpipe.NetPipe
}

func NewClient(svc netpipe.NetPipe) *Client {
	return &Client{dest: svc}
}

func (c *Client) Open(ctx context.Context, target string) (*Pipe, error) {
	ctx, cancel := context.WithTimeout(ctx, netpipeTimeout)
	defer cancel()
	ctx, task := trace.NewTask(ctx, "pipeclient.Client.Open")
	defer task.End()

	var localAddr string
	ctx = httptrace.WithClientTrace(ctx, &httptrace.ClientTrace{
		GotConn: func(info httptrace.GotConnInfo) {
			localAddr = info.Conn.LocalAddr().String()
		},
	})

	req := &netpipe.CreatePipeRequest{
		Target:            target,
		KeepaliveDuration: ptypes.DurationProto(30 * time.Second),
	}
	var resp *netpipe.Pipe
	var err error
	trace.WithRegion(ctx, "client/netpipe.CreatePipe", func() {
		resp, err = c.dest.CreatePipe(ctx, req)
		if err != nil {
			trace.Logf(ctx, "err", "%q", err)
		}
	})
	if err != nil {
		return nil, err
	}

	pipe := &Pipe{
		c: c,

		target:            target,
		name:              resp.GetName(),
		maxMessageBytes:   int(resp.GetMaxMessageBytes()),
		localAddrOnCreate: localAddr,

		join: join.NewBuffer(),
	}
	pipe.writeCond = sync.NewCond(&pipe.mu)
	pipe.writeLoopCond = sync.NewCond(&pipe.mu)
	pipe.readCond = sync.NewCond(&pipe.mu)
	pipe.readLoopCond = sync.NewCond(&pipe.mu)

	pipe.keepaliveDuration, err = ptypes.Duration(resp.GetKeepaliveDuration())
	if err != nil {
		trace.Logf(ctx, "err", "%q", err)
		return nil, err
	}

	go pipe.writeLoop()
	go pipe.readLoop()

	trace.Logf(ctx, "client/netpipe.CreatePipe",
		"pointer=%p name=%q MaxMessageBytes=%d",
		pipe, pipe.name, pipe.maxMessageBytes)
	return pipe, nil
}

type Pipe struct {
	c *Client

	target            string
	name              string
	maxMessageBytes   int
	keepaliveDuration time.Duration
	localAddrOnCreate string

	join *join.Buffer

	mu                  sync.Mutex
	writeCond           *sync.Cond
	writeLoopCond       *sync.Cond
	readCond            *sync.Cond
	readLoopCond        *sync.Cond
	buf                 bytes.Buffer
	offset              int64
	committedReadOffset int64
	closed              bool
}

func (p *Pipe) Close() error {
	ctx := context.Background()
	ctx, task := trace.NewTask(ctx, "pipeclient.Pipe.Close")
	defer task.End()
	trace.Logf(ctx, "pointer", "%p", p)

	trace.WithRegion(ctx, "Lock", p.mu.Lock)
	defer p.mu.Unlock()

	p.closed = true
	p.writeCond.Broadcast()
	p.writeLoopCond.Broadcast()
	return nil
}

func (p *Pipe) Write(b []byte) (int, error) {
	ctx := context.Background()
	ctx, task := trace.NewTask(ctx, "pipeclient.Pipe.Write")
	defer task.End()
	trace.Logf(ctx, "pointer", "%p", p)

	trace.WithRegion(ctx, "Lock", p.mu.Lock)
	defer p.mu.Unlock()

	trace.WithRegion(ctx, "Wait", func() {
		for !p.lockedWriteReady(ctx) {
			p.writeCond.Wait()
		}
	})

	if p.closed {
		return 0, os.ErrClosed
	}

	n, err := p.buf.Write(b)

	p.writeLoopCond.Signal()
	if p.lockedWriteReady(ctx) {
		p.writeCond.Signal()
	}

	return n, err
}

func (p *Pipe) lockedWriteReady(ctx context.Context) bool {
	// TODO: coordinate with server
	// TODO: split huge writes
	const maxBufferSize = 16 << 10

	ready := (p.buf.Len() < maxBufferSize) || p.closed
	trace.Logf(ctx, "lockedWriteReady",
		"ready=%t p.buf.Len=%d p.closed=%t",
		ready, p.buf.Len(), p.closed)
	return ready
}

func (p *Pipe) lockedWriteLoopReady(ctx context.Context) bool {
	ready := (p.buf.Len() > 0) || p.closed
	trace.Logf(ctx, "lockedWriteLoopReady",
		"ready=%t p.buf.Len=%d p.closed=%t",
		ready, p.buf.Len(), p.closed)
	return ready
}

func (p *Pipe) Read(b []byte) (int, error) {
	ctx := context.Background()
	ctx, task := trace.NewTask(ctx, "pipeclient.Pipe.Read")
	defer task.End()
	trace.Logf(ctx, "pointer", "%p", p)

	trace.WithRegion(ctx, "Lock", p.mu.Lock)
	defer p.mu.Unlock()

	trace.WithRegion(ctx, "Wait", func() {
		for !p.lockedReadReady(ctx) {
			p.readCond.Wait()
		}
	})

	if p.closed {
		return 0, io.EOF
	}

	n, err := p.join.Read(b)

	p.readLoopCond.Signal()
	if p.lockedReadReady(ctx) {
		p.readCond.Signal()
	}

	return n, err
}

func (p *Pipe) lockedReadReady(ctx context.Context) bool {
	available := p.join.ContiguousBytes()
	returned := p.committedReadOffset

	ready := (available > returned) || p.closed
	trace.Logf(ctx, "lockedReadReady",
		"ready=%t available=%d returned=%d p.closed=%t",
		ready, available, returned, p.closed)
	return ready
}

func (p *Pipe) lockedReadLoopReady(ctx context.Context) bool {
	// TODO: limit buffered data
	ready := true || p.closed
	trace.Logf(ctx, "lockedReadLoopReady",
		"ready=%t p.closed=%t",
		ready, p.closed)
	return ready
}

func (p *Pipe) sendDelete(ctx context.Context) {
	ctx, cancel := context.WithTimeout(ctx, netpipeTimeout)
	defer cancel()
	ctx, task := trace.NewTask(ctx, "pipeclient.Pipe.sendDelete")
	defer task.End()
	trace.Logf(ctx, "pointer", "%p", p)

	trace.WithRegion(ctx, "client/netpipe.DeletePipe", func() {
		req := &netpipe.DeletePipeRequest{Name: p.name}
		_, err := p.c.dest.DeletePipe(ctx, req)
		if err != nil {
			trace.Logf(ctx, "err", "%q", err)
		}
	})
}

func (p *Pipe) recvData(ctx context.Context) {
	ctx, cancel := context.WithTimeout(ctx, netpipeTimeout)
	defer cancel()
	ctx, task := trace.NewTask(ctx, "pipeclient.Pipe.recvData")
	defer task.End()
	trace.Logf(ctx, "pointer", "%p", p)

	const bufferSize = 1 << 10

	req := &netpipe.ReadRequest{
		Name:            p.name,
		MaxReadBytes:    bufferSize,
		TimeoutDuration: ptypes.DurationProto(netpipeTimeout / 2),
	}

	var resp *netpipe.ReadResponse
	var err error
	trace.WithRegion(ctx, "client/netpipe.Read", func() {
		// TODO: select which TCP connection to use
		resp, err = p.c.dest.Read(ctx, req)
		if err != nil {
			trace.Logf(ctx, "err", "%q", err)
		}
	})

	n, err := p.join.WriteAt(resp.GetReadData(), resp.GetReadOffset())
	if resp.GetEndOfFile() {
		_ = p.join.Close()
		trace.WithRegion(ctx, "Lock", p.mu.Lock)
		p.closed = true
		p.mu.Unlock()
	}
	if err != nil {
		trace.Logf(ctx, "err", "%q", err)
		return
	}

	trace.WithRegion(ctx, "Lock", p.mu.Lock)
	if n > 0 && p.lockedReadReady(ctx) {
		p.readCond.Signal()
	}
	p.mu.Unlock()

	readOffset := p.join.ContiguousBytes()
	trace.WithRegion(ctx, "Lock", p.mu.Lock)
	prev := p.committedReadOffset
	p.mu.Unlock()

	if n > 0 && prev < readOffset {
		p.sendData(ctx, nil, 0, readOffset)
	}

	trace.WithRegion(ctx, "Lock", p.mu.Lock)
	if readOffset > p.committedReadOffset {
		p.committedReadOffset = readOffset
	}
	p.mu.Unlock()

	// if n > 0 && p.lockedWriteLoopReady(ctx) {
	// 	// Send the updated offset to the server
	// 	// TODO: maybe split this kind of write call into its own separate loop?
	// 	p.writeLoopCond.Signal()
	// }
}

func (p *Pipe) readLoop() {
	ctx := context.Background()
	ctx, task := trace.NewTask(ctx, "pipeclient.Pipe.readLoop")
	defer task.End()
	trace.Logf(ctx, "pointer", "%p", p)

	// TODO: pull out rate limiter config
	callLim := rate.NewLimiter(rate.Limit(10e3), 10)

	// bufferSize := 1 << 10
	// if max := p.maxMessageBytes; bufferSize > max {
	// 	bufferSize = max
	// }

	// buf := make([]byte, bufferSize)
	for {
		time.Sleep(callLim.Reserve().Delay())

		trace.WithRegion(ctx, "Lock", p.mu.Lock)
		trace.WithRegion(ctx, "Wait", func() {
			for !p.lockedReadLoopReady(ctx) {
				p.readLoopCond.Wait()
			}
		})
		p.mu.Unlock()

		p.recvData(ctx)

		trace.WithRegion(ctx, "Lock", p.mu.Lock)
		closed := p.closed
		p.mu.Unlock()

		if closed {
			return
		}
	}
}

func (p *Pipe) sendData(ctx context.Context, b []byte, offset int64, readOffset int64) {
	ctx, cancel := context.WithTimeout(ctx, netpipeTimeout)
	defer cancel()
	ctx, task := trace.NewTask(ctx, "pipeclient.Pipe.sendData")
	defer task.End()
	trace.Logf(ctx, "pointer", "%p", p)

	req := &netpipe.WriteAtRequest{
		Name:        p.name,
		WriteData:   b,
		WriteOffset: offset,

		CommittedReadOffset: readOffset,
	}

	var resp *netpipe.WriteAtResponse
	var err error
	trace.WithRegion(ctx, "client/netpipe.WriteAt", func() {
		// TODO: select which TCP connection to use
		resp, err = p.c.dest.WriteAt(ctx, req)
		if err != nil {
			trace.Logf(ctx, "err", "%q", err)
		}
	})
	if err != nil {
		// TODO: throttle errors differently?
		return
	}

	p.updateCommit(ctx, resp.GetCommittedWriteOffset())

	trace.WithRegion(ctx, "Lock", p.mu.Lock)
	prev := p.committedReadOffset
	if readOffset > prev {
		p.committedReadOffset = readOffset
	}
	p.mu.Unlock()
}

func (p *Pipe) writeLoop() {
	ctx := context.Background()
	ctx, cancel := context.WithCancel(ctx)
	defer cancel()
	ctx, task := trace.NewTask(ctx, "pipeclient.Pipe.writeLoop")
	defer task.End()
	trace.Logf(ctx, "pointer", "%p", p)

	defer func() {
		// Flush out all of the data we've accepted from Write calls, then
		// delete the pipe.
		trace.Logf(ctx, "sendDelete", "")
		p.sendDelete(ctx)
	}()

	// TODO: pull out rate limiter config
	callLim := rate.NewLimiter(rate.Limit(10e3), 10)

	bufferSize := 1 << 10
	if max := p.maxMessageBytes; bufferSize > max {
		bufferSize = max
	}

	buf := make([]byte, bufferSize)
	for {
		err := callLim.Wait(ctx)
		if err != nil {
			return
		}

		trace.WithRegion(ctx, "Lock", p.mu.Lock)
		trace.WithRegion(ctx, "Wait", func() {
			for !p.lockedWriteLoopReady(ctx) {
				p.writeLoopCond.Wait()
			}
		})
		closed := p.closed
		offset := p.offset
		n := copy(buf, p.buf.Bytes())
		prevOffset := p.committedReadOffset
		p.mu.Unlock()

		readOffset := p.join.ContiguousBytes()

		if n == 0 && closed {
			return
		}

		shouldCall := false
		if n > 0 {
			shouldCall = true
		}
		if readOffset > prevOffset {
			shouldCall = true
		}
		// We'll need to send occasional WriteAt calls to keep the
		// connection alive, and we'll also need the signal of how long it's
		// been since the last call we know of that re-extended the
		// connection's life. That'll require some inspection of error
		// values, and clear thinking on which ones mean a) connection life
		// was extended, b) connection is already closed, or c) no
		// actionable information.

		if !shouldCall {
			continue
		}

		trace.Logf(ctx, "sendData",
			"len=%d offset=%d prevReadOffset=%d readOffset=%d",
			n, offset, prevOffset, readOffset)
		p.sendData(ctx, buf[:n], offset, readOffset)
	}
}

func (p *Pipe) updateCommit(ctx context.Context, offset int64) {
	defer trace.StartRegion(ctx, "pipeclient.Pipe.updateCommit").End()
	trace.WithRegion(ctx, "Lock", p.mu.Lock)
	defer p.mu.Unlock()

	prev := p.offset
	delta := offset - prev
	l := p.buf.Len()

	trace.Logf(ctx, "updateCommit",
		"pointer=%p name=%q offset=%d len=%d new=%d",
		p, p.name, prev, l, offset)

	if delta <= 0 {
		// Nothing new. Maybe we're processing a delayed response?
		return
	}

	if delta > int64(l) {
		// Server is acking data we can't have sent yet. Weird. Ignore.
		return
	}

	// In bounds, commit.
	_ = p.buf.Next(int(delta))
	p.offset = offset

	p.writeCond.Signal()
}

// PipeConn converts a Pipe into a net.Conn.
func PipeConn(p *Pipe) net.Conn {
	// To read data sent by server
	reader := new(conn.TimeoutReader)
	go func() { reader.ReadFrom(p.join) }()
	// To write data to the server
	writer := new(conn.TimeoutWriter)
	go func() { writer.WriteTo(p) }()

	c := &conn.ReadWriter{
		Reader: reader,
		Writer: writer,

		Closers: []io.Closer{p, p.join}, // TODO: confirm this?

		Local:  pipeAddr(p.localAddrOnCreate + ":" + p.name),
		Remote: pipeAddr(p.target),
	}

	return c
}

type pipeAddr string

func (pa pipeAddr) Network() string { return "netpipe" }
func (pa pipeAddr) String() string  { return string(pa) }
