package rtmp

import (
	"encoding/hex"
	"fmt"
	golog "log"
	"net"
	"sync"
	"sync/atomic"
	"time"

	goctx "context"

	rtmpctx "code.justin.tv/video/gortmp/pkg/context"
	rtmplog "code.justin.tv/video/gortmp/pkg/log"
)

type Handler interface {
	// handle a single connection
	ServeRTMP(goctx.Context, net.Conn) error
}

type MediaServerHandler interface {
	OnMediaStreamCreated(goctx.Context, MediaStream)
	OnMediaStreamDestroyed(goctx.Context, MediaStream)
	Handle(goctx.Context, Receiver, Message) error
}

type ServerConfig struct {
	Handler Handler
	Logger  *golog.Logger
	Context goctx.Context
}

type Server struct {
	handler Handler
	logger  *golog.Logger
	ctx     goctx.Context
}

type MediaServer struct {
	handler     MediaServerHandler
	streams     map[string]MediaStream
	streamsMu   sync.Mutex
	connections int64
	ctx         goctx.Context
}

type serverConn struct {
	BasicConn
	server *MediaServer

	streams map[uint32]*serverStream

	ctx   goctx.Context
	ctxMu sync.Mutex
}

type serverStream struct {
	id   uint32
	conn *serverConn

	mediaPlayer MediaPlayer
	mediaStream MediaStream
	// Is the mediaStreams's lifecycle tied to server stream
	mediaStreamOwned bool
}

// Server methods

func NewServer(config ServerConfig) *Server {
	var background goctx.Context
	if config.Context == nil {
		background = goctx.Background()
	} else {
		background = config.Context
	}
	ret := &Server{
		handler: config.Handler,
		ctx:     rtmpctx.WithLogger(background, config.Logger),
	}

	return ret
}

func (s *Server) ListenAndServe(addr string) error {
	ln, err := net.Listen("tcp", addr)
	if err != nil {
		return err
	}
	rtmplog.Infof(s.ctx, "Listener Addr: %s", ln.Addr())
	return s.Serve(ln)
}

func (s *Server) Serve(ln net.Listener) error {
	for {
		client, err := ln.Accept()
		if err != nil {
			return err
		}

		rtmplog.Infof(s.ctx, "ServeRTMP(%s)", client.RemoteAddr())

		go func() {
			if handshake, err := SHandshake(s.ctx, client); err != nil {
				rtmplog.Tracef(s.ctx, "Handshake failed, input bytes: \n%s", hex.Dump(handshake))
				return
			}
			if err := s.handler.ServeRTMP(s.ctx, client); err != nil {
				rtmplog.Errorf(s.ctx, "ServeRTMP(%s) failed with %s", client.RemoteAddr(), err)
			}
		}()
	}
}

func (s *Server) ServeWithWaitGroup(ln net.Listener, wg *sync.WaitGroup) error {
	for {
		client, err := ln.Accept()
		if err != nil {
			return err
		}

		rtmplog.Infof(s.ctx, "ServeRTMP(%s)", client.RemoteAddr())
		wg.Add(1)
		go func() {
			defer wg.Done()
			if handshake, err := SHandshake(s.ctx, client); err != nil {
				rtmplog.Tracef(s.ctx, "Handshake failed, input bytes: \n%s", hex.Dump(handshake))
				return
			}
			if err := s.handler.ServeRTMP(s.ctx, client); err != nil {
				rtmplog.Errorf(s.ctx, "ServeRTMP(%s) failed with %s", client.RemoteAddr(), err)
			}
		}()
	}
}

func (s *Server) Context() goctx.Context {
	return s.ctx
}

//MediaServer Methods

func NewMediaServer(handler MediaServerHandler) *MediaServer {
	return &MediaServer{
		handler: handler,
		streams: make(map[string]MediaStream),
		ctx:     goctx.Background(),
	}
}

