package main

import (
	"io/ioutil"
	"os"
	"path"
	"testing"
	"time"

	"code.justin.tv/video/gotranscoder/pkg/avdata"
	"code.justin.tv/video/gotranscoder/pkg/m3u8"
	"code.justin.tv/video/gotranscoder/pkg/usher"
	"code.justin.tv/video/gotranscoder/pkg/usher/usherfakes"
	"code.justin.tv/video/protocols/hls"
	"code.justin.tv/video/protocols/hlsext"

	"github.com/golang/protobuf/ptypes"
	"github.com/golang/protobuf/ptypes/duration"
	"github.com/golang/protobuf/ptypes/timestamp"
	"github.com/stretchr/testify/assert"
)

var (
	testChannel        = "monstercat"
	testSessionID      = "12345"
	testUsherTranscode = &usher.HlsTranscode{
		Destination: 999999,
	}
)

func GetTestConfig() *EncoderConfig {
	return &EncoderConfig{
		Transcodes: []TranscodeQualitySettings{
			TranscodeQualitySettings{
				Label:   "high",
				Bitrate: 1600000,
			},
		},
	}
}

func GetTestConfig_2() *EncoderConfig {
	return &EncoderConfig{
		Transcodes: []TranscodeQualitySettings{
			TranscodeQualitySettings{
				Label:         "mobile",
				Bitrate:       180000,
				Muxrate:       230000,
				AudioChannels: 2,
				AudioBitrate:  24000,
			},
		},
	}
}

// Returns a mock Usher which fulfilles SegmentHandler's expectations
func newSegmentHandlerUsher(id int) *usherfakes.FakeUsher {
	u := &usherfakes.FakeUsher{}
	u.GetHlsTranscodeReturns(&usher.HlsTranscode{ID: id}, nil)
	return u
}

func newTestSegmentHandler(t *testing.T, id int, encoderCfg *EncoderConfig) (SegmentHandler, *usherfakes.FakeUsher, func()) {
	usher := newSegmentHandlerUsher(id)

	tmpdir, err := ioutil.TempDir("", "segment-handler-test")
	if err != nil {
		t.Fatalf("unable to create temp dir: %s", err)
	}

	// Copy the state of globals so we can reset them after the test
	oldSessionID := sessionID
	oldChannel := channel
	oldUsherTranscode := usherTranscode

	sessionID = testSessionID
	channel = &testChannel
	usherTranscode = testUsherTranscode
	encoderConfig = encoderCfg

	closer := func() {
		sessionID = oldSessionID
		channel = oldChannel
		usherTranscode = oldUsherTranscode
		os.RemoveAll(tmpdir)
	}

	sh := SegmentHandler{
		usher:                        usher,
		SegmentPath:                  tmpdir,
		TranscodeID:                  id,
		TargetDuration:               2000,
		segmentHandlerStart:          time.Now(),
		transcodeUsherUpdateInterval: time.Millisecond,
		usherUpdateStreamInterval:    time.Millisecond,
	}
	return sh, usher, closer
}

func TestSegmentHandlerInitialization(t *testing.T) {
	assert := assert.New(t)

	handler, usher, cleanup := newTestSegmentHandler(t, 1, GetTestConfig())
	defer cleanup()
	handler.Initialize()

	time.Sleep(3 * time.Millisecond)

	handler.close()
	handler.wg.Wait()

	// UpdateUsherTranscode should be called every transcodeUsherUpdateInterval
	assert.NotEqual(usher.UpdateHlsTranscodeCallCount(), 0)
	// UpdateStreamInfo should not get called until at least one segment
	// has been written to the history
	assert.Equal(usher.UpdateStreamPropertiesCallCount(), 0)
}

func TestSegmentHandlerProcessOne(t *testing.T) {
	assert := assert.New(t)

	handler, _, cleanup := newTestSegmentHandler(t, 1, GetTestConfig())
	defer cleanup()
	handler.Initialize()

	handler.Process(avdata.Segment{
		Label:         "high",
		SegmentNumber: 1,
	}, nil)

	handler.close()
	handler.wg.Wait()

	assert.Contains(handler.TranscodeQualities, "high")
	assert.Len(handler.segmentHistory["high"].Segments, 1)

	// TODO: assert that we sent the segment and playlist to tenfoot

	assertFileExists(assert, path.Join(handler.SegmentPath, "high", "index-live.m3u8"))
	assertFileExists(assert, path.Join(handler.SegmentPath, "high", "index-history.m3u8"))
	assertFileExists(assert, path.Join(handler.SegmentPath, "high", "playlist.pb"))
}

