package usermutation

import (
	"fmt"
	"log"
	"os"
	"os/signal"
	"sync"
	"syscall"

	"github.com/aws/aws-sdk-go/aws"
	"github.com/aws/aws-sdk-go/service/kinesis"
)

const (
	maxDescribeNumberOfShards = 1000
	kinesisStream             = "user-mutations"
)

// StreamConsumer is
type StreamConsumer interface {
	Start() error
	Shutdown()
}

type streamConsumerImpl struct {
	checkpointStore CheckpointStore
	kinesisClient   KinesisClient
	workerManager   WorkerManager
	callback        ConsumerFunc
	waitGroup       *sync.WaitGroup
}

// NewStreamConsumer allocates and returns a new StreamConsumer
func NewStreamConsumer(kinesisClient KinesisClient, checkpointStore CheckpointStore, workerManager WorkerManager, callback ConsumerFunc) StreamConsumer {
	return &streamConsumerImpl{
		kinesisClient:   kinesisClient,
		checkpointStore: checkpointStore,
		callback:        callback,
		workerManager:   workerManager,
	}
}

func (s *streamConsumerImpl) Start() error {
	s.waitGroup = &sync.WaitGroup{}
	s.waitGroup.Add(1)
	shards, err := s.getShards()
	if err != nil {
		return err
	}
	if len(shards) == 0 {
		return fmt.Errorf("No Kinesis shards found")
	}
	log.Printf("Found %d shards...", len(shards))

	for _, shard := range shards {
		s.workerManager.InitializeWorker(*shard.ShardId, s.kinesisClient, s.checkpointStore, s.callback)
	}
	s.workerManager.SpawnWorkers()

	c := make(chan os.Signal, 1)
	signal.Notify(c, os.Interrupt)
	signal.Notify(c, syscall.SIGTERM)
	go func() {
		<-c
		s.Shutdown()
	}()

	s.waitGroup.Wait()
	log.Printf("Terminating run loop")
	return nil
}

func (s *streamConsumerImpl) Shutdown() {
	s.workerManager.Shutdown()
	s.waitGroup.Done()
}

func (s *streamConsumerImpl) getShards() ([]*kinesis.Shard, error) {
	shards := []*kinesis.Shard{}
	hasMoreShards := true
	params := &kinesis.DescribeStreamInput{
		StreamName: aws.String(kinesisStream),
		Limit:      aws.Int64(maxDescribeNumberOfShards),
	}
	for {
		describeOutput, err := s.kinesisClient.DescribeStream(params)
		if err != nil {
			return nil, err
		}
		shards = append(shards, describeOutput.StreamDescription.Shards...)
		hasMoreShards = *describeOutput.StreamDescription.HasMoreShards
		if hasMoreShards {
			lastShard := shards[len(shards)-1]
			params.ExclusiveStartShardId = lastShard.ShardId
		} else {
			break
		}
	}
	return shards, nil
}
