package scanproto

import (
	"errors"
	"fmt"
	"math/rand"
	"sync"
	"time"

	"code.justin.tv/release/trace/api"
	"code.justin.tv/release/trace/internal/awsutils"
	"github.com/aws/aws-sdk-go/aws"
	"github.com/aws/aws-sdk-go/aws/awserr"
	"github.com/aws/aws-sdk-go/aws/session"
	"github.com/aws/aws-sdk-go/service/kinesis"
	"github.com/golang/protobuf/proto"
	"golang.org/x/net/context"
	"golang.org/x/net/trace"
	"golang.org/x/time/rate"
)

const (
	// Kinesis's DescribeStream API is restricted to describing at most
	// 10,000 shards per response.
	maxShardsPerDescribe = 10000
	// Maximum number of Kinesis DescribeStreams pages to read through
	maxDescribeStreamPages = 10
)

// implements scanproto.TransactionSource. Doesn't use the Multilang Daemon,
// so it shares no state with other processes.
type simpleKinesisTxSource struct {
	client *kinesis.Kinesis
	stream string

	txs    chan *api.Transaction
	errors chan error

	ctx    context.Context
	cancel func()
	wg     sync.WaitGroup

	stopMu sync.Mutex
	stop   bool
}

func NewKinesisTransactionSource(ctx context.Context, client *kinesis.Kinesis, stream string) TransactionSource {
	ctx, cancel := context.WithCancel(ctx)

	cfg := awsutils.WithChitin(ctx, client.Config.Copy())

	client = kinesis.New(session.New(cfg))
	awsutils.InstallXTraceHandlers(&client.Handlers)

	return &simpleKinesisTxSource{
		client: client,
		stream: stream,
		txs:    make(chan *api.Transaction, readBuffer),
		errors: make(chan error, 1),
		ctx:    ctx,
		cancel: cancel,
	}
}

func listShards(k *kinesis.Kinesis, stream string) (ids []string, err error) {
	req := &kinesis.DescribeStreamInput{
		StreamName: aws.String(stream),
		Limit:      aws.Int64(maxShardsPerDescribe),
	}

	// TODO: use aws-sdk-go's pagination

	resp, err := k.DescribeStreamWithContext(context.TODO(), req)
	if err != nil {
		return nil, err
	}
	if resp.StreamDescription == nil {
		return nil, errors.New("empty DescribeStream response")
	}
	ids = make([]string, len(resp.StreamDescription.Shards))
	for i, shard := range resp.StreamDescription.Shards {
		ids[i] = aws.StringValue(shard.ShardId)
	}

	// If the response indicates pagination is necessary, read
	// through the pages
	for i := 0; i < maxDescribeStreamPages && aws.BoolValue(resp.StreamDescription.HasMoreShards); i++ {
		req.ExclusiveStartShardId = aws.String(ids[len(ids)-1])
		resp, err = k.DescribeStreamWithContext(context.TODO(), req)
		if err != nil {
			return nil, err
		}
		if resp.StreamDescription == nil {
			return nil, errors.New("empty DescribeStream response")
		}

		for _, shard := range resp.StreamDescription.Shards {
			ids = append(ids, aws.StringValue(shard.ShardId))
		}
	}

	return ids, nil
}

func (s *simpleKinesisTxSource) Run() {
	s.stopMu.Lock()
	if s.stop {
		s.stopMu.Unlock()
		return
	}
	s.wg.Add(1)
	s.stopMu.Unlock()

	ids, err := listShards(s.client, s.stream)
	if err != nil {
		s.errors <- err
		return
	}

	// Allow one retry per shard per minute, averaged across all shards
	limitGlobalRetry := rate.NewLimiter(rate.Every(60*time.Second)*rate.Limit(len(ids)), 1)
	for _, id := range ids {
		s.wg.Add(1)
		sc := &shardConsumer{
			source:               s,
			ShardId:              id,
			limitShardGetRecords: rate.NewLimiter(0.5*rate.Every(1*time.Second), 1),
			limitGlobalRetry:     limitGlobalRetry,
		}
		go func() {
			defer s.wg.Done()
			defer s.cancel()
			sc.consume(context.Background())
		}()
	}

	s.wg.Done()

	s.wg.Wait()
	close(s.txs)
	close(s.errors)
}

func (s *simpleKinesisTxSource) reportErr(err error) {
	if s.ctx.Err() == nil {
		// We're shutting down. The user probably doesn't care about these
		// errors, since they're probably caused by the shutdown.
		s.errors <- err
	}
}

type shardConsumer struct {
	source *simpleKinesisTxSource

	ShardId        string
	Iterator       string
	SequenceNumber string

	limitShardGetRecords *rate.Limiter
	limitGlobalRetry     *rate.Limiter

	eventLog trace.EventLog
}