func TestSegmentHandlerWeaverPlaylistGeneration(t *testing.T) {
	assert := assert.New(t)

	handler, _, cleanup := newTestSegmentHandler(t, 1, GetTestConfig())
	defer cleanup()
	handler.Initialize()

	handler.historiesLock.Lock()
	// started 10 seconds ago, because we have 10 seconds of durations
	handler.segmentHandlerStart = time.Now().Add(-10 * time.Second)
	handler.segmentHistory["high"] = &SegmentHistory{
		Segments: []avdata.Segment{
			avdata.Segment{
				Label:         "high",
				SegmentNumber: 1,
				SegmentName:   "index-0001-abcd.ts",
				Duration:      2000.0,
			},
			avdata.Segment{
				Label:         "high",
				SegmentNumber: 2,
				SegmentName:   "index-0002-defg.ts",
				Duration:      2000.0,
			},
			avdata.Segment{
				Label:         "high",
				SegmentNumber: 3,
				SegmentName:   "index-0003-hijk.ts",
				Duration:      2000.0,
			},
			avdata.Segment{
				Label:         "high",
				SegmentNumber: 4,
				SegmentName:   "index-0004-lmno.ts",
				Duration:      2000.0,
			},
			avdata.Segment{
				Label:         "high",
				SegmentNumber: 5,
				SegmentName:   "index-0005-pqrs.ts",
				Duration:      2000.0,
			},
		},
		AdBreaks: []*hlsext.AdBreakRequest{
			&hlsext.AdBreakRequest{
				StartSequenceNumber: 4,
				AdData:              []byte("payload"),
			},
		},
		StreamOffsets: make(map[string]*m3u8.StreamOffset),
	}
	handler.segmentHistory["high"].StreamOffsets[constWeaverPlaylist] = &m3u8.StreamOffset{}
	handler.historiesLock.Unlock()
	handler.close()
	handler.wg.Wait()
	var err error
	handler.streams["high"], err = newStream(testChannel, "high", testUsherTranscode.Destination)
	assert.NoError(err)

	handler.avgKbpsSeen = 2000.0
	have, err := handler.weaverPlaylist(5, "high", handler.segmentHistory["high"])
	assert.NoError(err)

	nowPB, _ := ptypes.TimestampProto(time.Now())

	want := &hlsext.ExtendedPlaylist{
		Version: 0,
		Stream:  handler.streams["high"],
		Playlist: &hls.Playlist{
			Segments: []*hls.Segment{
				&hls.Segment{
					Uri:            "index-0001-abcd.ts",
					SequenceNumber: 1,
					Duration:       &duration.Duration{2, 0},
					Discontinuity:  false,
				},
				&hls.Segment{
					Uri:            "index-0002-defg.ts",
					SequenceNumber: 2,
					Duration:       &duration.Duration{2, 0},
					Discontinuity:  false,
				},
				&hls.Segment{
					Uri:            "index-0003-hijk.ts",
					SequenceNumber: 3,
					Duration:       &duration.Duration{2, 0},
					Discontinuity:  false,
				},
				&hls.Segment{
					Uri:            "index-0004-lmno.ts",
					SequenceNumber: 4,
					Duration:       &duration.Duration{2, 0},
					Discontinuity:  false,
				},
				&hls.Segment{
					Uri:            "index-0005-pqrs.ts",
					SequenceNumber: 5,
					Duration:       &duration.Duration{2, 0},
					Discontinuity:  false,
				},
			},
			Creation:       nowPB,
			TargetDuration: 2,
			MediaSequence:  0,
			TimeElapsed:    &duration.Duration{0, 0},
			TimeTotal:      &duration.Duration{10, 0},
			Final:          false,
		},
		Ads: []*hlsext.AdBreakRequest{
			&hlsext.AdBreakRequest{
				StartSequenceNumber: 4,
				AdData:              []byte("payload"),
			},
		},
		Bitrate: 1600000,
	}

	// Make sure time-dependent values are approximately valid, and then
	// clear them out.
	d1, d2 := convertProtoDurations(assert, have.Playlist.TimeElapsed, want.Playlist.TimeElapsed)
	assert.InDelta(int(d1), int(d2), float64(100*time.Millisecond))
	have.Playlist.TimeElapsed = nil
	want.Playlist.TimeElapsed = nil

	d1, d2 = convertProtoDurations(assert, have.Playlist.TimeTotal, want.Playlist.TimeTotal)
	assert.InDelta(int(d1), int(d2), float64(100*time.Millisecond))
	have.Playlist.TimeTotal = nil
	want.Playlist.TimeTotal = nil

	t1, t2 := convertProtoTimestamps(assert, have.Playlist.Creation, want.Playlist.Creation)
	assert.WithinDuration(t1, t2, time.Second)
	have.Playlist.Creation = nil
	want.Playlist.Creation = nil

	assert.Equal(want, have)
}