func (s *MediaServer) ServeRTMP(ctx goctx.Context, conn net.Conn) error {
	defer conn.Close()

	atomic.AddInt64(&s.connections, 1)
	defer atomic.AddInt64(&s.connections, -1)

	ctx = rtmpctx.WithLocalAddr(s.ctx, conn.LocalAddr())
	ctx = rtmpctx.WithRemoteAddr(ctx, conn.RemoteAddr())

	sc := &serverConn{
		BasicConn: NewBasicConn(conn),
		streams:   make(map[uint32]*serverStream),
		server:    s,
		ctx:       ctx,
	}
	defer sc.cleanup()

	for {
		raw, err := sc.Read()
		if err != nil {
			rtmplog.Debugf(sc.getCtx(), "connection read error: %s", err)
			return err
		}

		//rtmplog.Tracef(sc.getCtx(), "msg: %v %v %v %v", raw.ChunkStreamID, raw.Timestamp, raw.Type, raw.StreamID)

		msg, err := sc.parseMessage(raw)
		if err != nil {
			rtmplog.Debugf(sc.getCtx(), "message parse error: %s", err)
			return err
		}

		if raw.StreamID == 0 {
			if err := s.handler.Handle(sc.getCtx(), sc, msg); err != nil {
				rtmplog.Debugf(sc.getCtx(), "conn handler error: %s", err)
				return err
			}
		} else if ss, ok := sc.streams[raw.StreamID]; ok {
			if err := s.handler.Handle(sc.getCtx(), ss, msg); err != nil {
				rtmplog.Debugf(sc.getCtx(), "stream handler error: %s", err)
				return err
			}
		} else {
			rtmplog.Warnf(sc.getCtx(), "ignoring message for invalid stream id %d", raw.StreamID)
			return nil
		}
	}
}

func (s *MediaServer) Context() goctx.Context {
	return s.ctx
}

func (s *MediaServer) Connections() int64 {
	return atomic.LoadInt64(&s.connections)
}

func (s *MediaServer) MediaStream(name string) MediaStream {
	s.streamsMu.Lock()
	defer s.streamsMu.Unlock()

	return s.streams[name]
}

func (s *MediaServer) createStream(ctx goctx.Context, name string) (MediaStream, error) {
	stream := NewMediaStream(ctx, name)
	err := s.AddStream(ctx, stream)

	return stream, err
}

func (s *MediaServer) AddStream(ctx goctx.Context, stream MediaStream) error {
	err := func() error {
		name := stream.Name()
		s.streamsMu.Lock()
		defer s.streamsMu.Unlock()

		if _, found := s.streams[name]; found {
			return fmt.Errorf("media stream %s already exists", name)
		}
		s.streams[name] = stream
		return nil
	}()

	if err == nil {
		s.handler.OnMediaStreamCreated(ctx, stream)
	}

	return err
}

func (s *MediaServer) DeleteStream(ctx goctx.Context, name string) error {
	stream, err := func() (MediaStream, error) {
		s.streamsMu.Lock()
		defer s.streamsMu.Unlock()

		if stream, found := s.streams[name]; found {
			delete(s.streams, name)
			return stream, nil
		}
		return nil, fmt.Errorf("media stream %s not found", name)
	}()

	if stream != nil {
		stream.Close()
		s.handler.OnMediaStreamDestroyed(ctx, stream)
	}

	return err
}

// Conn methods

func (sc *serverConn) parseMessage(raw *RawMessage) (Message, error) {
	if raw.ChunkStreamID == CS_ID_PROTOCOL_CONTROL {
		var msg Message
		var err error

		if raw.Type == USER_CONTROL_MESSAGE {
			msg, err = ParseEvent(raw)
		} else {
			msg, err = ParseControlMessage(raw)
		}

		if err != nil {
			rtmplog.Debugf(sc.getCtx(), "error parsing message %#v: %s", raw, err)
		}

		return msg, err
	}

	switch raw.Type {
	case COMMAND_AMF0:
		fallthrough
	case COMMAND_AMF3:
		return ParseCommand(raw)
	case USER_CONTROL_MESSAGE:
		return ParseEvent(raw)
	default:
		return raw, nil
	}
}

