package server

import (
	"encoding/hex"
	"fmt"
	golog "log"
	"net"
	"sync"
	"sync/atomic"
	"time"

	goctx "context"

	rtmpctx "code.justin.tv/event-engineering/gortmp/pkg/context"
	rtmplog "code.justin.tv/event-engineering/gortmp/pkg/log"
	gortmp "code.justin.tv/event-engineering/gortmp/pkg/rtmp"
)

type Handler interface {
	// handle a single connection
	ServeRTMP(goctx.Context, net.Conn) error
}

type MediaServerHandler interface {
	OnMediaStreamCreated(goctx.Context, gortmp.MediaStream)
	OnMediaStreamDestroyed(goctx.Context, gortmp.MediaStream)
	Handle(goctx.Context, gortmp.Receiver, gortmp.Message) error
}

type ServerConfig struct {
	Handler Handler
	Logger  *golog.Logger
	Context goctx.Context
}

type Server struct {
	handler Handler
	logger  *golog.Logger
	ctx     goctx.Context
}

type MediaServer struct {
	handler     MediaServerHandler
	streams     map[string]gortmp.MediaStream
	streamsMu   sync.Mutex
	connections int64
	ctx         goctx.Context
}

type serverConn struct {
	gortmp.BasicConn
	server *MediaServer

	streams map[uint32]*serverStream

	ctx   goctx.Context
	ctxMu sync.Mutex
}

type serverStream struct {
	id   uint32
	conn *serverConn

	mediaPlayer gortmp.MediaPlayer
	mediaStream gortmp.MediaStream
	// Is the mediaStreams's lifecycle tied to server stream
	mediaStreamOwned bool
}

// Server methods

func NewServer(config ServerConfig) *Server {
	var background goctx.Context
	if config.Context == nil {
		background = goctx.Background()
	} else {
		background = config.Context
	}
	ret := &Server{
		handler: config.Handler,
		ctx:     rtmpctx.WithLogger(background, config.Logger),
	}

	return ret
}

func (s *Server) ListenAndServe(addr string) error {
	ln, err := net.Listen("tcp", addr)
	if err != nil {
		return err
	}
	rtmplog.Infof(s.ctx, "Listener Addr: %s", ln.Addr())
	return s.Serve(ln)
}

func (s *Server) Serve(ln net.Listener) error {
	for {
		client, err := ln.Accept()
		if err != nil {
			return err
		}

		rtmplog.Infof(s.ctx, "ServeRTMP(%s)", client.RemoteAddr())

		go func() {
			if handshake, err := gortmp.SHandshake(s.ctx, client); err != nil {
				rtmplog.Tracef(s.ctx, "Handshake failed, input bytes: \n%s", hex.Dump(handshake))
				return
			}

			s.HandleConn(client)
		}()
	}
}

func (s *Server) HandleConn(conn net.Conn) error {
	if err := s.handler.ServeRTMP(s.ctx, conn); err != nil {
		rtmplog.Errorf(s.ctx, "ServeRTMP(%s) failed with %s", conn.RemoteAddr(), err)
		return err
	}

	return nil
}

func (s *Server) ServeWithWaitGroup(ln net.Listener, wg *sync.WaitGroup) error {
	for {
		client, err := ln.Accept()
		if err != nil {
			return err
		}

		rtmplog.Infof(s.ctx, "ServeRTMP(%s)", client.RemoteAddr())
		wg.Add(1)
		go func() {
			defer wg.Done()

			if handshake, err := gortmp.SHandshake(s.ctx, client); err != nil {
				rtmplog.Tracef(s.ctx, "Handshake failed, input bytes: \n%s", hex.Dump(handshake))
				return
			}

			_ = s.HandleConn(client)
		}()
	}
}

func (s *Server) Context() goctx.Context {
	return s.ctx
}

//MediaServer Methods

