package vod

import (
	"encoding/json"
	"fmt"
	"image"
	"image/color"
	"image/jpeg"
	"io/ioutil"
	"math"
	"math/rand"
	"net/http"
	"os"
	"path"
	"testing"
	"time"

	"code.justin.tv/video/gotranscoder/pkg/avdata"
	"code.justin.tv/video/gotranscoder/pkg/m3u8"
	"code.justin.tv/video/gotranscoder/pkg/statsd"
	"code.justin.tv/video/gotranscoder/pkg/vod/vodfakes"

	"github.com/stretchr/testify/assert"
)

// Creates mock segments. Call `Cleanup()` to delete temp files
type SegmentFactory struct {
	BytesWritten int

	createdDirs map[string]struct{}
	segNumber   int64
}

func NewSegmentFactory() *SegmentFactory {
	return &SegmentFactory{
		createdDirs: make(map[string]struct{}),
	}
}

func (sf *SegmentFactory) New(dur time.Duration, label string, data string) *avdata.Segment {
	// Make sure directory exists
	dirPath := path.Join(os.TempDir(), label)
	if err := os.MkdirAll(dirPath, 0777); err != nil {
		os.Exit(1)
	}
	sf.createdDirs[dirPath] = struct{}{}

	file, err := ioutil.TempFile(dirPath, "vodtest")
	if err != nil {
		fmt.Println(err.Error())
		os.Exit(1)
	}
	defer file.Close()

	// Write dummy data to file
	n, err := file.Write([]byte(data))
	if err != nil {
		fmt.Println(err.Error())
		os.Exit(1)
	}
	if n != len(data) {
		fmt.Println("Temp file didn't read all data")
		os.Exit(1)
	}
	sf.BytesWritten += n

	seg := &avdata.Segment{
		Label:         label,
		SegmentNumber: sf.segNumber,
		SegmentSize:   int64(n),
		FrameCount:    int32(30 * int(dur.Seconds())),
		SegmentName:   path.Base(file.Name()),
		Duration:      1000 * dur.Seconds(),
	}

	// update state
	sf.segNumber++

	return seg
}

func (sf *SegmentFactory) Cleanup() {
	for dir := range sf.createdDirs {
		os.RemoveAll(dir)
	}
}

func TestBasic(t *testing.T) {
	assert := assert.New(t)
	pusher, cleanFunc, _, _ := testQueueSetup(100, 4, mockOriginSettings{})
	defer cleanFunc()

	segs := NewSegmentFactory()
	defer segs.Cleanup()

	for _, data := range []string{"Kappa", "FrankerZ", "MingLee", "NotLikeThis"} {
		pusher.Process(segs.New(3*time.Second, "source", data), nil)
	}

	queue := pusher.queues["source"]
	for len(queue.C) > 0 {
		time.Sleep(100 * time.Millisecond)
	}

	for start := time.Now(); ; {
		if time.Since(start) > 2*time.Second {
			t.Fatal("Didn't save playlist")
			break
		}
		if queue.playlist.Len() >= 1 {
			pl := queue.playlist.Get()
			assert.Equal(12000.0, pl[0].Duration)
			break
		}
	}

	data := []string{"Kappa", "FrankerZ", "MingLee", "NotLikeThis"}
	for i := 0; i < 101; i++ {
		pusher.Process(segs.New(2*time.Second, "source", data[i%len(data)]), nil)
	}

	pusher.Quit()
}

func TestBasicOrigin(t *testing.T) {
	assert := assert.New(t)
	pusher, cleanFunc, _, _ := testQueueSetup(100, 4, mockOriginSettings{
		enable: true,
	})
	defer cleanFunc()

	segs := NewSegmentFactory()
	defer segs.Cleanup()

	for _, data := range []string{"Kappa", "FrankerZ", "MingLee", "NotLikeThis"} {
		pusher.Process(segs.New(3*time.Second, "source", data), nil)
	}

	queue := pusher.queues["source"]
	for len(queue.C) > 0 {
		time.Sleep(100 * time.Millisecond)
	}

	for start := time.Now(); ; {
		if time.Since(start) > 10*time.Second {
			t.Fatal("Didn't save playlist")
			break
		}
		if queue.playlist.Len() >= 1 {
			pl := queue.playlist.Get()
			assert.Equal(12000.0, pl[0].Duration)
			break
		}
	}

	data := []string{"Kappa", "FrankerZ", "MingLee", "NotLikeThis"}
	for i := 0; i < 101; i++ {
		pusher.Process(segs.New(2*time.Second, "source", data[i%len(data)]), nil)
	}

	pusher.Quit()
}

