package rtmp

import (
	"bytes"
	crand "crypto/rand"
	"fmt"
	"io"
	mrand "math/rand"
	"testing"
)

func mustRawMessage(msg Message) *RawMessage {
	raw, err := msg.RawMessage()
	if err != nil {
		panic(err)
	}

	return raw
}

// this does a full loop test, in an attempt to make sure that
// gortmp is at least internally consistent

func randomPayload(length int) []byte {
	payload := make([]byte, length)
	if _, err := crand.Read(payload); err != nil {
		panic(err)
	}
	return payload
}

func randomType() uint8 {
	return uint8(mrand.Intn(0xff-25) + 25)
}

func randomStreamID() uint32 {
	return mrand.Uint32()
}

func randomChunkStreamID() uint32 {
	return uint32(mrand.Intn(65599-2) + 2)
}

func randomTimestamp() uint32 {
	return mrand.Uint32() & 0xfffffe
}

// returns a timestamp that is always > 0xffffff
func randomExtendedTimestamp() uint32 {
	return mrand.Uint32() | 0x1000000
}

func compareMessages(a, b *RawMessage) error {
	headerFail := fmt.Errorf("header mismatch: %+v != %+v", a, b)
	payloadFail := fmt.Errorf("payload mismatch: %+v != %+v", a.Data.Bytes(), b.Data.Bytes())

	if a.ChunkStreamID != b.ChunkStreamID {
		return headerFail
	}

	if a.StreamID != b.StreamID {
		return headerFail
	}

	if a.Timestamp != b.Timestamp {
		return headerFail
	}

	if a.Type != b.Type {
		return headerFail
	}

	if bytes.Compare(a.Data.Bytes(), b.Data.Bytes()) != 0 {
		return payloadFail
	}

	return nil
}

// 1. send a message larger than the chunk size
// 2. send a message smaller than the chunk size
// 3. set the chunk size
// 4. send a message larger than the old chunk size, smaller than the new
// 5. send a message larger than the new chunk size
func TestSetChunkSizeLoopback(t *testing.T) {
	pr, pw := io.Pipe()
	csr := NewChunkStreamReader(pr)
	csw := NewChunkStreamWriter(pw)

	messages := []*RawMessage{
		&RawMessage{
			ChunkStreamID: CS_ID_DATA,
			Type:          uint8(mrand.Uint32()),
			StreamID:      mrand.Uint32(),
			Data:          bytes.NewBuffer(randomPayload(int(DEFAULT_CHUNK_SIZE - 1))),
		}, &RawMessage{
			ChunkStreamID: CS_ID_DATA,
			Type:          uint8(mrand.Uint32()),
			StreamID:      mrand.Uint32(),
			Data:          bytes.NewBuffer(randomPayload(int(DEFAULT_CHUNK_SIZE + 1))),
		},
		mustRawMessage(SetChunkSizeMessage{DEFAULT_CHUNK_SIZE * 2}),
		&RawMessage{
			ChunkStreamID: CS_ID_DATA,
			Type:          uint8(mrand.Uint32()),
			StreamID:      mrand.Uint32(),
			Data:          bytes.NewBuffer(randomPayload(int(DEFAULT_CHUNK_SIZE*2 - 1))),
		}, &RawMessage{
			ChunkStreamID: CS_ID_DATA,
			Type:          uint8(mrand.Uint32()),
			StreamID:      mrand.Uint32(),
			Data:          bytes.NewBuffer(randomPayload(int(DEFAULT_CHUNK_SIZE*2 + 1))),
		},
	}

	readDone := make(chan error, 1)
	writeDone := make(chan error, 1)
	go func() {
		defer pw.Close()
		for _, message := range messages {
			if err := csw.Write(message); err != nil {
				writeDone <- err
				return
			}
		}
		writeDone <- nil
	}()

	go func() {
		defer pr.Close()
		for _, message := range messages {
			msg, err := csr.Read()
			if err != nil {
				readDone <- err
				return
			}

			if err := compareMessages(message, msg); err != nil {
				readDone <- err
				return
			}
		}
		readDone <- nil
	}()

	for i := 0; i < 2; i++ {
		select {
		case err := <-readDone:
			if err != nil {
				t.Fatalf("read failure: %s", err)
			}
		case err := <-writeDone:
			if err != nil {
				t.Fatalf("write failure: %s", err)
			}
		}
	}
}
