package join

import (
	"container/heap"
	"fmt"
	"io"
	"sort"
	"sync"
)

// A Buffer assembles data segments into a stream. Reads start at an offset
// of zero bytes and progress as calls to Read return the stream's data.
//
// Writes may provide data starting at any offset, though writes to offsets that
// have already been read will have no effect. Reads for byte offsets that are
// covered by multiple WriteAt calls may return data from any of those calls.
type Buffer struct {
	cond   *sync.Cond
	base   int64
	closed bool
	data   buffletHeap
}

type bufflet struct {
	payload []byte
	offset  int64
}

// buffletHeap tracks data fragments.
//
// TODO: The datastructure here decides whether the peer can force us to use a
// lot of memory by sending a small amount of data. We need a strategy to defend
// against that.
//
// TODO: Decide how to respond when we get overlapping writes. Does the first or
// the last win, or is it undefined? For good senders the data will be
// identical, but it would be good to have a clear decision on this.
type buffletHeap []*bufflet

var _ heap.Interface = (*buffletHeap)(nil)

func (h *buffletHeap) Len() int {
	return len(*h)
}

func (h *buffletHeap) Swap(i, j int) {
	(*h)[i], (*h)[j] = (*h)[j], (*h)[i]
}

func (h *buffletHeap) Less(i, j int) bool {
	hi, hj := (*h)[i], (*h)[j]
	if hi.offset != hj.offset {
		return hi.offset < hj.offset
	}
	return false
}

func (h *buffletHeap) Push(x interface{}) {
	(*h) = append((*h), x.(*bufflet))
}

func (h *buffletHeap) Pop() interface{} {
	l := len((*h))
	tail := (*h)[l-1]
	(*h) = (*h)[:l-1]
	return tail
}

// NewBuffer returns a Buffer that is ready to use.
func NewBuffer() *Buffer {
	b := &Buffer{
		cond: sync.NewCond(new(sync.Mutex)),
	}
	return b
}

// WriteAt adds data segment p to the buffer at byte offset off.
func (b *Buffer) WriteAt(p []byte, off int64) (int, error) {
	b.cond.L.Lock()
	defer b.cond.L.Unlock()

	if b.closed {
		return 0, fmt.Errorf("closed")
	}
	if b.base >= off+int64(len(p)) {
		// This write is to a part of the stream we've already completely
		// processed through a WriteTo/Read pair. Ignore it.
		return len(p), nil
	}
	if len(p) == 0 {
		return 0, nil
	}

	buf := &bufflet{
		offset:  off,
		payload: make([]byte, len(p)),
	}
	copy(buf.payload, p)

	heap.Push(&b.data, buf)
	b.cond.Signal()

	return len(p), nil
}

// Read consumes any available data into p. If no data is available in the
// stream, it blocks until it is, or until the Buffer is closed.
func (b *Buffer) Read(p []byte) (int, error) {
	b.cond.L.Lock()
	defer b.cond.L.Unlock()

	for {
		if len(b.data) > 0 && b.data[0].offset <= b.base {
			// Some data is available at/before the base offset.
			buf := heap.Pop(&b.data).(*bufflet)
			bufEnd := buf.offset + int64(len(buf.payload))
			bytesReady := bufEnd - b.base
			if bytesReady <= 0 {
				// The entire bufflet is a duplicate of data we've already
				// returned. Discard it and try again.
				continue
			}

			// Discard any bytes offsets that we've already returned.
			overlap := b.base - buf.offset
			buf.payload = buf.payload[overlap:]
			buf.offset += int64(overlap)

			if bytesReady > 0 {
				// There is data immediately following the base offset.

				n := copy(p, buf.payload[:int(bytesReady)])
				buf.payload = buf.payload[n:]
				buf.offset += int64(n)
				b.base += int64(n)
				if len(buf.payload) > 0 {
					// This call didn't consume the whole segment, save the
					// remainder.
					heap.Push(&b.data, buf)
				}
				return n, nil
			}
		}

		if b.closed {
			return 0, io.EOF
		}
		b.cond.Wait()
	}
}

// Close closes the buffer. Future calls to WriteAt will return an error. Active
// calls to Read that are blocked, and future calls that would block, return an
// EOF.
func (b *Buffer) Close() error {
	b.cond.L.Lock()
	b.closed = true
	b.cond.Broadcast()
	b.cond.L.Unlock()
	return nil
}

// ContiguousBytes returns the number of sequential bytes that the Buffer
// has received starting at an offset of 0.
func (b *Buffer) ContiguousBytes() int64 {
	b.cond.L.Lock()
	defer b.cond.L.Unlock()

	// TODO: build a datastructure where this has a fast implementation

	sorted := make([]*bufflet, len(b.data))
	copy(sorted, b.data)

	sort.Slice(sorted, func(i, j int) bool {
		si, sj := sorted[i], sorted[j]
		return si.offset < sj.offset
	})

	end := b.base
	for _, buf := range sorted {
		if buf.offset > end {
			// There's a gap, we're done.
			break
		}

		bufEnd := buf.offset + int64(len(buf.payload))
		if bufEnd > end {
			end = bufEnd
		}
	}

	return end
}
