package ms2s3

import (
	"context"
	"encoding/binary"
	"sync"

	gortmp "code.justin.tv/event-engineering/gortmp/pkg/rtmp"
	"github.com/aws/aws-sdk-go/aws"
	"github.com/aws/aws-sdk-go/service/s3/s3iface"
	"github.com/aws/aws-sdk-go/service/s3/s3manager"
	"github.com/sirupsen/logrus"
)

var (
	flvHeader = []byte{'F', 'L', 'V', 0x01, 0x05, 0x00, 0x00, 0x00, 0x09, 0x00, 0x00, 0x00, 0x00,
		0x12, 0x00, 0x00, 0x28, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, // 11
		0x02, 0x00, 0x0a, 0x6f, 0x6e, 0x4d, 0x65, 0x74, 0x61, 0x44, 0x61, 0x74, 0x61, // 13
		0x08, 0x00, 0x00, 0x00, 0x01, // 5
		0x00, 0x08, 0x64, 0x75, 0x72, 0x61, 0x74, 0x69, 0x6F, 0x6E, // 10
		0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, // 9
		0x00, 0x00, 0x09, // 3
		0x00, 0x00, 0x00, 0x33}
)

const (
	AUDIO_TAG = byte(0x08)
	VIDEO_TAG = byte(0x09)
)

type MediaStreamtoS3 interface {
	RecordToS3(ctx context.Context, ms gortmp.MediaStream, destKey string) error
	Stop()
	WaitUntilDone()
}

type mediastreamtos3 struct {
	logger       logrus.FieldLogger
	s3           s3iface.S3API
	destBucket   string
	stopped      bool
	doneWg       sync.WaitGroup
	tagHeaderBuf []byte
}

func New(s3Client s3iface.S3API, destBucket string, logger logrus.FieldLogger) MediaStreamtoS3 {
	return &mediastreamtos3{
		logger:       logger,
		s3:           s3Client,
		destBucket:   destBucket,
		tagHeaderBuf: make([]byte, 11),
	}
}

func (s *mediastreamtos3) Stop() {
	s.stopped = true
}

func (s *mediastreamtos3) WaitUntilDone() {
	s.doneWg.Wait()
}

func (s *mediastreamtos3) writeTag(br *blockingReader, tag *gortmp.FlvTag) error {
	s.logger.Debugf("Writing tag with timestamp %v Type %v", tag.Timestamp, tag.Type)
	// Construct the tag header
	binary.BigEndian.PutUint32(s.tagHeaderBuf[3:7], tag.Timestamp)
	s.tagHeaderBuf[7] = s.tagHeaderBuf[3]
	binary.BigEndian.PutUint32(s.tagHeaderBuf[:4], tag.Size)
	s.tagHeaderBuf[0] = tag.Type

	// Write tag header
	if _, err := br.Write(s.tagHeaderBuf); err != nil {
		s.logger.WithError(err).Warnf("Failed to write tag header to buffer")
		return err
	}

	// Write tag data
	if _, err := br.Write(tag.Bytes); err != nil {
		s.logger.WithError(err).Warnf("Failed to write tag data to buffer")
		return err
	}

	// Write tag length including header
	if err := binary.Write(br, binary.BigEndian, uint32(tag.Size+11)); err != nil {
		s.logger.Warnf("Failed to write tag length")
		return err
	}

	return nil
}