func NewMediaServer(handler MediaServerHandler) *MediaServer {
	return &MediaServer{
		handler: handler,
		streams: make(map[string]gortmp.MediaStream),
		ctx:     goctx.Background(),
	}
}

func (s *MediaServer) ServeRTMP(ctx goctx.Context, conn net.Conn) error {
	defer conn.Close()

	atomic.AddInt64(&s.connections, 1)
	defer atomic.AddInt64(&s.connections, -1)

	ctx = rtmpctx.WithLocalAddr(s.ctx, conn.LocalAddr())
	ctx = rtmpctx.WithRemoteAddr(ctx, conn.RemoteAddr())

	sc := &serverConn{
		BasicConn: gortmp.NewBasicConn(conn),
		streams:   make(map[uint32]*serverStream),
		server:    s,
		ctx:       ctx,
	}
	defer sc.cleanup()

	for {
		raw, err := sc.Read()
		if err != nil {
			rtmplog.Debugf(sc.getCtx(), "connection read error: %s", err)
			return err
		}

		//rtmplog.Tracef(sc.getCtx(), "msg: %v %v %v %v", raw.ChunkStreamID, raw.Timestamp, raw.Type, raw.StreamID)

		msg, err := sc.parseMessage(raw)
		if err != nil {
			rtmplog.Debugf(sc.getCtx(), "message parse error: %s", err)
			return err
		}

		if raw.StreamID == 0 {
			if err := s.handler.Handle(sc.getCtx(), sc, msg); err != nil {
				rtmplog.Debugf(sc.getCtx(), "conn handler error: %s", err)
				return err
			}
		} else if ss, ok := sc.streams[raw.StreamID]; ok {
			if err := s.handler.Handle(sc.getCtx(), ss, msg); err != nil {
				rtmplog.Debugf(sc.getCtx(), "stream handler error: %s", err)
				return err
			}
		} else {
			rtmplog.Warnf(sc.getCtx(), "ignoring message for invalid stream id %d", raw.StreamID)
			return nil
		}
	}
}

func (s *MediaServer) Context() goctx.Context {
	return s.ctx
}

func (s *MediaServer) Connections() int64 {
	return atomic.LoadInt64(&s.connections)
}

func (s *MediaServer) MediaStream(name string) gortmp.MediaStream {
	s.streamsMu.Lock()
	defer s.streamsMu.Unlock()

	return s.streams[name]
}

func (s *MediaServer) createStream(ctx goctx.Context, name string) (gortmp.MediaStream, error) {
	stream := gortmp.NewMediaStream(ctx, name)
	err := s.AddStream(ctx, stream)

	return stream, err
}

func (s *MediaServer) AddStream(ctx goctx.Context, stream gortmp.MediaStream) error {
	err := func() error {
		name := stream.Name()
		s.streamsMu.Lock()
		defer s.streamsMu.Unlock()

		if _, found := s.streams[name]; found {
			return fmt.Errorf("media stream %s already exists", name)
		}
		s.streams[name] = stream
		return nil
	}()

	if err == nil {
		s.handler.OnMediaStreamCreated(ctx, stream)
	}

	return err
}

func (s *MediaServer) DeleteStream(ctx goctx.Context, name string) error {
	stream, err := func() (gortmp.MediaStream, error) {
		s.streamsMu.Lock()
		defer s.streamsMu.Unlock()

		if stream, found := s.streams[name]; found {
			delete(s.streams, name)
			return stream, nil
		}
		return nil, fmt.Errorf("media stream %s not found", name)
	}()

	if stream != nil {
		stream.Close()
		s.handler.OnMediaStreamDestroyed(ctx, stream)
	}

	return err
}

// Conn methods

