package historyin

import (
	"context"
	"errors"
	"fmt"
	"time"

	"github.com/aws/aws-sdk-go/aws"
	"github.com/aws/aws-sdk-go/service/firehose"
	"github.com/aws/aws-sdk-go/service/firehose/firehoseiface"
)

const (
	firehoseBatchMaxRecords = 500
)

var (
	errInvalidPutBatchResponse = errors.New("invalid put batch response")
)

// BatchRunner sends batches to firehose
type batchRunner struct {
	DeliveryStreamName string
	Firehose           firehoseiface.FirehoseAPI
	MaxBatchAge        time.Duration
	Batch              batch
	Logger             Logger
	runnerState        runnerState
}

// Add adds an audit to the batch
func (br *batchRunner) Add(audit *Audit) error {
	return br.Batch.Add(audit)
}

// BatchSize returns the size of the batch
func (br *batchRunner) CurrentBatchSize() int {
	return br.Batch.CurrentSize()
}

// Run pushes batches to firehose
func (br *batchRunner) Run() {
	for !br.runnerState.Stopped() {
		ctx := br.runnerState.Context()
		br.waitForWork(ctx)
		batch := br.Batch.PopBatch(firehoseBatchMaxRecords)
		if len(batch) > 0 {
			br.sendBatch(ctx, batch)
		}
	}

	br.runnerState.MarkDone()
}

func (br *batchRunner) waitForWork(ctx context.Context) {
	if br.runnerState.IsDraining() {
		return
	}

	select {
	case <-ctx.Done():
	case <-time.NewTimer(br.MaxBatchAge).C:
	case <-br.Batch.ThresholdBreach():
		br.Batch.MarkThresholdBreachRead()
	}
}

func (br *batchRunner) Drain() {
	br.runnerState.Drain()
}

// Stop the runner
func (br *batchRunner) Stop(timeout time.Duration) (stopped bool) {
	br.runnerState.Stop()
	return br.runnerState.Wait(timeout)
}

func (br *batchRunner) sendBatch(ctx context.Context, batch []*firehose.Record) {
	var nTry int
	for len(batch) > 0 && !br.runnerState.Stopped() {
		output, err := br.Firehose.PutRecordBatchWithContext(ctx, &firehose.PutRecordBatchInput{
			DeliveryStreamName: aws.String(br.DeliveryStreamName),
			Records:            batch,
		})
		if err != nil {
			br.Logger.Error(fmt.Errorf("error putting to firehose: %s", err.Error()))
			continue
		}

		batch, err = br.failedOnly(batch, output)
		if err != nil {
			br.Logger.Error(fmt.Errorf("error validating firehose batch: %s", err.Error()))
			continue
		}

		nTry++
		br.runnerState.Wait(time.Duration(nTry) * 100 * time.Millisecond)
	}
}

func (br *batchRunner) failedOnly(batch []*firehose.Record, output *firehose.PutRecordBatchOutput) ([]*firehose.Record, error) {
	if len(batch) != len(output.RequestResponses) {
		return nil, errInvalidPutBatchResponse
	}

	newBatch := []*firehose.Record{}
	for nItem, item := range output.RequestResponses {
		if item.ErrorCode == nil {
			continue
		}
		newBatch = append(newBatch, batch[nItem])
	}

	return newBatch, nil
}