func (sc *serverConn) setCtx(ctx goctx.Context) {
	sc.ctxMu.Lock()
	sc.ctx = ctx
	sc.ctxMu.Unlock()
}

func (sc *serverConn) getCtx() goctx.Context {
	sc.ctxMu.Lock()
	defer sc.ctxMu.Unlock()
	return sc.ctx
}

func (sc *serverConn) Write(msg Message) error {
	raw, err := msg.RawMessage()
	if err != nil {
		return err
	}

	if err := sc.BasicConn.Write(raw); err != nil {
		return err
	}

	return sc.BasicConn.Flush()
}

func (sc *serverConn) cleanup() {
	for _, ss := range sc.streams {
		ss.cleanup()
	}
}

func (sc *serverConn) Handle(msg Message) error {
	switch msg := msg.(type) {
	case ConnectCommand:
		return sc.invokeConnect(msg)
	case CreateStreamCommand:
		return sc.invokeCreateStream(msg)
	case DeleteStreamCommand:
		return sc.invokeDeleteStream(msg)
	case FCPublishCommand:
		return sc.invokeFCPublishCommand(msg)
	case InitStreamCommand:
		return sc.invokeInitStream(msg)
	default:
		return nil
	}
}

func (sc *serverConn) invokeConnect(cmd ConnectCommand) error {
	_, ok := cmd.Properties["app"].(string)
	if !ok {
		return ErrConnectRejected(fmt.Errorf("invalid app: %#v", cmd.Properties["app"]))
	}

	rtmplog.Infof(sc.getCtx(), "invokeConnect(%+v)", cmd)

	messages := []Message{
		WindowAcknowledgementSizeMessage{2500000},
		SetPeerBandwidthMessage{2500000, BANDWIDTH_LIMIT_DYNAMIC},
		StreamBeginEvent{0}, // this should be a Stream begin event
		SetChunkSizeMessage{4096},
		ResultCommand{
			TransactionID: cmd.TransactionID,
			Properties:    DefaultConnectProperties,
			Info:          DefaultConnectInformation,
		},
	}

	for _, msg := range messages {
		if err := sc.Write(msg); err != nil {
			return err
		}
	}

	return nil
}

func (sc *serverConn) invokeCreateStream(cmd CreateStreamCommand) error {
	stream, err := sc.createStream()
	if err != nil {
		return err
	}

	messages := []Message{
		ResultCommand{
			TransactionID: cmd.TransactionID,
			Info:          stream.id,
		},
		StreamBeginEvent{stream.id},
	}

	for _, msg := range messages {
		if err := sc.Write(msg); err != nil {
			return err
		}
	}

	return nil
}

func (sc *serverConn) invokeInitStream(cmd InitStreamCommand) error {
	return sc.initStream(cmd.InitStreamID)
}

// currently not supported
func (sc *serverConn) invokeDeleteStream(cmd DeleteStreamCommand) error {
	if stream, ok := sc.streams[cmd.DeleteStreamID]; !ok {
		return fmt.Errorf("no such stream id: %d", cmd.DeleteStreamID)
	} else {
		delete(sc.streams, cmd.DeleteStreamID)
		stream.cleanup()
	}
	return nil
}

func (sc *serverConn) invokeFCPublishCommand(cmd FCPublishCommand) error {
	return sc.Write(OnFCPublishCommand{
		Status: StatusPublishStart(fmt.Sprintf("FCPublish to stream %s", cmd.Name)),
	})
}

