package rtmp

import (
	"bytes"
	"errors"
	"fmt"
	"math"
	"sync"

	"code.justin.tv/video/gortmp/pkg/codec"
	rtmpctx "code.justin.tv/video/gortmp/pkg/context"
	"code.justin.tv/video/gortmp/pkg/log"
	"context"
)

var (
	MediaStreamClosed = errors.New("media stream closed")
)

const maxLoopLogs = 1000

// All Measurements in kbps
type MediaStreamStats struct {
	VideoBytes int64
	AudioBytes int64
	Bytes      int64

	VideoBytesRate int64
	AudioBytesRate int64
	BytesRate      int64
}

type MediaStreamInfo struct {
	AudioHeader     *FlvTag
	LastAudioPacket *FlvTag
	VideoHeader     *FlvTag
	LastVideoPacket *FlvTag
	DataHeader      *FlvTag
	LastDataPacket  *FlvTag
}

// MediaStream pubsub
type MediaStream interface {
	Name() string

	Publish(*FlvTag) error
	Subscribe() (chan *FlvTag, error)
	Unsubscribe(chan *FlvTag)

	Wait()
	Close()

	IsClosed() bool
	Stats() MediaStreamStats
	Info() MediaStreamInfo
}

type mediaStream struct {
	name string

	pub         chan *FlvTag
	sub         chan chan *FlvTag
	unsub       chan chan *FlvTag
	infoRequest chan chan MediaStreamInfo

	videoBytes Counter
	audioBytes Counter
	bytes      Counter

	videoBps Counter
	audioBps Counter
	bps      Counter

	lastVideoBytes int64
	lastAudioBytes int64
	lastBytes      int64
	lastTs         int64

	done     chan bool
	doneOnce sync.Once

	log log.Logger
}

func NewMediaStream(ctx context.Context, streamName string) MediaStream {
	ctx = rtmpctx.WithStreamName(ctx, streamName)
	ms := &mediaStream{
		name:        streamName,
		pub:         make(chan *FlvTag),
		sub:         make(chan chan *FlvTag),
		unsub:       make(chan chan *FlvTag),
		infoRequest: make(chan chan MediaStreamInfo),
		done:        make(chan bool),
		log:         log.FromContext(ctx, "mediastream"),
	}

	go ms.loop()

	return ms
}

func (ms *mediaStream) Name() string {
	return ms.name
}

func (ms *mediaStream) Publish(tag *FlvTag) error {
	// select to avoid blocking when the loop exits
	select {
	case ms.pub <- tag:
		return nil
	case <-ms.done:
		return MediaStreamClosed
	}
}

func (ms *mediaStream) Subscribe() (chan *FlvTag, error) {
	ch := make(chan *FlvTag)

	select {
	case ms.sub <- ch:
		return ch, nil
	case <-ms.done:
		return nil, MediaStreamClosed
	}
}

func (ms *mediaStream) Unsubscribe(ch chan *FlvTag) {
	go func() {
		for _ = range ch {
			// avoids a potential deadlock where a blocking write
			// to ch could happen before the blocking read from unsub
			// this loop ensures writes to ch do not block, and
			// we're guaranteed it will end as ch will be closed
		}
	}()

	select {
	case ms.unsub <- ch:
	case <-ms.done:
	}
}

func (ms *mediaStream) Wait() {
	<-ms.done
}

func (ms *mediaStream) Close() {
	ms.doneOnce.Do(func() {
		close(ms.done)
	})
}

func (ms *mediaStream) IsClosed() bool {
	select {
	case <-ms.done:
		return true
	default:
		return false
	}
}

func (ms *mediaStream) Stats() MediaStreamStats {
	return MediaStreamStats{
		VideoBytes: ms.videoBytes.Get(),
		AudioBytes: ms.audioBytes.Get(),
		Bytes:      ms.bytes.Get(),

		VideoBytesRate: ms.videoBps.Get(),
		AudioBytesRate: ms.audioBps.Get(),
		BytesRate:      ms.bps.Get(),
	}
}

func (ms *mediaStream) Info() MediaStreamInfo {
	infoResp := make(chan MediaStreamInfo)
	select {
	case ms.infoRequest <- infoResp:
		return <-infoResp
	case <-ms.done:
		return MediaStreamInfo{}
	}
}

func (ms *mediaStream) toInt(f float64) int64 {
	if math.IsInf(f, 0) || math.IsNaN(f) {
		return 0
	}
	return int64(f)
}

func (ms *mediaStream) updateBps(ts int64) {
	tsDelta := float64(ts - ms.lastTs)

	// fetch current values
	videoBytes := ms.videoBytes.Get()
	audioBytes := ms.audioBytes.Get()
	bytes := ms.bytes.Get()

	videoBps := float64(videoBytes-ms.lastVideoBytes) / tsDelta * 1000
	audioBps := float64(audioBytes-ms.lastAudioBytes) / tsDelta * 1000
	bps := float64(bytes-ms.lastBytes) / tsDelta * 1000

	ms.lastVideoBytes = videoBytes
	ms.lastAudioBytes = audioBytes
	ms.lastBytes = bytes
	ms.lastTs = ts

	ms.videoBps.Set(ms.toInt(videoBps))
	ms.audioBps.Set(ms.toInt(audioBps))
	ms.bps.Set(ms.toInt(bps))
}

