package rtmp

import (
	"bytes"
	"encoding/binary"
	"fmt"
	"io"
)

type chunkError struct {
	fmt uint8
}

func (ce *chunkError) Error() string {
	return fmt.Sprintf("chunk error: invalid header fmt %d", ce.fmt)
}

type ChunkStreamReader interface {
	Read() (*RawMessage, error)
}

type inChunkStream struct {
	id                    uint32
	lastHeader            *Header
	lastAbsoluteTimestamp uint32

	currentMessage *RawMessage
}

type inChunkStreamMap map[uint32]*inChunkStream

func (csm inChunkStreamMap) Get(csid uint32) *inChunkStream {
	cs, ok := csm[csid]
	if !ok {
		csm[csid] = &inChunkStream{id: csid}
		return csm[csid]
	}
	return cs
}

type chunkStreamReader struct {
	r, tr io.Reader

	// due to chunk stream bugs in other tools
	// we will create a rewind buffer here
	// it will be used to 'retry' a chunk read
	rewind *bytes.Buffer

	chunkSize    uint32
	chunkStreams inChunkStreamMap
}

func NewChunkStreamReader(r io.Reader) ChunkStreamReader {
	rewind := &bytes.Buffer{}
	return &chunkStreamReader{
		r:            r,
		tr:           io.TeeReader(r, rewind),
		rewind:       rewind,
		chunkSize:    DEFAULT_CHUNK_SIZE,
		chunkStreams: make(inChunkStreamMap),
	}
}

// read chunks until we have a full message
func (csr *chunkStreamReader) Read() (*RawMessage, error) {
	for {
		msg, err := csr.readChunk()
		if err != nil {
			return nil, err
		}

		if msg != nil {
			return msg, nil
		}
	}
}

func (csr *chunkStreamReader) readHeader(r io.Reader) (*Header, error) {
	vfmt, csi, err := csr.readBaseHeader(r)
	if err != nil {
		return nil, err
	}

	chunkStream := csr.chunkStreams.Get(csi)

	header := &Header{
		ChunkStreamID: csi,
		Fmt:           vfmt,
	}

	// if we get anything other than a continuation chunk mid message, throw an error rather than attempting to parse
	if chunkStream.currentMessage != nil && header.Fmt != HEADER_FMT_CONTINUATION {
		return nil, &chunkError{header.Fmt}
	}

	// read in all the values per the rtmp spec
	switch header.Fmt {
	case HEADER_FMT_FULL:
		buf := make([]byte, 11)
		if _, err := io.ReadFull(r, buf); err != nil {
			return nil, err
		}
		header.Timestamp = Uint24(buf[0:3])
		header.MessageLength = Uint24(buf[3:6])
		header.MessageTypeID = uint8(buf[6])
		header.MessageStreamID = binary.LittleEndian.Uint32(buf[7:])

		chunkStream.lastHeader = header
	case HEADER_FMT_SAME_STREAM:
		buf := make([]byte, 7)
		if _, err = io.ReadFull(r, buf); err != nil {
			return nil, err
		}
		header.Timestamp = Uint24(buf[:3])
		header.MessageLength = Uint24(buf[3:6])
		header.MessageTypeID = uint8(buf[6])

		if chunkStream.lastHeader == nil {
			return nil, &chunkError{header.Fmt}
		}

		header.MessageStreamID = chunkStream.lastHeader.MessageStreamID
		chunkStream.lastHeader = header
	case HEADER_FMT_SAME_LENGTH_AND_STREAM:
		buf := make([]byte, 3)
		if _, err := io.ReadFull(r, buf); err != nil {
			return nil, err
		}
		header.Timestamp = Uint24(buf)

		if chunkStream.lastHeader == nil {
			return nil, &chunkError{header.Fmt}
		}

		header.MessageStreamID = chunkStream.lastHeader.MessageStreamID
		header.MessageLength = chunkStream.lastHeader.MessageLength
		header.MessageTypeID = chunkStream.lastHeader.MessageTypeID
		chunkStream.lastHeader = header
	case HEADER_FMT_CONTINUATION:
		if chunkStream.lastHeader == nil {
			return nil, &chunkError{header.Fmt}
		}

		header.MessageStreamID = chunkStream.lastHeader.MessageStreamID
		header.MessageLength = chunkStream.lastHeader.MessageLength
		header.MessageTypeID = chunkStream.lastHeader.MessageTypeID
		header.Timestamp = chunkStream.lastHeader.Timestamp
		header.ExtendedTimestamp = chunkStream.lastHeader.ExtendedTimestamp
		chunkStream.lastHeader = header
	default:
		return nil, &chunkError{header.Fmt}
	}

	// this is here because some rtmp implementations incorrectly leave out
	// an extended timestamp field here when they should not.  resetting the
	// rewind buffer here means that if we fail to parse the next chunk, we can
	// back up and re-try parsing from this point forward.
	csr.rewind.Reset()

	// determine which timestamp to use when calculating
	// the absolute timestamp, and read in the extended
	// timestamp if there is one
	timestamp := header.Timestamp
	if header.Timestamp == EXTENDED_TIMESTAMP {
		buf := make([]byte, 4)
		if _, err := io.ReadFull(r, buf); err != nil {
			return nil, err
		}

		extTs := binary.BigEndian.Uint32(buf)

		// ignore this timestamp and instead use the previously
		// received timestamp in the case of a continuation chunk
		if header.Fmt != HEADER_FMT_CONTINUATION {
			header.ExtendedTimestamp = extTs
			timestamp = extTs
		} else {
			timestamp = header.ExtendedTimestamp
		}
	}

	// calculate this chunk's timestamp
	switch header.Fmt {
	case HEADER_FMT_FULL:
		header.AbsoluteTimestamp = timestamp
	case HEADER_FMT_SAME_STREAM:
		fallthrough
	case HEADER_FMT_SAME_LENGTH_AND_STREAM:
		fallthrough
	case HEADER_FMT_CONTINUATION:
		header.AbsoluteTimestamp = chunkStream.lastAbsoluteTimestamp + timestamp
	}

	// carry over current timestamp if we're mid message
	if chunkStream.currentMessage != nil {
		header.AbsoluteTimestamp = chunkStream.currentMessage.Timestamp
	}

	return header, nil
}

