package StarfruitNyxClient

import (
	"context"
	"fmt"
	"testing"
	"time"

	rpc "code.justin.tv/amzn/StarfruitNyxTwirp"
	telemetry "code.justin.tv/amzn/TwitchTelemetry"

	"github.com/aws/aws-sdk-go/aws/request"
	"github.com/aws/aws-sdk-go/aws/session"
	"github.com/aws/aws-sdk-go/service/kinesis"
	"github.com/aws/aws-sdk-go/service/kinesis/kinesisiface"
	"github.com/klauspost/compress/zstd"
	"github.com/pkg/errors"
)

// Kinesis API Mock
type testKinesis struct {
	kinesisiface.KinesisAPI

	regionName string
	down       bool
	lastRecord *kinesis.PutRecordInput

	eventsChan chan *testKinesisEvent
}

type testKinesisEvent struct {
	regionName string
	record     *kinesis.PutRecordInput
}

func (tk *testKinesis) PutRecordWithContext(ctx context.Context, i *kinesis.PutRecordInput, o ...request.Option) (*kinesis.PutRecordOutput, error) {
	if tk.down {
		return nil, errors.New("test kinesis is down")
	}

	tk.lastRecord = i
	tk.eventsChan <- &testKinesisEvent{
		regionName: tk.regionName,
		record:     i,
	}
	return nil, nil
}

// Noop TwitchTelemetry SampleObserver
type noopObserver struct {
}

func (o *noopObserver) ObserveSample(sample *telemetry.Sample) {}
func (o *noopObserver) Flush()                                 {}
func (o *noopObserver) Stop()                                  {}

// Tests
func TestClient(t *testing.T) {

	// Generate an EdgeSegmentRequestEvent
	eventPDX := &rpc.EdgeSegmentRequestEvent{
		Uuid:                  []byte("testUUID"),
		Time:                  time.Now(),
		ChannelArn:            "arn:aws:starfruit:us-west-2:twitch:channel:testChannelID",
		StreamID:              11111,
		Rendition:             "testRendition",
		ContentLength:         1000000,
		ContentDuration:       2 * time.Second,
		Pop:                   "testPoP",
		Edge:                  "testEdge",
		ViewerIPAddress:       "1.1.1.1",
		ViewerPort:            012,
		SessionID:             12345,
		MappingIP:             "1.1.1.1",
		StreamOriginDC:        "sjc02",
		Headers:               []byte("{ 'Content-Type': 'application/json' }"),
		RequestFormat:         "WEAVER",
		Priority:              rpc.SessionPriority_NEW,
		SegmentSequenceNumber: 234,
	}

	eventCMH := &rpc.EdgeSegmentRequestEvent{
		Uuid:                  []byte("testUUID"),
		Time:                  time.Now(),
		ChannelArn:            "arn:aws:starfruit:us-east-2:twitch:channel:testChannelID",
		StreamID:              11111,
		Rendition:             "testRendition",
		ContentLength:         1000000,
		ContentDuration:       2 * time.Second,
		Pop:                   "testPoP",
		Edge:                  "testEdge",
		ViewerIPAddress:       "1.1.1.1",
		ViewerPort:            012,
		SessionID:             12345,
		MappingIP:             "1.1.1.1",
		StreamOriginDC:        "sjc02",
		Headers:               []byte("{ 'Content-Type': 'application/json' }"),
		RequestFormat:         "WEAVER",
		Priority:              rpc.SessionPriority_NEW,
		SegmentSequenceNumber: 234,
	}

	rsrEventPDX := &rpc.ReplicationSegmentRequestEvent{
		Uuid:                  []byte("testUUID"),
		Time:                  time.Now(),
		ChannelArn:            "arn:aws:starfruit:us-west-2:twitch:channel:testChannelID",
		Rendition:             "testRendition",
		StreamID:              11111,
		ContentLength:         1000000,
		SegmentSequenceNumber: 234,
		RequesterPop:          "ams02",
		UpstreamPop:           "lhr03",
		RequestHost:           "video-pr-abcdef",
		UpstreamHost:          "video-pr-123456",
		StreamOriginDc:        "sjc02",
		StatusCode:            200,
		ErrorMessage:          "",
		RequestDuration:       100 * time.Millisecond,
	}

	t.Run("TestAddEvent", func(t *testing.T) {
		eventsByType := map[EventType]NyxProtoEvent{
			EventTypeESR: eventPDX,
			EventTypeRSR: rsrEventPDX,
		}
		for eventType, event := range eventsByType {
			c, _, _, err := newClient(eventType)
			if err != nil {
				t.Fatalf("failed to create client %s", err.Error())
			}

			err = c.AddEvent(event)
			if err != nil {
				t.Fatalf("failed to add event: %v", err)
			}

			// validate event submission
			s, ok := c.streams[RegionPDX]
			if !ok {
				t.Fatalf("failed to get correct stream")
			}

			if clen := len(s.events); clen < 1 {
				t.Fatal("failed to add message to incoming events")
			}

			addedEvent := <-s.events
			if event != addedEvent {
				t.Fatal("incorrect event in incoming events")
			}
		}
	})

	t.Run("TestRun", func(t *testing.T) {
		c, eventsChan, _, err := newClient(EventTypeESR)
		if err != nil {
			t.Fatalf("failed to create client %s", err.Error())
		}

		runCtx, flush := context.WithCancel(context.Background())

		doneCh := make(chan struct{})
		go func() {
			err := c.Run(runCtx)
			if err != nil {
				t.Fatalf("couldn't run %s", err.Error())
			}

			close(doneCh)
		}()

		err = c.AddEvent(eventPDX)
		if err != nil {
			t.Fatalf("failed to add event: %v", err)
		}

		flush()
		<-doneCh

		input := <-eventsChan

		events, err := decodeData(input.record.Data)
		if err != nil {
			t.Fatal("kinesis record detail failed")
		}

		if len(events.Events) != 1 {
			t.Fatalf("incorrect number of events in kinesis record")
		}

		stream := input.regionName
		if stream != "PDX" {
			t.Fatalf("record not sent to correct aws region")
		}
	})

	t.Run("TestRunSecondaryRegion", func(t *testing.T) {
		c, eventsChan, _, err := newClient(EventTypeESR)
		if err != nil {
			t.Fatalf("failed to create client %s", err.Error())
		}

		runCtx, flush := context.WithCancel(context.Background())

		doneCh := make(chan struct{})
		go func() {
			err := c.Run(runCtx)
			if err != nil {
				t.Fatalf("couldn't run %s", err.Error())
			}

			close(doneCh)
		}()

		err = c.AddEvent(eventCMH)
		if err != nil {
			t.Fatalf("failed to add event: %v", err)
		}

		fmt.Println("added event")

		flush()
		<-doneCh

		input := <-eventsChan

		events, err := decodeData(input.record.Data)
		if err != nil {
			t.Fatal("kinesis record detail failed")
		}

		if len(events.Events) != 1 {
			t.Fatalf("incorrect number of events in kinesis record")
		}

		stream := input.regionName
		if stream != "CMH" {
			t.Fatalf("record not sent to correct aws region")
		}
	})

	t.Run("TestFailure", func(t *testing.T) {
		c, eventsChan, regionalKinesisClients, err := newClient(EventTypeESR)
		if err != nil {
			t.Fatalf("failed to create client %s", err.Error())
		}

		pdxKinesis, ok := regionalKinesisClients["PDX"]
		if !ok {
			t.Fatalf("regional kinesis client not found")
		}

		pdxKinesis.down = true
		defer func() {
			pdxKinesis.down = false
		}()

		runCtx, flush := context.WithCancel(context.Background())

		doneCh := make(chan struct{})
		go func() {
			c.Run(runCtx)
			close(doneCh)
		}()

		err = c.AddEvent(eventPDX)
		if err != nil {
			t.Fatalf("failed to add event: %v", err)
		}

		flush()
		<-doneCh

		if len(eventsChan) > 0 {
			t.Fatalf("kinesis improperly received a record")
		}
	})
}

