package videoplaygeosession

import (
	"context"
	"fmt"
	"time"

	"code.justin.tv/cb/semki/internal/clients/dynamo"
	"code.justin.tv/cb/semki/internal/clients/sqs"
	"code.justin.tv/cb/semki/internal/stats"

	"github.com/pkg/errors"
	log "github.com/sirupsen/logrus"
)

const (
	// Name is the unique name of this stat
	Name = "videoplaygeosession"
	// TableFormat is the name of the dynamo table used as persistent storage
	TableFormat = "cb-semki-%s-video-play-geo-session"
	// RewriteHours is the minimum amount of hours of sessions we need to overwrite
	// this is because we could have potentially changed a <48hr stream into 3 <=24hr streams
	RewriteHours = 4 * 24 // 4 days
)

// Stat contains the clients the stat needs
type Stat struct {
	Clients *stats.Clients
	Env     string
}

// InitStat prepares a stat for calculation
func InitStat(clients *stats.Clients, env string) *Stat {
	return &Stat{
		Clients: clients,
		Env:     env,
	}
}

// GetTableName returns the dynamo table name for this stat
func GetTableName(env string) string {
	return fmt.Sprintf(TableFormat, env)
}

// Calculate calculates sessions and sends results to dynamo
func (s *Stat) Calculate(ctx context.Context, start time.Time, end time.Time) error {
	rows, err := s.Clients.Redshift.GetVideoPlayGeoSessionAggregate(ctx, start, end)
	if err != nil {
		msg := fmt.Sprintf("stat %s: redshift query failed", Name)

		log.WithError(err).Error(msg)
		return errors.Wrap(err, msg)
	}

	dynBatch := []*DynamoRow{}
	sqsBatch := []sqs.Message{}

	dynCount := 0

	currChannel := rows[0].ChannelID
	currStart := rows[0].SegmentStartTime
	var currRow *DynamoRow
	for idx := 0; idx < len(rows); idx++ {
		if currChannel != rows[idx].ChannelID ||
			!currStart.Equal(rows[idx].SegmentStartTime) {
			dynBatch = append(dynBatch, currRow)
			currRow = nil

			if len(dynBatch) == dynamo.BatchSize {
				sqsBatch = append(sqsBatch, sqs.Message{
					Name:    Name,
					Message: dynBatch,
					Retry:   nil,
				})

				dynCount++
				dynBatch = []*DynamoRow{}
			}

			if len(sqsBatch) == sqs.BatchSize {
				if err := s.Clients.Pool.Acquire(ctx, 1); err != nil {
					msg := fmt.Sprintf("stat %s: failed to acquire semaphore", Name)

					log.WithError(err).Error(msg)
					return errors.Wrap(err, msg)
				}

				go s.sendRowToIngest(sqsBatch, dynCount)

				sqsBatch = []sqs.Message{}
				dynCount = 0
			}

			currChannel = rows[idx].ChannelID
			currStart = rows[idx].SegmentStartTime
		}

		if currRow == nil {
			currRow = &DynamoRow{
				SegmentID:        fmt.Sprintf("%s:%s", rows[idx].ChannelID, rows[idx].SegmentStartTime.Format(dynamo.DynamoTimeFormat)),
				CountryBreakdown: map[string]int64{rows[idx].Country: rows[idx].Count},
			}
		} else {
			currRow.CountryBreakdown[rows[idx].Country] = rows[idx].Count
		}
	}

	if currRow != nil {
		dynBatch = append(dynBatch, currRow)
		dynCount++
	}

	if dynCount > 0 {
		sqsBatch = append(sqsBatch, sqs.Message{
			Name:    Name,
			Message: dynBatch,
			Retry:   nil,
		})

		if err := s.Clients.Pool.Acquire(ctx, 1); err != nil {
			msg := fmt.Sprintf("stat %s: failed to acquire semaphore", Name)

			log.WithError(err).Error(msg)
			return errors.Wrap(err, msg)
		}

		go s.sendRowToIngest(sqsBatch, dynCount)
	}
	return nil
}

func (s *Stat) sendRowToIngest(messages []sqs.Message, count int) {
	s.Clients.OnSemAcquire()

	defer func() {
		s.Clients.Pool.Release(1)
		s.Clients.OnSemRelease()
	}()

	err := s.Clients.SQS.AddBatch(context.Background(), messages, Name)
	if err != nil {
		log.WithError(err).Error(fmt.Sprintf("stat %s: failed to send batch %d to ingest SQS", Name, count))

		s.Clients.OnSQSFailure(Name)
	} else {
		s.Clients.OnSQSSuccess(Name, count)
	}
}