func (sc *serverConn) parseMessage(raw *gortmp.RawMessage) (gortmp.Message, error) {
	if raw.ChunkStreamID == gortmp.CS_ID_PROTOCOL_CONTROL {
		var msg gortmp.Message
		var err error

		if raw.Type == gortmp.USER_CONTROL_MESSAGE {
			msg, err = gortmp.ParseEvent(raw)
		} else {
			msg, err = gortmp.ParseControlMessage(raw)
		}

		if err != nil {
			rtmplog.Debugf(sc.getCtx(), "error parsing message %#v: %s", raw, err)
		}

		return msg, err
	}

	switch raw.Type {
	case gortmp.COMMAND_AMF0:
		fallthrough
	case gortmp.COMMAND_AMF3:
		return gortmp.ParseCommand(raw)
	case gortmp.USER_CONTROL_MESSAGE:
		return gortmp.ParseEvent(raw)
	default:
		return raw, nil
	}
}

func (sc *serverConn) setCtx(ctx goctx.Context) {
	sc.ctxMu.Lock()
	sc.ctx = ctx
	sc.ctxMu.Unlock()
}

func (sc *serverConn) getCtx() goctx.Context {
	sc.ctxMu.Lock()
	defer sc.ctxMu.Unlock()
	return sc.ctx
}

func (sc *serverConn) Write(msg gortmp.Message) error {
	raw, err := msg.RawMessage()
	if err != nil {
		return err
	}

	if err := sc.BasicConn.Write(raw); err != nil {
		return err
	}

	return sc.BasicConn.Flush()
}

func (sc *serverConn) cleanup() {
	for _, ss := range sc.streams {
		ss.cleanup()
	}
}

func (sc *serverConn) Handle(msg gortmp.Message) error {
	switch msg := msg.(type) {
	case gortmp.ConnectCommand:
		return sc.invokeConnect(msg)
	case gortmp.CreateStreamCommand:
		return sc.invokeCreateStream(msg)
	case gortmp.DeleteStreamCommand:
		return sc.invokeDeleteStream(msg)
	case gortmp.FCPublishCommand:
		return sc.invokeFCPublishCommand(msg)
	case gortmp.InitStreamCommand:
		return sc.invokeInitStream(msg)
	default:
		return nil
	}
}

func (sc *serverConn) invokeConnect(cmd gortmp.ConnectCommand) error {
	_, ok := cmd.Properties["app"].(string)
	if !ok {
		return gortmp.ErrConnectRejected(fmt.Errorf("invalid app: %#v", cmd.Properties["app"]))
	}

	rtmplog.Infof(sc.getCtx(), "invokeConnect(%+v)", cmd)

	messages := []gortmp.Message{
		gortmp.WindowAcknowledgementSizeMessage{2500000},
		gortmp.SetPeerBandwidthMessage{2500000, gortmp.BANDWIDTH_LIMIT_DYNAMIC},
		gortmp.StreamBeginEvent{0}, // this should be a Stream begin event
		gortmp.SetChunkSizeMessage{4096},
		gortmp.ResultCommand{
			TransactionID: cmd.TransactionID,
			Properties:    gortmp.DefaultConnectProperties,
			Info:          gortmp.DefaultConnectInformation,
		},
	}

	for _, msg := range messages {
		if err := sc.Write(msg); err != nil {
			return err
		}
	}

	return nil
}

func (sc *serverConn) invokeCreateStream(cmd gortmp.CreateStreamCommand) error {
	stream, err := sc.createStream()
	if err != nil {
		return err
	}

	messages := []gortmp.Message{
		gortmp.ResultCommand{
			TransactionID: cmd.TransactionID,
			Info:          stream.id,
		},
		gortmp.StreamBeginEvent{stream.id},
	}

	for _, msg := range messages {
		if err := sc.Write(msg); err != nil {
			return err
		}
	}

	return nil
}

func (sc *serverConn) invokeInitStream(cmd gortmp.InitStreamCommand) error {
	return sc.initStream(cmd.InitStreamID)
}

