package memcached

import (
	"context"
	"encoding/json"
	"strconv"
	"time"

	"code.justin.tv/businessviewcount/aperture/internal/util"

	"code.justin.tv/foundation/gomemcache/memcache"
	"code.justin.tv/hygienic/distconf"
	viewcount "code.justin.tv/video/viewcount-api/lib/client"

	multierror "github.com/hashicorp/go-multierror"
	"github.com/pkg/errors"
	"github.com/sirupsen/logrus"
	log "github.com/sirupsen/logrus"
)

const (
	ratiosCacheKeyPrefix        = "ratios."
	freezeCacheKeyPrefix        = "freeze."
	singleChannelCacheKeyPrefix = "viewcount."

	allChannelsCacheKey = "allchannelviewcounts"

	viewcountCacheExpiration = 5 * time.Second

	maxMultiSize = 1000
)

// Cache is an interface for storing items with memcache
type Cache interface {
	GetRatio(ctx context.Context, channelID string) (float64, error)
	GetRatioMulti(ctx context.Context, channelIDs []string) (map[string]float64, error)
	SetRatio(ctx context.Context, channelID string, ratio float64) error

	GetFrozenChannels(ctx context.Context, channelIDs []string) (map[string]*util.ChannelFreeze, error)
	GetFrozenChannel(ctx context.Context, channelID string) (*util.ChannelFreeze, error)
	SetFrozenChannel(ctx context.Context, channelID string, freezeProps *util.ChannelFreeze) error

	SetViewcountForChannel(ctx context.Context, channelID string, views *viewcount.Viewcount) error
	SetViewcountForAllChannels(ctx context.Context, views map[uint64]*viewcount.Viewcount) error
	GetViewcountForChannel(ctx context.Context, channelID string) (*viewcount.Viewcount, error)
	GetViewcountForAllChannels(ctx context.Context) (map[uint64]*viewcount.Viewcount, error)

	ExpirationTime() time.Duration
}

// MemcacheCache stores items with memcache
type MemcacheCache struct {
	MemcacheClient *memcache.Client
	expirationTime *distconf.Duration
}

// NewCache instantiates a Cache.
func NewCache(addr string, maxIdleConns int, pollInterval, timeout time.Duration, expirationTime *distconf.Duration) Cache {
	c, _ := memcache.Elasticache(addr, pollInterval)

	c = c.MaxIdleConns(maxIdleConns)
	c.Timeout = timeout

	return &MemcacheCache{
		MemcacheClient: c,
		expirationTime: expirationTime,
	}
}

// GetRatio gets the ratio of a channel from the cache
func (m *MemcacheCache) GetRatio(ctx context.Context, channelID string) (float64, error) {
	item, err := m.MemcacheClient.Get(ctx, ratiosCacheKeyPrefix+channelID)
	if err != nil || item == nil {
		if err == memcache.ErrCacheMiss {
			return 1.0, nil
		}
		return 1.0, err
	}
	ratio, err := strconv.ParseFloat(string(item.Value), 64)
	if err != nil {
		return 1.0, err
	}
	return ratio, nil
}

// GetRatioMulti gets the ratios for a list of channels from the cache
func (m *MemcacheCache) GetRatioMulti(ctx context.Context, channelIDs []string) (map[string]float64, error) {
	if len(channelIDs) > maxMultiSize {
		return nil, errors.New("memcache: too many keys for GetMulti")
	}

	var errs *multierror.Error
	keys := make([]string, len(channelIDs))
	for i, channelID := range channelIDs {
		keys[i] = ratiosCacheKeyPrefix + channelID
	}

	items, err := m.MemcacheClient.GetMulti(ctx, keys)
	if err != nil {
		errs = multierror.Append(errs, err)
	}

	ratios := make(map[string]float64, len(channelIDs))
	for _, channelID := range channelIDs {
		key := ratiosCacheKeyPrefix + channelID
		item, ok := items[key]
		if !ok || item == nil {
			ratios[channelID] = 1.0
			continue
		}
		ratio, err := strconv.ParseFloat(string(item.Value), 64)
		if err != nil {
			errs = multierror.Append(errs, err)
			ratios[channelID] = 1.0
			continue
		}
		ratios[channelID] = ratio
	}
	return ratios, errs.ErrorOrNil()
}

// SetRatio sets the ratio of a channel into the cache
func (m *MemcacheCache) SetRatio(ctx context.Context, channelID string, ratio float64) error {
	value := []byte(strconv.FormatFloat(ratio, 'f', -1, 64))
	item := &memcache.Item{
		Expiration: int32(m.ExpirationTime().Seconds()),
		Key:        ratiosCacheKeyPrefix + channelID,
		Value:      value,
	}
	return m.MemcacheClient.Set(ctx, item)
}

