package viewcount

import (
	"context"
	"strconv"
	"time"

	"code.justin.tv/businessviewcount/aperture/internal/clients/memcached"
	"code.justin.tv/businessviewcount/aperture/internal/clients/stats"
	"code.justin.tv/foundation/twitchclient"
	viewcount "code.justin.tv/video/viewcount-api/lib/client"

	log "github.com/sirupsen/logrus"
)

const (
	maxRetries        = 5
	retryDelaySeconds = 1
)

// Viewcount is an interface for retrieving view counts from viewcount-api
type Viewcount interface {
	ForAllChannels(ctx context.Context) (map[uint64]*viewcount.Viewcount, error)
	ForChannel(ctx context.Context, channelID uint64) (*viewcount.Viewcount, error)
}

// Client wraps the viewcount-api protobuf client
type Client struct {
	viewcountClient *viewcount.Client
	cacheClient     memcached.Cache
	statsd          stats.StatSender
}

// NewClient creates a new wrapper around the viewcount-api protobuf client
func NewClient(host string, cache memcached.Cache, statsd stats.StatSender) (*Client, error) {
	httpClient := twitchclient.NewHTTPClient(twitchclient.ClientConf{
		Host:  host,
		Stats: statsd,
	})
	client, err := viewcount.New(&viewcount.Config{
		URL:        host,
		HTTPClient: httpClient,
	})
	if err != nil {
		return nil, err
	}
	return &Client{
		viewcountClient: client,
		cacheClient:     cache,
		statsd:          statsd,
	}, nil
}

// ForAllChannels returns view counts for all live channels
func (c *Client) ForAllChannels(ctx context.Context) (map[uint64]*viewcount.Viewcount, error) {
	start := time.Now()
	defer func() {
		c.statsd.ExecutionTime("viewcount_client.for_all_channels", time.Since(start))
	}()

	views, err := c.cacheClient.GetViewcountForAllChannels(ctx)

	if err != nil || views == nil {
		go c.statsd.Increment("viewcount_client.for_all_channels.cache_miss", 1)

		viewcounts, err := c.sendRequestForViewcounts(ctx)
		if err != nil {
			return nil, err
		}

		go func() {
			cacheErr := c.cacheClient.SetViewcountForAllChannels(context.Background(), viewcounts)
			if cacheErr != nil {
				log.WithError(cacheErr).Error("failed to cache viewcounts for all")
			}
		}()

		return viewcounts, nil
	}

	go c.statsd.Increment("viewcount_client.for_all_channels.cache_hit", 1)
	return views, nil
}

// ForChannel returns a view count for a single channel
func (c *Client) ForChannel(ctx context.Context, channelID uint64) (*viewcount.Viewcount, error) {
	start := time.Now()
	defer func() {
		c.statsd.ExecutionTime("viewcount_client.for_channel", time.Since(start))
	}()

	serviceStart := time.Now()
	viewcount, err := c.cacheClient.GetViewcountForChannel(ctx, strconv.FormatUint(channelID, 10))
	c.statsd.ExecutionTime("viewcount.for_channel", time.Since(serviceStart))

	if err != nil || viewcount == nil {
		go c.statsd.Increment("viewcount_client.for_channel.cache_miss", 1)

		views, err := c.viewcountClient.ForChannel(ctx, channelID, "aperture")
		if err != nil {
			log.WithField("channel_id", channelID).WithError(err).Error("failed to retrieve viewcounts for channel")
			return nil, err
		}

		go func() {
			cacheErr := c.cacheClient.SetViewcountForChannel(context.Background(), strconv.FormatUint(channelID, 10), views)
			if cacheErr != nil {
				log.WithError(cacheErr).Error("failed to cache viewcount for single channel")
			}
		}()

		return views, nil
	}

	go c.statsd.Increment("viewcount_client.for_channel.cache_hit", 1)
	return viewcount, nil
}

func (c *Client) sendRequestForViewcounts(ctx context.Context) (map[uint64]*viewcount.Viewcount, error) {
	var lastError error

	for try := 1; try <= maxRetries; try++ {
		if try != 1 {
			log.WithFields(log.Fields{
				"retry_count":    try,
				"previous_error": lastError,
			}).Warn("retrying call to viewcount-api after previous failure")

			// Use a linear backoff, starting at 1 second and
			// increasing by 1 second each time, to a maximum of 4 seconds
			sleepDuration := time.Duration((try - 1) * retryDelaySeconds)
			time.Sleep(sleepDuration * time.Second)
		}

		// Check if context is canceled or have timed out.
		if ctx.Err() != nil {
			log.WithError(ctx.Err()).Error("context done while sending request to viewcount-api")
			return nil, ctx.Err()
		}

		start := time.Now()
		viewcounts, err := c.viewcountClient.ForAllChannels(ctx, "aperture")
		c.statsd.ExecutionTime("viewcount.for_all_channels", time.Since(start))

		if err != nil {
			log.WithError(err).Error("failed to send all channels request to viewcount-api")
			lastError = err
			continue
		}

		return viewcounts, nil
	}

	return nil, lastError
}