func formatByteSlice(b []byte, n int) string {
	if b == nil || n >= len(b) {
		return fmt.Sprintf("%#v", b)
	}
	s := fmt.Sprintf("%#v", b[:n])
	return s[:len(s)-1] + ", ...}"
}

func videoSequenceHeaderLooksValid(h *FlvVideoHeader) bool {
	switch h.CodecID {
	case FlvCodecAVC:
		if h.CodecData == nil {
			return false
		}
		avcd, err := codec.NewAVCDecoderConfig(bytes.NewReader(h.CodecData))
		if err != nil {
			return false
		}
		return len(avcd.SpsTable) > 0 && len(avcd.SpsTable[0]) > 0 && len(avcd.PpsTable) > 0 && len(avcd.PpsTable[0]) > 0
	default:
		// We don't know how to parse sequence headers for other codecs
		return false
	}
}

func audioSequenceHeaderLooksValid(h *FlvAudioHeader) bool {
	switch h.SoundFormat {
	case FlvCodecAAC:
		if h.CodecData == nil {
			return false
		}
		// ISO/IEC 14496-3 AudioSpecificConfig structure
		if len(h.CodecData) < 2 {
			// The shortest valid structure is 5+4+4 bits (padded up to two bytes)
			return false
		}
		return h.CodecData[0]&0xF8 != 0 // Audio Object Type (first-part) is not AOT_NULL
	default:
		// We don't know how to parse sequence headers for other codecs
		return false
	}
}

func (ms *mediaStream) loop() {
	ms.log.Infof("start")
	defer ms.log.Infof("end")

	nlog := log.NewNLogger(ms.log)

	var info MediaStreamInfo
	subs := make(map[chan *FlvTag]bool)

	// shut down subs
	defer func() {
		for sub := range subs {
			close(sub)
		}
	}()

	var lastTs uint32

	for {
		select {
		case tag := <-ms.pub:
			ms.bytes.Add(int64(tag.Size))

			lastTs = tag.Timestamp

			switch tag.Type {
			case AUDIO_TYPE:
				ms.audioBytes.Add(int64(tag.Size))

				header, err := tag.GetAudioHeader()
				if err == nil {
					ms.log.Tracef("audio: format(%v) AACPacketType(%v)", header.SoundFormat, header.AACPacketType)

					// store the audio sequence header if necessary
					if header.IsSequenceHeader() {
						if info.AudioHeader == nil {
							info.AudioHeader = tag
						} else if !bytes.Equal(info.AudioHeader.Bytes, tag.Bytes) {
							if audioSequenceHeaderLooksValid(header) {
								info.AudioHeader = tag
							} else {
								nlog.Logf(log.LogWarn, maxLoopLogs, "audio: new sequence header is suspicious, not latching. Data: %s",
									formatByteSlice(header.CodecData, 32))
							}
						}
					}
				} else {
					nlog.Logf(log.LogWarn, maxLoopLogs, "audio: failed to parse header: %v", err)
				}

				info.LastAudioPacket = tag
				for sub := range subs {
					sub <- tag
				}
			case VIDEO_TYPE:
				ms.videoBytes.Add(int64(tag.Size))

				// store the video sequence header if necessary
				header, err := tag.GetVideoHeader()
				if err == nil {
					ms.log.Tracef("video: FrameType(%v) CodecID(%v) AVCPacketType(%v)", header.FrameType, header.CodecID, header.AVCPacketType)

					if header.FrameType == 1 {
						ms.updateBps(int64(tag.Timestamp))
					}

					if header.IsSequenceHeader() {
						if info.VideoHeader == nil {
							info.VideoHeader = tag
						} else if !bytes.Equal(info.VideoHeader.Bytes, tag.Bytes) {
							if videoSequenceHeaderLooksValid(header) {
								info.VideoHeader = tag
							} else {
								nlog.Logf(log.LogWarn, maxLoopLogs, "video: new sequence header is suspicious, not latching. Data: %s",
									formatByteSlice(header.CodecData, 32))
							}
						}
					}
				} else {
					nlog.Logf(log.LogWarn, maxLoopLogs, "video: failed to parse header %v", err)
				}

				info.LastVideoPacket = tag
				// wait for the first video keyframe
				for sub, started := range subs {
					if !started {
						if header.FrameType == 1 {
							sub <- tag
							subs[sub] = true
							if info.VideoHeader == nil {
								ms.log.Tracef("Sending video data without a video header")
							}
						}
					} else {
						sub <- tag
					}
				}
			case DATA_AMF0:
				if info.DataHeader == nil {
					info.DataHeader = tag
				} else {
					ms.log.Tracef("Got new data header, ignoring")
				}

				info.LastDataPacket = tag
				for sub := range subs {
					sub <- tag
				}
			}
		case respCh := <-ms.infoRequest:
			respCh <- info

		case sub := <-ms.sub:
			subs[sub] = false
			if info.DataHeader != nil {
				header := *info.DataHeader
				header.Timestamp = lastTs

				sub <- &header
			}
			if info.VideoHeader != nil {
				header := *info.VideoHeader
				header.Timestamp = lastTs

				sub <- &header
			}
			if info.AudioHeader != nil {
				header := *info.AudioHeader
				header.Timestamp = lastTs

				sub <- &header
			}
		case sub := <-ms.unsub:
			if _, found := subs[sub]; found {
				delete(subs, sub)
				close(sub)
			}
		case <-ms.done:
			return // shutdown
		}
	}
}