func (m *MemcacheCache) SetViewcountForChannel(ctx context.Context, channelID string, views *viewcount.Viewcount) error {
	val, err := json.Marshal(views)
	if err != nil {
		return err
	}

	return m.MemcacheClient.Set(ctx, &memcache.Item{
		Expiration: int32(viewcountCacheExpiration.Seconds()),
		Key:        singleChannelCacheKeyPrefix + channelID,
		Value:      val,
	})
}

func (m *MemcacheCache) SetViewcountForAllChannels(ctx context.Context, views map[uint64]*viewcount.Viewcount) error {
	val, err := json.Marshal(views)
	if err != nil {
		return err
	}

	return m.MemcacheClient.Set(ctx, &memcache.Item{
		Expiration: int32(viewcountCacheExpiration.Seconds()),
		Key:        allChannelsCacheKey,
		Value:      val,
	})
}

func (m *MemcacheCache) GetViewcountForChannel(ctx context.Context, channelID string) (*viewcount.Viewcount, error) {
	item, err := m.MemcacheClient.Get(ctx, singleChannelCacheKeyPrefix+channelID)
	if err != nil {
		return nil, err
	}

	var views *viewcount.Viewcount
	err = json.Unmarshal(item.Value, &views)
	if err != nil {
		return nil, err
	}

	return views, nil
}

func (m *MemcacheCache) GetViewcountForAllChannels(ctx context.Context) (map[uint64]*viewcount.Viewcount, error) {
	item, err := m.MemcacheClient.Get(ctx, allChannelsCacheKey)
	if err != nil {
		return nil, err
	}

	var views map[uint64]*viewcount.Viewcount
	err = json.Unmarshal(item.Value, &views)
	if err != nil {
		return nil, err
	}

	return views, nil
}

// ExpirationTime returns the default expiration time
func (m *MemcacheCache) ExpirationTime() time.Duration {
	return m.expirationTime.Get()
}

// SetFrozenChannel sets a freeze props for a given channel
func (m *MemcacheCache) SetFrozenChannel(ctx context.Context, channelID string, freezeProps *util.ChannelFreeze) error {
	jsonBlob, err := json.Marshal(freezeProps)
	if err != nil {
		log.WithFields(log.Fields{
			"channel_id": channelID,
			"props":      *freezeProps,
		}).WithError(err).Error("failed to marshal freeze properties to json")
		return err
	}

	m.logFreeze("memcache_set_frozen_channel", channelID, freezeProps)

	return m.MemcacheClient.Set(ctx, &memcache.Item{
		Key:        freezeCacheKeyPrefix + channelID,
		Expiration: int32(freezeProps.Expiration.Unix()),
		Value:      jsonBlob,
	})
}

// GetFrozenChannels returns a map of channels to their freeze properties
func (m *MemcacheCache) GetFrozenChannels(ctx context.Context, channelIDs []string) (map[string]*util.ChannelFreeze, error) {
	frozenChannels := make(map[string]*util.ChannelFreeze, len(channelIDs))
	var errs *multierror.Error

	if len(channelIDs) > maxMultiSize {
		return nil, errors.New("memcache: too many keys for GetMulti")
	}

	keys := make([]string, len(channelIDs))
	for i, channelID := range channelIDs {
		keys[i] = freezeCacheKeyPrefix + channelID
	}

	items, err := m.MemcacheClient.GetMulti(ctx, keys)
	if err != nil {
		errs = multierror.Append(errs, err)
	}

	for _, channelID := range channelIDs {
		item, ok := items[freezeCacheKeyPrefix+channelID]
		if !ok || item == nil {
			continue
		}

		var freezeProp util.ChannelFreeze
		err := json.Unmarshal(item.Value, &freezeProp)
		if err != nil {
			errs = multierror.Append(errs, err)
			continue
		}

		frozenChannels[channelID] = &freezeProp
	}

	return frozenChannels, errs
}

// GetFrozenChannel returns a channels freeze properties
func (m *MemcacheCache) GetFrozenChannel(ctx context.Context, channelID string) (*util.ChannelFreeze, error) {
	item, err := m.MemcacheClient.Get(ctx, freezeCacheKeyPrefix+channelID)
	if err != nil || item == nil {
		if err == memcache.ErrCacheMiss {
			return nil, nil
		}
		return nil, err
	}
	var freezeProp util.ChannelFreeze
	err = json.Unmarshal(item.Value, &freezeProp)
	if err != nil {
		return nil, err
	}

	m.logFreeze("memcache_get_frozen_channel", channelID, &freezeProp)
	return &freezeProp, nil
}

func (m *MemcacheCache) logFreeze(msg, channelID string, freezeProps *util.ChannelFreeze) {
	logrus.WithFields(logrus.Fields{
		"channel_id":  channelID,
		"vc0":         freezeProps.ViewcountAtCreation,
		"length":      freezeProps.PbyPSessionLength,
		"expiration":  freezeProps.Expiration,
		"ramp_length": freezeProps.RampDownLength,
		"created_at":  freezeProps.CreatedAt,
	}).Info(msg)
}
