package kinesis

import (
	"fmt"
	"testing"

	"code.justin.tv/web/jax/common/log"
	"code.justin.tv/web/jax/updater/kinesis/kinesis_mocks"
	"github.com/aws/aws-sdk-go/aws"
	awsKinesis "github.com/aws/aws-sdk-go/service/kinesis"

	. "github.com/smartystreets/goconvey/convey"
)

func TestGetRecordsChannel(t *testing.T) {
	streamName := "stream"
	index := "index"
	Convey("fails to describe stream", t, func() {
		kinesisClient := &kinesis_mocks.Kinesis{}
		client := &client{
			KinesisClient: kinesisClient,
			Index:         index,
			StreamName:    streamName,
		}

		kinesisClient.On("DescribeStream", &awsKinesis.DescribeStreamInput{
			StreamName: aws.String(streamName),
		}).Return(nil, fmt.Errorf("what"))

		_, err := client.GetRecordsChannel()

		So(err, ShouldNotBeNil)
	})

	Convey("described stream", t, func() {
		kinesisClient := &kinesis_mocks.Kinesis{}
		client := &client{
			KinesisClient: kinesisClient,
			Index:         index,
			StreamName:    streamName,
		}

		kinesisClient.On("DescribeStream", &awsKinesis.DescribeStreamInput{
			StreamName: aws.String(streamName),
		}).Return(&awsKinesis.DescribeStreamOutput{
			StreamDescription: &awsKinesis.StreamDescription{
				Shards: []*awsKinesis.Shard{
					&awsKinesis.Shard{
						ShardId: aws.String("1"),
					},
					&awsKinesis.Shard{
						ShardId: aws.String("2"),
					},
				},
				HasMoreShards: aws.Bool(true),
			},
		}, nil)

		kinesisClient.On("DescribeStream", &awsKinesis.DescribeStreamInput{
			StreamName:            aws.String(streamName),
			ExclusiveStartShardId: aws.String("2"),
		}).Return(&awsKinesis.DescribeStreamOutput{
			StreamDescription: &awsKinesis.StreamDescription{
				Shards: []*awsKinesis.Shard{
					&awsKinesis.Shard{
						ShardId: aws.String("3"),
					},
				},
				HasMoreShards: aws.Bool(false),
			},
		}, nil)

		Convey("fails to get shard iterator", func() {
			kinesisClient.On("GetShardIterator", &awsKinesis.GetShardIteratorInput{
				ShardId:           aws.String("1"),
				StreamName:        aws.String(streamName),
				ShardIteratorType: aws.String("LATEST"),
			}).Return(nil, fmt.Errorf("wat"))

			_, err := client.GetRecordsChannel()

			So(err, ShouldNotBeNil)
		})

		Convey("get shard iterators", func() {
			shardIterator1 := aws.String("what")
			shardIterator2 := aws.String("what2")
			errorShardIterator := aws.String("waht2")

			kinesisClient.On("GetShardIterator", &awsKinesis.GetShardIteratorInput{
				ShardId:           aws.String("1"),
				StreamName:        aws.String(streamName),
				ShardIteratorType: aws.String("LATEST"),
			}).Return(&awsKinesis.GetShardIteratorOutput{
				ShardIterator: shardIterator1,
			}, nil)

			kinesisClient.On("GetShardIterator", &awsKinesis.GetShardIteratorInput{
				ShardId:           aws.String("2"),
				StreamName:        aws.String(streamName),
				ShardIteratorType: aws.String("LATEST"),
			}).Return(&awsKinesis.GetShardIteratorOutput{
				ShardIterator: shardIterator2,
			}, nil)

			kinesisClient.On("GetShardIterator", &awsKinesis.GetShardIteratorInput{
				ShardId:           aws.String("3"),
				StreamName:        aws.String(streamName),
				ShardIteratorType: aws.String("LATEST"),
			}).Return(&awsKinesis.GetShardIteratorOutput{
				ShardIterator: errorShardIterator,
			}, nil)

			Convey("gets records", func() {
				log.Init(nil)
				nextShardIterator := aws.String("what3")

				kinesisClient.On("GetRecords", &awsKinesis.GetRecordsInput{
					Limit:         aws.Int64(5000),
					ShardIterator: shardIterator1,
				}).Return(&awsKinesis.GetRecordsOutput{
					Records: []*awsKinesis.Record{
						&awsKinesis.Record{
							Data: []byte(`hello`),
						},
					},
					NextShardIterator: nextShardIterator,
				}, nil)

				kinesisClient.On("GetRecords", &awsKinesis.GetRecordsInput{
					Limit:         aws.Int64(5000),
					ShardIterator: nextShardIterator,
				}).Return(&awsKinesis.GetRecordsOutput{
					Records: []*awsKinesis.Record{
						&awsKinesis.Record{
							Data: []byte(`hi`),
						},
					},
					NextShardIterator: nil,
				}, nil)

				kinesisClient.On("GetRecords", &awsKinesis.GetRecordsInput{
					Limit:         aws.Int64(5000),
					ShardIterator: shardIterator2,
				}).Return(&awsKinesis.GetRecordsOutput{
					Records: []*awsKinesis.Record{
						&awsKinesis.Record{
							Data: []byte(`hello`),
						},
					},
					NextShardIterator: nil,
				}, nil)

				kinesisClient.On("GetRecords", &awsKinesis.GetRecordsInput{
					Limit:         aws.Int64(5000),
					ShardIterator: errorShardIterator,
				}).Return(nil, fmt.Errorf("what"))

				ch, err := client.GetRecordsChannel()

				So(err, ShouldBeNil)

				records := <-ch

				So(records.Records, ShouldResemble, []*awsKinesis.Record{
					&awsKinesis.Record{
						Data: []byte(`hello`),
					},
				})

				records = <-ch

				So(records.Records, ShouldResemble, []*awsKinesis.Record{
					&awsKinesis.Record{
						Data: []byte(`hello`),
					},
				})

				records = <-ch

				So(records.Records, ShouldResemble, []*awsKinesis.Record{
					&awsKinesis.Record{
						Data: []byte(`hi`),
					},
				})
			})
		})
	})
}