// currently not supported
func (sc *serverConn) invokeDeleteStream(cmd gortmp.DeleteStreamCommand) error {
	if stream, ok := sc.streams[cmd.DeleteStreamID]; !ok {
		return fmt.Errorf("no such stream id: %d", cmd.DeleteStreamID)
	} else {
		delete(sc.streams, cmd.DeleteStreamID)
		stream.cleanup()
	}
	return nil
}

func (sc *serverConn) invokeFCPublishCommand(cmd gortmp.FCPublishCommand) error {
	return sc.Write(gortmp.OnFCPublishCommand{
		Status: gortmp.StatusPublishStart(fmt.Sprintf("FCPublish to stream %s", cmd.Name)),
	})
}

func (sc *serverConn) initStream(streamID uint32) error {
	if len(sc.streams) > gortmp.MAX_CONN_STREAM_COUNT {
		return fmt.Errorf("maximum stream count per connection reached")
	}
	if streamID == 0 {
		return fmt.Errorf("cannot init stream 0")
	}
	if _, ok := sc.streams[streamID]; !ok {
		stream := &serverStream{
			id:   streamID,
			conn: sc,
		}
		sc.streams[streamID] = stream
		return nil
	}
	return fmt.Errorf("initStreamID %d is already occupied", streamID)
}

func (sc *serverConn) createStream() (*serverStream, error) {
	if len(sc.streams) > gortmp.MAX_CONN_STREAM_COUNT {
		return nil, fmt.Errorf("maximum stream count per connection reached")
	}

	for i := uint32(1); i <= gortmp.MAX_CONN_STREAM_COUNT; i++ {
		if _, ok := sc.streams[i]; !ok {
			stream := &serverStream{
				id:   i,
				conn: sc,
			}
			sc.streams[i] = stream
			return stream, nil
		}
	}

	return nil, fmt.Errorf("maximum stream count per connection reached")
}

// Stream methods

func (ss *serverStream) Write(msg gortmp.Message) error {
	raw, err := msg.RawMessage()
	if err != nil {
		return err
	}

	raw.StreamID = ss.id
	// only reset unset ChunkStreamID
	if raw.ChunkStreamID == 0 {
		raw.ChunkStreamID = gortmp.CS_ID_DATA
	}

	return ss.conn.Write(raw)
}

func (ss *serverStream) Conn() gortmp.Conn {
	return ss.conn
}

func (ss *serverStream) Handle(msg gortmp.Message) error {
	switch msg := msg.(type) {
	case gortmp.PlayCommand:
		return ss.invokePlay(msg)
	case gortmp.PublishCommand:
		return ss.invokePublish(msg)
	case *gortmp.RawMessage:
		return ss.handleRaw(msg)
	default:
		return nil
	}
}

// ReleaseMediaStream releases the serverStream's ownership of the MediaStream. The caller is now
// responsible for cleaning up the stream by deleting it from the MediaServer.
func (ss *serverStream) ReleaseMediaStream() (gortmp.MediaStream, error) {
	if ss.mediaStream == nil {
		return nil, fmt.Errorf("The server stream lacks a mediaStream")
	}
	if ss.mediaStreamOwned == false {
		return nil, fmt.Errorf("The server stream does not own the mediaStream")
	}
	ss.mediaStreamOwned = false
	return ss.mediaStream, nil
}

func (ss *serverStream) handleRaw(msg *gortmp.RawMessage) error {
	switch msg.Type {
	case gortmp.VIDEO_TYPE:
		fallthrough
	case gortmp.AUDIO_TYPE:
		fallthrough
	case gortmp.DATA_AMF0:
		if ss.mediaStream != nil {
			return ss.mediaStream.Publish(&gortmp.FlvTag{
				Type:        msg.Type,
				Timestamp:   msg.Timestamp,
				Size:        uint32(msg.Data.Len()),
				Bytes:       msg.Data.Bytes(),
				ArrivalTime: time.Now(),
			})
		}
	}

	return nil
}

