package rtmp

import (
	"fmt"
	"log"
	"net"
	"os"
	"testing"
	"time"

	goctx "context"

	rtmpctx "code.justin.tv/video/gortmp/pkg/context"
	"github.com/valyala/fasthttp/fasthttputil"
)

var (
	logger = log.New(os.Stderr, "", 0)
)

type serverHandler struct {
	publishChan chan struct{}
}

func (sh *serverHandler) OnMediaStreamCreated(ctx goctx.Context, stream MediaStream) {
}

func (sh *serverHandler) OnMediaStreamDestroyed(ctx goctx.Context, stream MediaStream) {
}

func (sh *serverHandler) Handle(ctx goctx.Context, r Receiver, m Message) error {
	if cmd, ok := m.(Command); ok {
		return sh.handleCommand(ctx, r, cmd)
	}
	return r.Handle(m)
}

func (sh *serverHandler) handleCommand(ctx goctx.Context, r Receiver, cmd Command) error {
	switch cmd := cmd.(type) {
	case PublishCommand:
		return sh.handlePublish(ctx, r, cmd)
	}
	return r.Handle(cmd)
}

func (sh *serverHandler) handlePublish(ctx goctx.Context, r Receiver, publish PublishCommand) error {
	_, ok := r.(Stream)
	if !ok {
		// A real RTMP server needs to be able to do this
		return fmt.Errorf("failed to get stream for publish command")
	}

	if err := r.Handle(publish); err != nil {
		return fmt.Errorf("invokePublish error after authorize: %s", err)
	}

	close(sh.publishChan)
	return nil
}

func serverTestWorker(publishChan chan struct{}, conn net.Conn) error {
	serverHandler := &serverHandler{publishChan}
	ms := NewMediaServer(serverHandler)
	err := ms.ServeRTMP(rtmpctx.WithLogger(goctx.Background(), logger), conn)
	return err
}

type flushingConn struct {
	BasicConn
}

func (fc *flushingConn) Write(msg Message) error {
	raw, err := msg.RawMessage()
	if err != nil {
		return err
	}

	if err := fc.BasicConn.Write(raw); err != nil {
		return err
	}

	return fc.BasicConn.Flush()
}

// The real server keeps this as a method in serverConn to get its context, so attaching this here to keep it from creeping out of the test code
func (fc *flushingConn) parseMessage(raw *RawMessage) (Message, error) {
	if raw.ChunkStreamID == CS_ID_PROTOCOL_CONTROL {
		var msg Message
		var err error

		if raw.Type == USER_CONTROL_MESSAGE {
			msg, err = ParseEvent(raw)
		} else {
			msg, err = ParseControlMessage(raw)
		}
		if err != nil {
			err = fmt.Errorf("error parsing message %#v: %s", raw, err)
		}
		return msg, err
	}

	switch raw.Type {
	case COMMAND_AMF0:
		fallthrough
	case COMMAND_AMF3:
		return ParseCommand(raw)
	case USER_CONTROL_MESSAGE:
		return ParseEvent(raw)
	default:
		return raw, nil
	}
}

type publishClient interface {
	Publish(conn net.Conn) error
}

func runServerPublisher(t *testing.T, p publishClient) {
	pc := fasthttputil.NewPipeConns()
	defer pc.Close()

	serverPublishChan := make(chan struct{})
	errCh := make(chan error)
	go func() {
		err := serverTestWorker(serverPublishChan, pc.Conn1())
		if err != nil {
			errCh <- fmt.Errorf("server exit: %s", err)
		}
	}()
	go func() {
		err := p.Publish(pc.Conn2())
		if err != nil {
			errCh <- fmt.Errorf("publisher exit: %s", err)
		}
	}()

	select {
	case <-serverPublishChan:
	case err := <-errCh:
		t.Fatal(err)
	case <-time.After(time.Second * 5):
		t.Fatal("Connection timout")
	}
}
