package rtmp

import (
	"bytes"
	"encoding/binary"
	"fmt"
	"io"
)

type ChunkStreamWriter interface {
	Write(*RawMessage) error
}

type chunkStreamWriter struct {
	w io.Writer

	chunkSize    uint32
	chunkStreams outChunkStreamMap
}

type outChunkStream struct {
	id         uint32
	lastHeader *Header
}

type outChunkStreamMap map[uint32]*outChunkStream

func (m outChunkStreamMap) Get(csid uint32) (*outChunkStream, error) {
	if csid < 2 {
		return nil, fmt.Errorf("invalid chunk stream id: %d", csid)
	}

	cs, ok := m[csid]
	if !ok {
		m[csid] = &outChunkStream{
			id: csid,
		}
		return m[csid], nil
	}
	return cs, nil
}

func NewChunkStreamWriter(w io.Writer) ChunkStreamWriter {
	return &chunkStreamWriter{
		w:            w,
		chunkSize:    DEFAULT_CHUNK_SIZE,
		chunkStreams: make(outChunkStreamMap),
	}
}

func (csw *chunkStreamWriter) Write(msg *RawMessage) error {
	if msg.Data.Len() > MAX_MESSAGE_LENGTH {
		return fmt.Errorf("cannot write packet length > 24 bits, len = %d", msg.Data.Len())
	}

	cs, err := csw.chunkStreams.Get(msg.ChunkStreamID)
	if err != nil {
		return err
	}

	buf := bytes.NewBuffer(msg.Data.Bytes())

	header, err := csw.buildHeader(msg)
	if err != nil {
		return err
	}

	if err := csw.writeChunk(header, buf); err != nil {
		return err
	}

	header.Fmt = HEADER_FMT_CONTINUATION
	for buf.Len() > 0 {
		if err := csw.writeChunk(header, buf); err != nil {
			return err
		}
	}

	cs.lastHeader = header

	return nil
}

// can only construct correct headers for new messages
// not continuation chunks for partially written messages
func (csw *chunkStreamWriter) buildHeader(msg *RawMessage) (*Header, error) {
	cs, err := csw.chunkStreams.Get(msg.ChunkStreamID)
	if err != nil {
		return nil, err
	}

	prev := cs.lastHeader
	next := &Header{
		ChunkStreamID:     msg.ChunkStreamID,
		AbsoluteTimestamp: msg.Timestamp,
		MessageLength:     uint32(msg.Data.Len()),
		MessageTypeID:     msg.Type,
		MessageStreamID:   msg.StreamID,
	}

	// need to figure out if we should be comparing against
	// extended or regular timestamp
	var prevTs, nextDelta uint32

	switch {
	case prev == nil:
		prevTs = 0
		nextDelta = next.AbsoluteTimestamp
	case prev.Timestamp == EXTENDED_TIMESTAMP:
		prevTs = prev.ExtendedTimestamp
		nextDelta = next.AbsoluteTimestamp - prev.AbsoluteTimestamp
	default:
		prevTs = prev.Timestamp
		nextDelta = next.AbsoluteTimestamp - prev.AbsoluteTimestamp
	}

	switch {
	case prev == nil:
		fallthrough
	case prev.AbsoluteTimestamp > next.AbsoluteTimestamp:
		fallthrough
	case next.ChunkStreamID != CS_ID_DATA:
		fallthrough
	case prev.MessageStreamID != next.MessageStreamID:
		// all above cases require a full header
		next.Fmt = HEADER_FMT_FULL
		next.Timestamp = next.AbsoluteTimestamp
	case prev.MessageTypeID != next.MessageTypeID || prev.MessageLength != next.MessageLength:
		// stream id must match if we're here, so that's the best we can do
		next.Fmt = HEADER_FMT_SAME_STREAM
		next.Timestamp = nextDelta
	case prevTs != nextDelta:
		// length/type/stream all match, but ts does not
		next.Fmt = HEADER_FMT_SAME_LENGTH_AND_STREAM
		next.Timestamp = nextDelta
	default:
		// everything matches!
		next.Fmt = HEADER_FMT_CONTINUATION
		next.Timestamp = nextDelta
	}

	if next.Timestamp >= EXTENDED_TIMESTAMP {
		next.ExtendedTimestamp = next.Timestamp
		next.Timestamp = EXTENDED_TIMESTAMP
	}

	return next, nil
}

// write header and drain buf for chunkSize bytes
func (csw *chunkStreamWriter) writeChunk(header *Header, buf *bytes.Buffer) error {
	if err := csw.writeHeader(header); err != nil {
		return err
	}

	var newChunkSize uint32
	if header.MessageTypeID == SET_CHUNK_SIZE && buf.Len() == 4 {
		newChunkSize = binary.BigEndian.Uint32(buf.Bytes())
	}

	remain := uint32(buf.Len())
	if remain > csw.chunkSize {
		remain = csw.chunkSize
	}

	if _, err := io.CopyN(csw.w, buf, int64(remain)); err != nil {
		return err
	}

	if newChunkSize != 0 {
		csw.chunkSize = newChunkSize
	}

	return nil
}

func (csw *chunkStreamWriter) writeHeader(header *Header) error {
	if err := csw.writeBaseHeader(header); err != nil {
		return err
	}

	var length int
	var buf []byte

	// reserve 4 bytes for extended timestamp
	if header.Timestamp == EXTENDED_TIMESTAMP {
		length = 4
	}

	switch header.Fmt {
	case HEADER_FMT_FULL:
		buf = make([]byte, length+11)
		PutUint24(buf[0:], header.Timestamp)
		PutUint24(buf[3:], header.MessageLength)
		buf[6] = header.MessageTypeID
		binary.LittleEndian.PutUint32(buf[7:], header.MessageStreamID)
	case HEADER_FMT_SAME_STREAM:
		buf = make([]byte, length+7)
		PutUint24(buf[0:], header.Timestamp)
		PutUint24(buf[3:], header.MessageLength)
		buf[6] = header.MessageTypeID
	case HEADER_FMT_SAME_LENGTH_AND_STREAM:
		buf = make([]byte, length+3)
		PutUint24(buf[0:], header.Timestamp)
	case HEADER_FMT_CONTINUATION:
		buf = make([]byte, length)
	default:
		return fmt.Errorf("invalid header format: %d", header.Fmt)
	}

	if header.Timestamp == EXTENDED_TIMESTAMP {
		binary.BigEndian.PutUint32(buf[len(buf)-4:], header.ExtendedTimestamp)
	}

	if _, err := csw.w.Write(buf); err != nil {
		return err
	}

	return nil
}

func (csw *chunkStreamWriter) writeBaseHeader(header *Header) error {
	var buf []byte
	switch {
	case header.ChunkStreamID <= 63:
		buf = append(buf, byte(header.Fmt<<6|uint8(header.ChunkStreamID)))
	case header.ChunkStreamID <= 319:
		buf = append(buf, byte(header.Fmt<<6))
		buf = append(buf, byte(header.ChunkStreamID-64))
	case header.ChunkStreamID <= 65599:
		buf = append(buf, byte(header.Fmt<<6)|1)
		buf = append(buf, byte(header.ChunkStreamID-64))
		buf = append(buf, byte((header.ChunkStreamID-64)>>8))
	default:
		return fmt.Errorf("invalid chunk stream id")
	}

	if _, err := csw.w.Write(buf); err != nil {
		return err
	}

	return nil
}