func (sc *serverConn) initStream(streamID uint32) error {
	if len(sc.streams) > MAX_CONN_STREAM_COUNT {
		return fmt.Errorf("maximum stream count per connection reached")
	}
	if streamID == 0 {
		return fmt.Errorf("cannot init stream 0")
	}
	if _, ok := sc.streams[streamID]; !ok {
		stream := &serverStream{
			id:   streamID,
			conn: sc,
		}
		sc.streams[streamID] = stream
		return nil
	}
	return fmt.Errorf("initStreamID %d is already occupied", streamID)
}

func (sc *serverConn) createStream() (*serverStream, error) {
	if len(sc.streams) > MAX_CONN_STREAM_COUNT {
		return nil, fmt.Errorf("maximum stream count per connection reached")
	}

	for i := uint32(1); i <= MAX_CONN_STREAM_COUNT; i++ {
		if _, ok := sc.streams[i]; !ok {
			stream := &serverStream{
				id:   i,
				conn: sc,
			}
			sc.streams[i] = stream
			return stream, nil
		}
	}

	return nil, fmt.Errorf("maximum stream count per connection reached")
}

// Stream methods

func (ss *serverStream) Write(msg Message) error {
	raw, err := msg.RawMessage()
	if err != nil {
		return err
	}

	raw.StreamID = ss.id
	// only reset unset ChunkStreamID
	if raw.ChunkStreamID == 0 {
		raw.ChunkStreamID = CS_ID_DATA
	}

	return ss.conn.Write(raw)
}

func (ss *serverStream) Conn() Conn {
	return ss.conn
}

func (ss *serverStream) Handle(msg Message) error {
	switch msg := msg.(type) {
	case PlayCommand:
		return ss.invokePlay(msg)
	case PublishCommand:
		return ss.invokePublish(msg)
	case *RawMessage:
		return ss.handleRaw(msg)
	default:
		return nil
	}
}

// ReleaseMediaStream releases the serverStream's ownership of the MediaStream. The caller is now
// responsible for cleaning up the stream by deleting it from the MediaServer.
func (ss *serverStream) ReleaseMediaStream() (MediaStream, error) {
	if ss.mediaStream == nil {
		return nil, fmt.Errorf("The server stream lacks a mediaStream")
	}
	if ss.mediaStreamOwned == false {
		return nil, fmt.Errorf("The server stream does not own the mediaStream")
	}
	ss.mediaStreamOwned = false
	return ss.mediaStream, nil
}

func (ss *serverStream) handleRaw(msg *RawMessage) error {
	switch msg.Type {
	case VIDEO_TYPE:
		fallthrough
	case AUDIO_TYPE:
		fallthrough
	case DATA_AMF0:
		if ss.mediaStream != nil {
			return ss.mediaStream.Publish(&FlvTag{
				Type:        msg.Type,
				Timestamp:   msg.Timestamp,
				Size:        uint32(msg.Data.Len()),
				Bytes:       msg.Data.Bytes(),
				ArrivalTime: time.Now(),
			})
		}
	}

	return nil
}

func mediaStreamRetries(ctx goctx.Context, ms *MediaServer, name string) MediaStream {
	for {
		select {
		case <-ctx.Done():
			return nil
		// TODO convert this to something that wakes when the named stream is created
		case <-time.After(15 * time.Millisecond):
			if st := ms.MediaStream(name); st != nil {
				return st
			}
		}
	}
}

