package vod

import (
	"math"
	"os"
	"path"
	"testing"
	"time"

	"github.com/stretchr/testify/assert"

	"code.justin.tv/video/gotranscoder/pkg/m3u8"
)

func testQueueSetup(queueSize, chunkFactor int, origSettings mockOriginSettings) (*Pusher, func(), func(string, string) int) {
	pusher, cleanFunc, vinylServer, originServer := NewPusherWithMockDefault(queueSize, chunkFactor, origSettings)

	verifyFunc := func(service string, call string) int {
		if service == "vinyl" {
			if i, ok := vinylServer.Calls[call]; ok {
				return i
			}
		} else if service == "origin" {
			if i, ok := originServer.Calls[call]; ok {
				return i
			}
		}
		return 0
	}

	return pusher, cleanFunc, verifyFunc

}
func TestDiscontinuity(t *testing.T) {
	assert := assert.New(t)
	pusher, cleanFunc, _ := testQueueSetup(10, 3, mockOriginSettings{})
	defer cleanFunc()

	queue := pusher.queues["source"]

	segs := NewSegmentFactory()
	defer segs.Cleanup()

	for i := 0; i < 2; i++ {
		queue.AddSegment(segs.New(3*time.Second, "source", "VOD sample data"))
	}

	// Add a segment with missing data, will be a discontinuity
	queue.AddSegment(segs.New(5*time.Second, "source", "VOD sample data"))
	badSeg := segs.New(5*time.Second, "source", "This will be deleted")
	os.Remove(path.Join(os.TempDir(), badSeg.Label, badSeg.SegmentName))
	queue.AddSegment(badSeg)

	// Keep adding segments until we overflow the queue 5 times,
	// leading to 5 discontinuities
	for missed := 0; missed < 5; {
		seg := segs.New(5*time.Second, "source", "VOD sample data")
		if queue.AddSegment(seg) != nil {
			missed++
			time.Sleep(time.Millisecond)
		}
	}

	// Wait until the queue is flushed
	for len(queue.C) > 0 {
		time.Sleep(time.Millisecond)
	}

	// add a couple more good segments
	for i := 0; i < 10; i++ {
		queue.AddSegment(segs.New(5*time.Second, "source", "VOD sample data"))
	}

	pusher.Quit()

	// Check discontinuities were added to the playlist
	discontinuities := 0
	chunks := queue.playlist.Get()
	for _, chunk := range chunks {
		if chunk.URL == m3u8.DiscontinuityURL {
			discontinuities++
		}
	}
	assert.NotZero(discontinuities)

	// Check correct duration
	var streamDur, targetDur float64
	for _, c := range chunks {
		if c.URL == m3u8.DiscontinuityURL {
			continue
		}
		seconds := c.Duration / 1000
		streamDur += seconds
		targetDur = math.Max(targetDur, seconds)
	}
	pl := queue.genPlaylist()
	assert.EqualValues(streamDur, pl.StreamDuration, "Incorrect stream duration in playlist")
	assert.EqualValues(streamDur, queue.Duration().Seconds(), "Incorrect stream duration in playlist")
	assert.EqualValues(targetDur, pl.TargetDuration, "Incorrect target duration in playlist")
}

func TestOriginFail(t *testing.T) {
	assert := assert.New(t)

	// Configure origin to return 5xx
	pusher, cleanFunc, verifyFunc := testQueueSetup(10, 3, mockOriginSettings{
		enable: true,
		code:   500,
	})
	defer cleanFunc()

	queue := pusher.queues["source"]

	segs := NewSegmentFactory()
	defer segs.Cleanup()

	// Verify that segment read returns an error
	_, err := queue.readSegmentFromOrigin(segs.New(3*time.Second, "source", "VOD sample data"))
	assert.NotNil(err)

	// verify that vinyl was not called
	assert.Zero(verifyFunc("vinyl", "register"))
	assert.Zero(verifyFunc("vinyl", "thumbnail"))
	assert.Zero(verifyFunc("vinyl", "finalize"))

	// verify that origin was called to download the segment
	assert.Equal(1, verifyFunc("origin", "ReadSegment"))

}

func TestOriginBasic(t *testing.T) {
	assert := assert.New(t)
	pusher, cleanFunc, verifyFunc := testQueueSetup(10, 3, mockOriginSettings{
		enable: true,
	})
	defer cleanFunc()

	queue := pusher.queues["source"]

	segs := NewSegmentFactory()
	defer segs.Cleanup()

	// verify that segment read returns an error
	_, err := queue.readSegmentFromOrigin(segs.New(3*time.Second, "source", "VOD sample data"))

	assert.Nil(err)

	// verify that vinyl was not called
	assert.Zero(verifyFunc("vinyl", "register"))
	assert.Zero(verifyFunc("vinyl", "thumbnail"))
	assert.Zero(verifyFunc("vinyl", "finalize"))

	// verify that origin was called to download the segment
	assert.Equal(1, verifyFunc("origin", "ReadSegment"))

}