// UTIL Functions
func decodeData(data []byte) (*rpc.EdgeSegmentRequestEvents, error) {
	decoder, _ := zstd.NewReader(nil)

	p := &rpc.NyxPayload{}
	err := p.Unmarshal(data)
	if err != nil {
		return nil, err
	}

	if p.CompressionLevel == rpc.CompressionLevel_ZSTD {
		data, err = decoder.DecodeAll(p.Data, nil)
		if err != nil {
			return nil, err
		}
	} else {
		data = p.Data
	}

	e := &rpc.EdgeSegmentRequestEvents{}
	err = e.Unmarshal(data)
	if err != nil {
		return nil, err
	}

	return e, nil
}

func newClient(eventType EventType) (*Client, chan *testKinesisEvent, map[string]*testKinesis, error) {
	sess, err := session.NewSession()
	if err != nil {
		return nil, nil, nil, fmt.Errorf("failed to create session %w", err)
	}

	c, err := NewClient(ClientConfig{
		EventType:        eventType,
		Session:          sess,
		SendTimeout:      time.Second,
		SendRetryTimeout: 3 * time.Second,
		FlushInterval:    time.Second,
	})
	if err != nil {
		return nil, nil, nil, fmt.Errorf("failed to create client %w", err)
	}

	eventsChan := make(chan *testKinesisEvent, 100)
	regionalKinesisClients := map[string]*testKinesis{
		RegionPDX.name: &testKinesis{
			regionName: RegionPDX.name,
			down:       false,
			eventsChan: eventsChan,
		},
		RegionCMH.name: &testKinesis{
			regionName: RegionCMH.name,
			down:       false,
			eventsChan: eventsChan,
		},
		RegionIAD.name: &testKinesis{
			regionName: RegionIAD.name,
			down:       false,
			eventsChan: eventsChan,
		},
		RegionDUB.name: &testKinesis{
			regionName: RegionDUB.name,
			down:       false,
			eventsChan: eventsChan,
		},
		RegionFRA.name: &testKinesis{
			regionName: RegionFRA.name,
			down:       false,
			eventsChan: eventsChan,
		},
		RegionICN.name: &testKinesis{
			regionName: RegionICN.name,
			down:       false,
			eventsChan: eventsChan,
		},
		RegionNRT.name: &testKinesis{
			regionName: RegionNRT.name,
			down:       false,
			eventsChan: eventsChan,
		},
	}

	for region, s := range c.streams {
		tk, ok := regionalKinesisClients[region.name]
		if !ok {
			return nil, nil, nil, fmt.Errorf("regional kinesis client not configured")
		}

		s.client = tk
	}

	return c, eventsChan, regionalKinesisClients, nil
}