func (ss *serverStream) invokePlay(cmd PlayCommand) error {
	if len(cmd.Name) == 0 {
		return ErrPlayFailed(fmt.Errorf("invalid play stream name"))
	}

	ctx := ss.conn.getCtx()
	ctx = rtmpctx.WithStreamName(ctx, cmd.Name)
	ctx = rtmpctx.WithConnStatus(ctx, "play")
	ss.conn.setCtx(ctx)

	rtmplog.Infof(ctx, "invokePlay(%s)", cmd.Name)

	ms := ss.conn.server.MediaStream(cmd.Name)

	if ms == nil && cmd.Timeout > 0 {
		timeoutCtx, cancel := goctx.WithTimeout(ctx, cmd.Timeout)
		ms = mediaStreamRetries(timeoutCtx, ss.conn.server, cmd.Name)
		cancel()
	}

	if ms == nil {
		rtmplog.Infof(ctx, "invokePlay stream not found: %v", cmd.Name)
		resp := OnStatusCommand{
			Info: NetStreamPlayInfo{
				Status:  StatusPlayStreamNotFound("not_found"),
				Details: cmd.Name,
			},
		}
		if err := ss.conn.Write(resp); err != nil {
			rtmplog.Infof(ctx, "invokePlay failed to write response: %s", err)
			return err
		}

		// Return nil here to avoid immediately terminating the connection
		return nil
	}

	err := ss.conn.Write(OnStatusCommand{
		Info: NetStreamPlayInfo{
			Status:  StatusPlayReset("reset"),
			Details: cmd.Name,
		},
	})
	if err != nil {
		rtmplog.Infof(ctx, "invokePlay failed to write onStatus: %s", err)
		return err
	}

	resp := OnStatusCommand{
		Info: NetStreamPlayInfo{
			Status:  StatusPlayStart("play"),
			Details: cmd.Name,
		},
	}

	if err := ss.conn.Write(resp); err != nil {
		return err
	}

	if err := ss.play(ms); err != nil {
		rtmplog.Infof(ctx, "invokePlay failed: %s", err)
		return ErrPlayFailed(fmt.Errorf("failed to play %v, error: %s", cmd.Name, err))
	}

	return nil
}

func (ss *serverStream) invokePublish(cmd PublishCommand) error {
	name, ok := cmd.Name.(string)
	if !ok || len(name) == 0 {
		return ErrPublishBadName(fmt.Errorf("invalid publish stream name"))
	}

	ctx := ss.conn.getCtx()
	ctx = rtmpctx.WithStreamName(ctx, name)
	ctx = rtmpctx.WithConnStatus(ctx, "publish")
	ss.conn.setCtx(ctx)

	if err := ss.publish(ss.conn.getCtx(), name); err != nil {
		return ErrPublishBadName(err)
	}

	err := ss.conn.Write(OnStatusCommand{
		StreamID:      cmd.StreamID,
		TransactionID: cmd.TransactionID,
		Info:          StatusPublishStart(fmt.Sprintf("Publishing %s.", name)),
	})

	if err != nil {
		return err
	}

	return nil
}

// manage the mediaStream ref
func (ss *serverStream) publish(ctx goctx.Context, name string) error {
	// ignore multiple publishes to the same stream with the same
	if ss.mediaStream != nil {
		return nil
	}

	ms, err := ss.conn.server.createStream(ss.conn.getCtx(), name)
	if err != nil {
		return err
	}

	ss.mediaStream = ms
	ss.mediaStreamOwned = true

	return nil
}

// manage the mediaplayer ref
func (ss *serverStream) play(ms MediaStream) error {
	if ss.mediaPlayer != nil {
		ss.mediaPlayer.Close()
		ss.mediaPlayer = nil
	}

	if ms == nil {
		return fmt.Errorf("stream not found")
	}

	if mp, err := NewMediaPlayer(ss.conn.getCtx(), ms, ss); err != nil {
		return err
	} else {
		ss.mediaPlayer = mp
	}
	ss.mediaPlayer.Start()
	return nil
}

func (ss *serverStream) cleanup() {
	if ss.mediaStream != nil && ss.mediaStreamOwned {
		ss.conn.server.DeleteStream(ss.conn.getCtx(), ss.mediaStream.Name())
		ss.mediaStream = nil
		ss.mediaStreamOwned = false
	}

	if ss.mediaPlayer != nil {
		ss.mediaPlayer.Close()
		ss.mediaPlayer = nil
	}
}