func TestFutureSegmentGeneration(t *testing.T) {
	assert := assert.New(t)

	handler, _, cleanup := newTestSegmentHandler(t, 1, GetTestConfig())
	defer cleanup()
	handler.Initialize()

	handler.segmentHistory["high"] = &SegmentHistory{
		Segments: []avdata.Segment{
			avdata.Segment{
				Label:          "high",
				SegmentNumber:  1,
				SegmentName:    "index-0001-abcd.ts",
				Duration:       2000.0,
				FutureSegments: []string{},
			},
			avdata.Segment{
				Label:          "high",
				SegmentNumber:  2,
				SegmentName:    "index-0002-defg.ts",
				Duration:       2000.0,
				FutureSegments: []string{},
			},
			avdata.Segment{
				Label:          "high",
				SegmentNumber:  3,
				SegmentName:    "index-0003-hijk.ts",
				Duration:       2000.0,
				FutureSegments: []string{},
			},
			avdata.Segment{
				Label:          "high",
				SegmentNumber:  4,
				SegmentName:    "index-0004-lmno.ts",
				Duration:       2000.0,
				FutureSegments: []string{},
			},
			avdata.Segment{
				Label:          "high",
				SegmentNumber:  5,
				SegmentName:    "index-0005-pqrs.ts",
				Duration:       2000.0,
				FutureSegments: []string{"index-0006-hquu.ts", "index-0007-aiqu.ts"},
			},
		},
		AdBreaks: []*hlsext.AdBreakRequest{
			&hlsext.AdBreakRequest{
				StartSequenceNumber: 4,
				AdData:              []byte("payload"),
			},
		},
		StreamOffsets: make(map[string]*m3u8.StreamOffset),
	}
	handler.close()
	handler.wg.Wait()

	have := handler.twitchTranscoderFutureSegments(handler.segmentHistory["high"])
	want := []*hls.Segment{
		&hls.Segment{
			Uri:            "index-0006-hquu.ts",
			SequenceNumber: 6,
			Duration:       &duration.Duration{2, 0},
		},
		&hls.Segment{
			Uri:            "index-0007-aiqu.ts",
			SequenceNumber: 7,
			Duration:       &duration.Duration{2, 0},
		},
	}
	assert.Equal(have, want)

	have = handler.twitchTranscoderFutureSegments(handler.segmentHistory["medium"])
	want = []*hls.Segment{}

	assert.Equal(have, want)
}

func TestSegmentHandlerWeaverBitrate(t *testing.T) {
	assert := assert.New(t)

	{
		handler, _, cleanup := newTestSegmentHandler(t, 1, GetTestConfig())
		defer cleanup()
		handler.Initialize()

		videoBitrate := handler.weaverBitrate("high")
		assert.Equal(1600000, int(videoBitrate))

		handler.close()
		handler.wg.Wait()
	}

	{
		handler, _, cleanup := newTestSegmentHandler(t, 1, GetTestConfig_2())
		defer cleanup()
		handler.Initialize()

		videoBitrate := handler.weaverBitrate("mobile")
		assert.Equal(182000, int(videoBitrate))

		handler.close()
		handler.wg.Wait()
	}
}

func assertInDurationDelta(assert *assert.Assertions, d1, d2, delta time.Duration) {
	assert.InDelta(int(d1), int(d2), float64(delta))
}

func convertProtoTimestamps(assert *assert.Assertions, pt1, pt2 *timestamp.Timestamp) (t1, t2 time.Time) {
	t1, err := ptypes.Timestamp(pt1)
	assert.NoError(err)
	t2, err = ptypes.Timestamp(pt2)
	return t1, t2
}

func convertProtoDurations(assert *assert.Assertions, pd1, pd2 *duration.Duration) (d1, d2 time.Duration) {
	d1, err := ptypes.Duration(pd1)
	assert.NoError(err)
	d2, err = ptypes.Duration(pd2)
	assert.NoError(err)
	return d1, d2
}

func assertFileExists(assert *assert.Assertions, path string) {
	_, err := os.Stat(path)
	assert.False(os.IsNotExist(err), "File should exist: %s", path)
}