// RecordToS3 will block until the first video keyframe in the supplied mediastream then start recording that data to S3
// It will offset timestamps so that the resulting file starts at 0
func (s *mediastreamtos3) RecordToS3(ctx context.Context, ms gortmp.MediaStream, destKey string) error {
	s.stopped = false
	var readyWg sync.WaitGroup

	// Since we're streaming data, this reader will block reads until it gets a write
	// Consequently we have to manually indicate that we're done by closing this reader
	br := newBlockingReader(s.logger)

	tags, err := ms.Subscribe()
	if err != nil {
		s.logger.WithError(err).Warnf("Failed to subscribe to mediastream")
		return err
	}

	s.doneWg.Add(1)
	readyWg.Add(1)

	go func() {
		defer ms.Unsubscribe(tags)
		defer s.doneWg.Done()

		uploader := s3manager.NewUploaderWithClient(s.s3)

		input := &s3manager.UploadInput{
			Bucket: aws.String(s.destBucket),
			Key:    aws.String(destKey),
			Body:   br,
		}

		result, err := uploader.UploadWithContext(ctx, input)

		if err != nil {
			s.logger.WithError(err).Warnf("Error uploading file %v to S3", destKey)
			br.Close()
			return
		}

		// If UploadID is empty then the file was small enough that we didn't need a multipart upload
		if result.UploadID == "" {
			s.logger.Infof("Uploaded file %v", destKey)
		} else {
			s.logger.Infof("Uploaded file %v with multipart upload ID %v", destKey, result.UploadID)
		}
	}()

	go func() {
		var initialTimestampSet bool
		var initialTimestamp uint32
		var firstKeyframeHit bool
		var startRecording bool
		var headersWritten bool
		var videoHeader *gortmp.FlvTag
		var audioHeader *gortmp.FlvTag

		// Write the FLV header
		_, err := br.Write(flvHeader)
		if err != nil {
			s.logger.WithError(err).Warnf("Failed to write FLV Header")
			br.Close()
			return
		}

		for {
			if s.stopped {
				s.logger.Info("Stopping")
				br.Close()
				return
			}

			tag, ok := <-tags

			if !ok {
				s.logger.Info("Tags channel closed")
				br.Close()
				return
			}

			if !startRecording {
				if tag.Type == VIDEO_TAG {
					header, err := tag.GetVideoHeader()
					if err == nil && header.AVCPacketType == 0 {
						videoHeader = tag

						s.logger.Debugf("Video Header has timestamp %v", tag.Timestamp)

						if videoHeader != nil && audioHeader != nil {
							startRecording = true
						}
					}
				} else if tag.Type == AUDIO_TAG {
					header, err := tag.GetAudioHeader()
					if err == nil && header.AACPacketType == 0 {
						audioHeader = tag

						s.logger.Debugf("Audio Header has timestamp %v", tag.Timestamp)

						if videoHeader != nil && audioHeader != nil {
							startRecording = true
						}
					}
				}

				continue
			}

			if !firstKeyframeHit {
				if tag.Type == VIDEO_TAG {
					header, err := tag.GetVideoHeader()
					if err == nil && header.AVCPacketType == 1 {
						firstKeyframeHit = true

						if !headersWritten {
							headersWritten = true

							videoHeader.Timestamp = 0
							audioHeader.Timestamp = 0

							err := s.writeTag(br, videoHeader)
							if err != nil {
								s.logger.WithError(err).Warn("Error writing video header")
								br.Close()
								return
							}

							err = s.writeTag(br, audioHeader)
							if err != nil {
								s.logger.WithError(err).Warn("Error writing audio header")
								br.Close()
								return
							}
						}

						readyWg.Done()
					} else {
						continue
					}
				} else {
					continue
				}
			}

			if !initialTimestampSet {
				initialTimestampSet = true
				s.logger.Debugf("Setting initial timestamp to %v", tag.Timestamp)
				initialTimestamp = tag.Timestamp
			}

			if tag.Timestamp > initialTimestamp {
				s.logger.Debugf("Original timestamp: %v Offset timestamp: %v", tag.Timestamp, tag.Timestamp-initialTimestamp)
				tag.Timestamp -= initialTimestamp
			} else {
				tag.Timestamp = 0
			}

			err := s.writeTag(br, tag)
			if err != nil {
				br.Close()
				return
			}
		}
	}()

	readyWg.Wait()

	return nil
}