func TestMultiDiscontinuity(t *testing.T) {
	assert := assert.New(t)
	pusher, cleanFunc, _, _ := testQueueSetup(2, 2, mockOriginSettings{})
	defer cleanFunc()

	sourceSegs := NewSegmentFactory()
	defer sourceSegs.Cleanup()

	mediumSegs := NewSegmentFactory()
	defer mediumSegs.Cleanup()

	for i := 0; i < 100; i++ {
		pusher.Process(sourceSegs.New(10*time.Second, "source", "KappaRoss"), nil)
		pusher.Process(mediumSegs.New(10*time.Second, "medium", "KappaRoss"), nil)
	}

	for i := 0; i < 10; i++ {
		pusher.Process(sourceSegs.New(3*time.Second, "source", "KappaRoss"), nil)
	}

	// Introduce bad segment which will mark the chunk as Lost
	seg1 := mediumSegs.New(3*time.Second, "medium", "KappaRoss")
	os.Remove(path.Join(os.TempDir(), "medium", seg1.SegmentName))
	pusher.Process(seg1, nil)

	// Push 8 more medium segs
	for i := 0; i < 8; i++ {
		pusher.Process(mediumSegs.New(3*time.Second, "medium", "KappaRoss"), nil)
	}

	// Push another bad medium seg
	seg2 := mediumSegs.New(3*time.Second, "medium", "KappaRoss")
	os.Remove(path.Join(os.TempDir(), "medium", seg2.SegmentName))
	pusher.Process(seg2, nil)

	for i := 0; i < 100; i++ {
		pusher.Process(sourceSegs.New(10*time.Second, "source", "KappaRoss"), nil)
		pusher.Process(mediumSegs.New(10*time.Second, "medium", "KappaRoss"), nil)
	}
	pusher.Quit()

	sQueue := pusher.queues["source"]
	mQueue := pusher.queues["medium"]

	// Check matching number of discontinuities
	sDiscs := 0
	sChunks := sQueue.playlist.Get()
	for _, chunk := range sChunks {
		if chunk.URL == m3u8.DiscontinuityURL {
			sDiscs++
		}
	}
	mDiscs := 0
	mChunks := mQueue.playlist.Get()
	for _, chunk := range mChunks {
		if chunk.URL == m3u8.DiscontinuityURL {
			mDiscs++
		}
	}
	fmt.Println("DISCS", sDiscs, mDiscs, mQueue.playlist.Len())
	assert.True(sDiscs >= 2, "Playlist didn't record correct number of discontinuities")

	// Check correct duration
	for _, q := range []*queue{sQueue, mQueue} {
		var streamDur, targetDur float64
		chunks := q.playlist.Get()
		for _, c := range chunks {
			if c.URL == m3u8.DiscontinuityURL {
				continue
			}
			seconds := c.Duration / 1000
			streamDur += seconds
			targetDur = math.Max(targetDur, seconds)
		}
		pl := q.genPlaylist()
		assert.EqualValues(streamDur, pl.StreamDuration, fmt.Sprintf("Incorrect stream duration in playlist. Format: %s", q.format))
		assert.EqualValues(streamDur, q.Duration().Seconds(), fmt.Sprintf("Incorrect stream duration in playlist. Format: %s", q.format))
		assert.EqualValues(targetDur, pl.TargetDuration, fmt.Sprintf("Incorrect target duration in playlist. Format: %s", q.format))
	}
}

func TestMisc(t *testing.T) {
	pusher, cleanFunc, _, _ := testQueueSetup(100, 4, mockOriginSettings{})
	defer cleanFunc()

	// Make sure VOD implements the plugin interface
	pusher.Name()
	pusher.Initialize()

	// Call process with another type
	pusher.Process(52, nil)
}

// Creates mock thumbnails. Call `Cleanup()` to delete temp files
type ThumbnailFactory struct {
	counter   int64
	directory string
}

// Create and manage mock sprites
type spritesFactory struct {
	counter   int
	directory string
}

func NewThumbnailFactory() *ThumbnailFactory {
	dirPath := path.Join(os.TempDir(), "thumb")
	if err := os.MkdirAll(dirPath, 0777); err != nil {
		fmt.Println(err.Error())
		os.Exit(1)
	}

	return &ThumbnailFactory{
		directory: dirPath,
	}
}

func newSpritesFactory() *spritesFactory {
	dirPath := path.Join(os.TempDir(), "sprites")
	if err := os.MkdirAll(dirPath, 0777); err != nil {
		fmt.Println(err.Error())
		os.Exit(1)
	}

	return &spritesFactory{
		directory: dirPath,
	}
}