func (sc *shardConsumer) throughputExceededRetry(fn func() error) error {
	next := time.NewTimer(0)
	defer next.Stop()
	for {
		<-next.C
		resetTo(next, 1*time.Second-time.Duration(rand.Int63n(int64(500*time.Millisecond))))

		err := fn()
		if err == nil {
			return nil
		}
		if ae, ok := err.(awserr.Error); ok && ae.Code() == "ProvisionedThroughputExceededException" {
			sc.eventLog.Printf("ProvisionedThroughputExceededException")
			// Back off, with jitter. Sufficient backoff is in our best
			// interest, since AWS seems to disable HTTP Keep-Alive when we
			// make requests too quickly.
			time.Sleep(1*time.Second - time.Duration(rand.Int63n(int64(500*time.Millisecond))))
			continue
		}
		if err != nil {
			sc.source.reportErr(err)
			return err
		}
	}
}

func (sc *shardConsumer) GetShardIterator() error {
	var (
		err      error
		itOutput *kinesis.GetShardIteratorOutput
	)

	itInput := &kinesis.GetShardIteratorInput{
		ShardId:    aws.String(sc.ShardId),
		StreamName: aws.String(sc.source.stream),
	}

	if sc.SequenceNumber == "" {
		itInput.ShardIteratorType = aws.String(kinesis.ShardIteratorTypeLatest)
	} else {
		itInput.ShardIteratorType = aws.String(kinesis.ShardIteratorTypeAfterSequenceNumber)
		itInput.StartingSequenceNumber = aws.String(sc.SequenceNumber)
	}

	if sc.throughputExceededRetry(func() error {
		itOutput, err = sc.source.client.GetShardIteratorWithContext(context.TODO(), itInput)
		if err != nil {
			sc.eventLog.Errorf("GetShardIterator error: %v", err)
		}
		return err
	}) != nil {
		return err
	}

	sc.Iterator = aws.StringValue(itOutput.ShardIterator)

	return nil
}

func (sc *shardConsumer) GetRecords(ctx context.Context) ([]*kinesis.Record, error) {
	var grOutput *kinesis.GetRecordsOutput

	grInput := &kinesis.GetRecordsInput{
		Limit:         aws.Int64(3),
		ShardIterator: aws.String(sc.Iterator),
	}

	i := 0
	fn := func() error {
		i++
		err := sc.limitShardGetRecords.Wait(ctx)
		if err != nil {
			sc.eventLog.Errorf("GetRecords shard rate limit error: %v", err)
			return err
		}
		if i > 1 {
			err := sc.limitGlobalRetry.Wait(ctx)
			if err != nil {
				sc.eventLog.Errorf("GetRecords global rate limit error: %v", err)
				return err
			}
		}

		grOutput, err = sc.source.client.GetRecordsWithContext(ctx, grInput)
		if err != nil {
			sc.eventLog.Errorf("GetRecords attempt %d error: %v", i, err)
			return err
		}
		return nil
	}

	if err := sc.throughputExceededRetry(fn); err != nil {
		return nil, err
	}

	sc.Iterator = aws.StringValue(grOutput.NextShardIterator)

	if n := len(grOutput.Records); n > 0 {
		sc.SequenceNumber = aws.StringValue(grOutput.Records[n-1].SequenceNumber)
	}

	return grOutput.Records, nil
}

func (sc *shardConsumer) consume(ctx context.Context) {
	sc.eventLog = trace.NewEventLog(
		"scanproto.shardConsumer",
		fmt.Sprintf("%s/%s", sc.source.stream, sc.ShardId))
	defer func() {
		sc.eventLog.Finish()
		sc.eventLog = nil
	}()

	err := sc.GetShardIterator()
	if err != nil {
		return
	}

	next := time.NewTimer(100*time.Millisecond - time.Duration(rand.Int63n(int64(50*time.Millisecond))))
	defer next.Stop()
	for sc.Iterator != "" && sc.source.ctx.Err() == nil {

		<-next.C
		resetTo(next, 1*time.Second-time.Duration(rand.Int63n(int64(500*time.Millisecond))))

		records, err := sc.GetRecords(ctx)
		if err != nil {
			if ae, ok := err.(awserr.Error); ok && ae.Code() == "ExpiredIteratorException" {
				sc.eventLog.Printf("ExpiredIteratorException")
				err := sc.GetShardIterator()
				if err != nil {
					// This error has been passed along to the package user
					// already by throughputExceededRetry.
					return
				}
				continue
			}

			// We don't know what error this is - back off a lot in case it's
			// a bad one!
			resetTo(next, 10*time.Second-time.Duration(rand.Int63n(int64(5*time.Second))))
			continue
		}

		for _, record := range records {
			var ts api.TransactionSet
			err := proto.Unmarshal(record.Data, &ts)
			if err != nil {
				sc.source.reportErr(err)
				return
			}
			for _, tx := range ts.Transaction {
				sc.source.txs <- tx
			}
		}
	}
}

func resetTo(t *time.Timer, d time.Duration) {
	t.Stop()
	select {
	case <-t.C:
	default:
	}
	t.Reset(d)
}

func (s *simpleKinesisTxSource) Transactions() <-chan *api.Transaction {
	return s.txs
}

func (s *simpleKinesisTxSource) Errors() <-chan error {
	return s.errors
}

func (s *simpleKinesisTxSource) Stop() {
	s.stopMu.Lock()
	s.stop = true
	s.stopMu.Unlock()

	s.cancel()
	s.wg.Wait()
}