func (csr *chunkStreamReader) readChunk() (*RawMessage, error) {
	header, err := csr.readHeader(csr.tr)

	// if chunk error, try again with rewind
	if _, ok := err.(*chunkError); ok {
		var rewindErr error
		// completely reset the rewind buffer
		// and keep the current buffer
		buf := make([]byte, csr.rewind.Len())
		copy(buf, csr.rewind.Bytes())

		// re-try the readChunk operation with a multireader
		// so that the rewind buffer is considered before the
		// rest of our data
		header, rewindErr = csr.readHeader(io.MultiReader(bytes.NewReader(buf), csr.r))

		// attempted rewind failed!
		if rewindErr != nil {
			err = fmt.Errorf("%s, rewind failed with %s", err, rewindErr)
		}
	}

	// now check the error for real
	if err != nil {
		return nil, err
	}

	chunkStream := csr.chunkStreams.Get(header.ChunkStreamID)
	var message *RawMessage

	if chunkStream.currentMessage == nil {
		message = &RawMessage{
			ChunkStreamID: header.ChunkStreamID,
			Type:          header.MessageTypeID,
			Timestamp:     header.AbsoluteTimestamp,
			StreamID:      header.MessageStreamID,
			Data:          &bytes.Buffer{},
		}
	} else {
		message = chunkStream.currentMessage
	}

	remain := header.MessageLength - uint32(message.Data.Len())
	lastChunk := true

	if remain > csr.chunkSize {
		lastChunk = false
		remain = csr.chunkSize
	}

	_, err = io.CopyN(message.Data, csr.r, int64(remain))
	if err != nil {
		return nil, err
	}

	if lastChunk {
		chunkStream.lastAbsoluteTimestamp = header.AbsoluteTimestamp
		chunkStream.currentMessage = nil
		if message.ChunkStreamID == CS_ID_PROTOCOL_CONTROL && message.Type == SET_CHUNK_SIZE {
			if message.Data.Len() != 4 {
				return nil, ErrInvalidChunk
			}
			csr.chunkSize = binary.BigEndian.Uint32(message.Data.Bytes())
		}
	} else {
		chunkStream.currentMessage = message
		message = nil
	}

	return message, nil
}

func (csr *chunkStreamReader) readBaseHeader(r io.Reader) (uint8, uint32, error) {
	var fmt uint8
	var csi uint32

	b := make([]byte, 2)
	if _, err := io.ReadFull(r, b[:1]); err != nil {
		return 0, 0, err
	}

	fmt = uint8(b[0] >> 6)
	b[0] = b[0] & 0x3f

	switch b[0] {
	case 0:
		// Chunk stream IDs 64-319 can be encoded in the 2-byte version of this
		// field. ID is computed as (the second byte + 64).
		if _, err := io.ReadFull(r, b[:1]); err != nil {
			return 0, 0, err
		}
		csi = uint32(64) + uint32(b[0])
	case 1:
		// Chunk stream IDs 64-65599 can be encoded in the 3-byte version of
		// this field. ID is computed as ((the third byte)*256 + the second byte
		// + 64).
		if _, err := io.ReadFull(r, b); err != nil {
			return 0, 0, err
		}

		csi = uint32(64) + uint32(b[0]) + uint32(b[1])*256
	default:
		// Chunk stream IDs 2-63 can be encoded in the 1-byte version of this
		// field.
		csi = uint32(b[0])
	}

	return fmt, csi, nil
}