func (th *spritesFactory) New() *avdata.Sprite {
	filename := "sprite.json"
	if th.counter%2 == 0 {
		filename = fmt.Sprintf("sprite-%d", th.counter)
	}

	file, err := os.Create(path.Join(th.directory, filename))
	if err != nil {
		fmt.Println(err.Error())
		os.Exit(1)
	}
	defer file.Close()

	sprite := &avdata.Sprite{
		Label:     "sprite",
		Path:      file.Name(),
		Size:      int64(th.counter),
		Timestamp: int64(time.Now().Unix()),
	}
	th.counter++
	return sprite
}

func (th *spritesFactory) Cleanup() {
	os.RemoveAll(th.directory)
}

func (th *ThumbnailFactory) New() *avdata.Thumbnail {
	file, err := ioutil.TempFile(th.directory, "vodtest")
	if err != nil {
		fmt.Println(err.Error())
		os.Exit(1)
	}
	defer file.Close()

	// generate image
	img := image.NewGray(image.Rect(0, 0, 128, 128))
	for y := 0; y < 128; y++ {
		for x := 0; x < 128; x++ {
			c := x + y + (rand.Intn(9) - 4) // jitter a bit
			if c > 255 {
				c = 255
			}
			img.SetGray(x, y, color.Gray{uint8(c)})
		}
	}

	// Write dummy data to file
	if err := jpeg.Encode(file, img, nil); err != nil {
		fmt.Println(err.Error())
		os.Exit(1)
	}

	thumb := &avdata.Thumbnail{
		Label:     "thumb",
		Path:      file.Name(),
		Size:      th.counter,
		Timestamp: int32(time.Now().Unix()),
	}
	th.counter++
	return thumb
}

func (th *ThumbnailFactory) Cleanup() {
	os.RemoveAll(th.directory)
}

func TestThumbnails(t *testing.T) {
	assert := assert.New(t)
	pusher, cleanFunc, _, _ := testQueueSetup(100, 4, mockOriginSettings{})
	defer cleanFunc()

	thumbs := NewThumbnailFactory()
	defer thumbs.Cleanup()

	for i := 0; i < 20; i++ {
		pusher.Process(thumbs.New(), nil)
		time.Sleep(time.Millisecond)
	}

	pusher.Quit()
	assert.Equal(4, len(pusher.thumbnails))
}

func TestSprites(t *testing.T) {
	assert := assert.New(t)
	pusher, cleanFunc, _, s3Conn := testQueueSetup(100, 4, mockOriginSettings{})
	defer cleanFunc()

	sprites := newSpritesFactory()
	defer sprites.Cleanup()

	numSpriteAssets := 20
	for i := 0; i < numSpriteAssets; i++ {
		pusher.Process(sprites.New(), nil)
	}

	pusher.Quit()
	assert.Equal(numSpriteAssets, s3Conn.PutCallCount())
	files, err := ioutil.ReadDir(sprites.directory)
	assert.Nil(err)
	assert.Equal(1, len(files))

}

func newCodecs(label string) (video, audio *avdata.Codec) {
	video = &avdata.Codec{
		Label:      label,
		VideoCodec: "MyVideoCodec",
		Width:      1920,
		Height:     1080,
	}
	audio = &avdata.Codec{
		Label:      label,
		AudioCodec: "MyAudioCodec",
	}
	return video, audio
}

func TestCodec(t *testing.T) {
	assert := assert.New(t)
	pusher, cleanFunc, _, _ := testQueueSetup(100, 4, mockOriginSettings{})
	defer cleanFunc()

	segs := NewSegmentFactory()
	defer segs.Cleanup()

	for _, data := range []string{"Kappa", "FrankerZ", "MingLee", "NotLikeThis"} {
		pusher.Process(segs.New(3*time.Second, "source", data), nil)
	}
	video, audio := newCodecs("source")
	pusher.Process(video, nil)
	pusher.Process(audio, nil)
	for _, data := range []string{"Kappa", "FrankerZ", "MingLee", "NotLikeThis"} {
		pusher.Process(segs.New(5*time.Second, "source", data), nil)
	}
	video, audio = newCodecs("medium")
	pusher.Process(video, nil)
	pusher.Process(audio, nil)
	for _, data := range []string{"Kappa", "FrankerZ", "MingLee", "NotLikeThis"} {
		pusher.Process(segs.New(10*time.Second, "source", data), nil)
	}

	pusher.Quit()
	for _, label := range []string{"source", "medium"} {
		f := pusher.formats[label]
		assert.Equal("MyVideoCodec", f.VideoCodec)
		assert.Equal("MyAudioCodec", f.AudioCodec)
		assert.Equal(1920, f.Width)
		assert.Equal(1080, f.Height)
	}
}