func (ss *serverStream) invokePlay(cmd gortmp.PlayCommand) error {
	if len(cmd.Name) == 0 {
		return gortmp.ErrPlayFailed(fmt.Errorf("invalid play stream name"))
	}

	ctx := ss.conn.getCtx()
	ctx = rtmpctx.WithStreamName(ctx, cmd.Name)
	ctx = rtmpctx.WithConnStatus(ctx, "play")
	ss.conn.setCtx(ctx)

	rtmplog.Infof(ctx, "invokePlay(%s)", cmd.Name)

	err := ss.conn.Write(gortmp.OnStatusCommand{
		Info: gortmp.NetStreamPlayInfo{
			Status:  gortmp.StatusPlayReset("reset"),
			Details: cmd.Name,
		},
	})

	if err != nil {
		rtmplog.Infof(ctx, "invokePlay failed: %s", err)
		return err
	}

	resp := gortmp.OnStatusCommand{
		Info: gortmp.NetStreamPlayInfo{
			Status:  gortmp.StatusPlayStart("play"),
			Details: cmd.Name,
		},
	}

	if err := ss.conn.Write(resp); err != nil {
		return err
	}

	if err := ss.play(cmd.Name); err != nil {
		rtmplog.Infof(ctx, "invokePlay failed, invalid stream name %s", err)
		return gortmp.ErrPlayFailed(fmt.Errorf("invalid stream name: %v %s", cmd.Name, err))
	}

	return nil
}

func (ss *serverStream) invokePublish(cmd gortmp.PublishCommand) error {
	name, ok := cmd.Name.(string)
	if !ok || len(name) == 0 {
		return gortmp.ErrPublishBadName(fmt.Errorf("invalid publish stream name"))
	}

	ctx := ss.conn.getCtx()
	ctx = rtmpctx.WithStreamName(ctx, name)
	ctx = rtmpctx.WithConnStatus(ctx, "publish")
	ss.conn.setCtx(ctx)

	if err := ss.publish(ss.conn.getCtx(), name); err != nil {
		return gortmp.ErrPublishBadName(err)
	}

	err := ss.conn.Write(gortmp.OnStatusCommand{
		StreamID:      cmd.StreamID,
		TransactionID: cmd.TransactionID,
		Info:          gortmp.StatusPublishStart(fmt.Sprintf("Publishing %s.", name)),
	})

	if err != nil {
		return err
	}

	return nil
}

// manage the mediaStream ref
func (ss *serverStream) publish(ctx goctx.Context, name string) error {
	// ignore multiple publishes to the same stream with the same
	if ss.mediaStream != nil {
		return nil
	}

	ms, err := ss.conn.server.createStream(ss.conn.getCtx(), name)
	if err != nil {
		return err
	}

	ss.mediaStream = ms
	ss.mediaStreamOwned = true

	return nil
}

// manage the mediaplayer ref
func (ss *serverStream) play(name string) error {
	if ss.mediaPlayer != nil {
		ss.mediaPlayer.Close()
		ss.mediaPlayer = nil
	}

	ms := ss.conn.server.MediaStream(name)
	if ms == nil {
		return fmt.Errorf("stream not found")
	}

	if mp, err := gortmp.NewMediaPlayer(ss.conn.getCtx(), ms, ss); err != nil {
		return err
	} else {
		ss.mediaPlayer = mp
	}
	ss.mediaPlayer.Start()
	return nil
}

func (ss *serverStream) cleanup() {
	if ss.mediaStream != nil && ss.mediaStreamOwned {
		ss.conn.server.DeleteStream(ss.conn.getCtx(), ss.mediaStream.Name())
		ss.mediaStream = nil
		ss.mediaStreamOwned = false
	}

	if ss.mediaPlayer != nil {
		ss.mediaPlayer.Close()
		ss.mediaPlayer = nil
	}
}
