package redis

import (
	"context"
	"fmt"
	"time"

	telemetry "code.justin.tv/amzn/TwitchTelemetry"
	"code.justin.tv/chat/rediczar"
	"code.justin.tv/creator-collab/log"
	"code.justin.tv/creator-collab/log/errors"
)

const rateLimitKeyPrefix = "ah:rateLimit"

// RateLimiter supports throttling operations.
type RateLimiter interface {
	RateLimitHostBySourceID(ctx context.Context, sourceID string) bool
	RateLimitUnhostBySourceID(ctx context.Context, sourceID string) bool
}

type RateLimiterParams struct {
	RedisAddress   string
	IsCluster      bool
	SampleReporter *telemetry.SampleReporter
	Logger         log.Logger

	MaxHostOperations  int
	HostWindowDuration time.Duration

	MaxUnhostOperations  int
	UnhostWindowDuration time.Duration
}

var _ RateLimiter = &RateLimiterImpl{}

// NewRateLimiter creates a RateLimiter that throttles operations by using a Redis cluster.
func NewRateLimiter(p *RateLimiterParams) *RateLimiterImpl {

	redisClient := newRedisClient(p.RedisAddress, p.IsCluster, false, p.SampleReporter)
	rateLimiterSampleReporter := newRateLimiterSampleReporter(p.SampleReporter, p.Logger)

	return &RateLimiterImpl{
		logger:                    p.Logger,
		rateLimiterSampleReporter: rateLimiterSampleReporter,
		redisClient:               redisClient,

		hostWindowSize:     p.MaxHostOperations,
		hostWindowDuration: p.HostWindowDuration,

		unhostWindowSize:     p.MaxUnhostOperations,
		unhostWindowDuration: p.UnhostWindowDuration,
	}
}

type RateLimiterImpl struct {
	logger                    log.Logger
	redisClient               rediczar.ThickClient
	rateLimiterSampleReporter *rateLimiterSampleReporter

	hostWindowSize     int
	hostWindowDuration time.Duration

	unhostWindowSize     int
	unhostWindowDuration time.Duration
}

func (r *RateLimiterImpl) RateLimitHostBySourceID(ctx context.Context, sourceID string) bool {
	key := fmt.Sprintf("%s:host:%s", rateLimitKeyPrefix, sourceID)
	return r.rateLimitWithMetrics(ctx, key, r.hostWindowDuration, r.hostWindowSize, "Host")
}

func (r *RateLimiterImpl) RateLimitUnhostBySourceID(ctx context.Context, sourceID string) bool {
	key := fmt.Sprintf("%s:unhost:%s", rateLimitKeyPrefix, sourceID)
	return r.rateLimitWithMetrics(ctx, key, r.unhostWindowDuration, r.unhostWindowSize, "Unhost")
}

func (r *RateLimiterImpl) rateLimitWithMetrics(
	ctx context.Context, key string, windowDuration time.Duration, windowSize int, operationName string) bool {

	allowed, err := r.rateLimit(ctx, key, windowDuration, windowSize)
	if err != nil {
		r.logger.Error(err)
		r.rateLimiterSampleReporter.report(operationName, rateLimitResultError)
		return true
	}

	result := rateLimitResultAllow
	if !allowed {
		result = rateLimitResultThrottle
	}
	r.rateLimiterSampleReporter.report(operationName, result)

	return allowed
}

func (r *RateLimiterImpl) rateLimit(ctx context.Context, key string, windowDuration time.Duration, windowSize int) (bool, error) {
	innerCtx, cancel := context.WithTimeout(ctx, 100*time.Millisecond)
	defer cancel()

	allowed, err := r.redisClient.SlidingWindowRateLimit(innerCtx, key, windowDuration, int64(windowSize))
	if err != nil {
		return true, errors.Wrap(err, "rate limiting operation failed", errors.Fields{
			"key":             key,
			"window_duration": windowDuration.String(),
			"window_size":     windowSize,
		})
	}

	return allowed, nil
}