func TestOriginThumbs(t *testing.T) {
	assert := assert.New(t)
	pusher, cleanFunc, verifyFunc, _ := testQueueSetup(10, 3, mockOriginSettings{
		enable: true,
	})
	defer cleanFunc()

	thumbs := NewThumbnailFactory()
	defer thumbs.Cleanup()

	// Process a thumbnail
	thumb := thumbs.New()
	thumb.Path = fmt.Sprintf("%s/thumb/thumb-0000000027.jpg", pusher.HlsUrlBase)
	err := pusher.Process(thumb, nil)
	pusher.Quit()

	assert.Nil(err)

	// verify that vinyl was not called
	assert.Zero(verifyFunc("vinyl", "thumbnail"))

	// verify that origin was called
	assert.Equal(1, verifyFunc("origin", "ReadThumbnail"))
}

func TestOriginThumbsError(t *testing.T) {
	assert := assert.New(t)
	pusher, cleanFunc, verifyFunc, _ := testQueueSetup(10, 3, mockOriginSettings{
		enable: true,
		code:   500,
	})
	defer cleanFunc()

	thumbs := NewThumbnailFactory()
	defer thumbs.Cleanup()

	// Process a thumbnail
	thumb := thumbs.New()
	thumb.Path = fmt.Sprintf("%s/thumb/thumb-0000000027.jpg", pusher.HlsUrlBase)
	_, err := pusher.readThumbnailFromOrigin(thumb)
	pusher.Quit()

	assert.NotNil(err)

	// verify that vinyl was not called
	assert.Zero(verifyFunc("vinyl", "thumbnail"))

	// verify that origin was called
	assert.Equal(1, verifyFunc("origin", "ReadThumbnail"))
}

func TestVinylStatsdStr(t *testing.T) {
	settings := &statsd.Settings{Enabled: true}
	settings.Connect()
	c := &vodfakes.FakeStatter{}
	statsd.SetMockConnection(c)
	assert := assert.New(t)

	// 200 error code should not be reported
	reportVinylStatusCode(&http.Response{StatusCode: http.StatusOK})
	assert.Equal(0, c.IncCallCount())

	// Should report nothing to statsd and should not panic
	reportVinylStatusCode(nil)
	resp, _, _ := c.IncArgsForCall(0)
	assert.Equal(fmt.Sprintf("%s.%s", statsdVinylAPI, statsdVinylDialFailure), resp)

	// Should report to statsd
	statusCode := http.StatusBadRequest
	reportVinylStatusCode(&http.Response{StatusCode: statusCode})
	resp, _, _ = c.IncArgsForCall(1)
	assert.Equal(fmt.Sprintf("%s.%d", statsdVinylAPI, statusCode), resp)
}

func TestVinylFpsAndBitrate(t *testing.T) {
	assert := assert.New(t)

	// Create a fake segment
	segs := NewSegmentFactory()
	defer segs.Cleanup()
	segmentData := "Some random data"
	segmentDuration := 2 * time.Second
	segment := segs.New(segmentDuration, testLabelSource, segmentData)
	segment.FrameCount = 84

	mockS3Conn := NewMockS3Conn()
	pusher, cleanFunc, vinylServer, _ := NewPusherWithMockDefault(10, 2, mockOriginSettings{}, mockS3Conn)
	defer cleanFunc()

	pusher.Process(segment, make(chan struct{}))
	pusher.Quit()

	registerRequests := vinylServer.getResponses("/v1/vods/past_broadcast")
	assert.NotZero(len(registerRequests))
	var registerBody registerWrapper
	assert.Nil(json.Unmarshal(registerRequests[0], &registerBody))
	assert.Equal(testFpsSource, registerBody.PastBroadcast.Formats[testLabelSource].Fps)
	assert.Equal(testFpsMedium, registerBody.PastBroadcast.Formats[testLabelMedium].Fps)
	assert.Equal(testBitrateSource, registerBody.PastBroadcast.Formats[testLabelSource].Bitrate)
	assert.Equal(testBitrateMedium, registerBody.PastBroadcast.Formats[testLabelMedium].Bitrate)

	finalizeRequests := vinylServer.getResponses(fmt.Sprintf("/v1/vods/%d", testVodId))
	assert.NotZero(len(finalizeRequests))
	var finalizeBody finalizeVODProps
	assert.Nil(json.Unmarshal(finalizeRequests[0], &finalizeBody))
	assert.Equal(0, finalizeBody.Formats[testLabelMedium].Bitrate)
	assert.Equal(float64(0), finalizeBody.Formats[testLabelMedium].Fps)
	assert.Equal(float64(segment.FrameCount)/segmentDuration.Seconds(), finalizeBody.Formats[testLabelSource].Fps)
	assert.Equal(int(float64(8*len(segmentData))/segmentDuration.Seconds()), finalizeBody.Formats[testLabelSource].Bitrate)
}
