package cache

import (
	"context"
	"fmt"
	"strconv"
	"time"

	"code.justin.tv/feeds/distconf"
	"code.justin.tv/feeds/errors"
	"code.justin.tv/feeds/graphdb/cmd/graphdb/internal/accesslog"
	"code.justin.tv/feeds/graphdb/cmd/graphdb/internal/graphdbmodel"
	"code.justin.tv/feeds/graphdb/cmd/graphdb/internal/sns"
	"code.justin.tv/feeds/graphdb/cmd/graphdb/internal/storage"
	"code.justin.tv/feeds/graphdb/cmd/graphdb/internal/storage/tablelookup"
	"code.justin.tv/feeds/log"
	"code.justin.tv/hygienic/objectcache"
	"code.justin.tv/hygienic/statsdsender"
	"code.justin.tv/hygienic/workerpool"
)

// Config configures the cache
type Config struct {
	DisableCache           *distconf.Bool
	DefaultConfigServer    *distconf.Str
	ModerationConfigServer *distconf.Str
	ListItemLimit          *distconf.Int
}

// Changing this will bump everything out of the cache
const cachePrefixVersion = "1"

// Load configuration from distconf
func (c *Config) Load(dconf *distconf.Distconf) error {
	c.DisableCache = dconf.Bool("graphdb.cache.disabled", true)
	c.DefaultConfigServer = dconf.Str("graphdb.cache.follow.config", "")
	c.ModerationConfigServer = dconf.Str("graphdb.cache.mods.config", "")
	c.ListItemLimit = dconf.Int("graphdb.cache.list_limit", 2000)
	return nil
}

// ObjCache is anything that can fetch and invalidate keys
type ObjCache interface {
	// Invalidate a previous key (IE delete it)
	Invalidate(ctx context.Context, s string) error
	// ForceCached will force an object into the cache for a key
	ForceCached(ctx context.Context, key string, object interface{}) error
	// Cached will fetch object into key if it's in the cache.  Otherwise it will execute callback to get the
	// object and put it in the cache as key
	Cached(ctx context.Context, key string, callback func() (interface{}, error), object interface{}) error
}

// MemcacheCache stores items with memcache.  It implements the basic Edge storage interface, intercepting requests
// and attempting to read them from memcache first.
type MemcacheCache struct {
	Caches       map[string]ObjCache               `nilcheck:"nodepth"`
	DefaultCache ObjCache                          `nilcheck:"nodepth"`
	Stats        *statsdsender.ErrorlessStatSender `nilcheck:"nodepth"`
	Config       *Config
	// Storage is where we forward requests to when they aren't in cache
	Storage storage.Client
	Lookup  *tablelookup.Lookup
	Log     *log.ElevatedLog
	// AsyncDirtyPool allows us to do memcache dirties in the background
	AsyncDirtyPool workerpool.Pool
	BlockForDirty  *distconf.Bool
	SNSClient      *sns.SNSClient
}

var _ storage.Client = &MemcacheCache{}
var _ ObjCache = &objectcache.ObjectCache{}

func (m *MemcacheCache) cache(edgeType string) ObjCache {
	ret, exists := m.Caches[edgeType]
	if exists {
		return ret
	}
	return m.DefaultCache
}

// InvalidateCache dirties all the cache keys that could possibly relate to an edge (counts, lists, and gets in both
// directions).
func (m *MemcacheCache) InvalidateCache(ctx context.Context, edge graphdbmodel.Edge) (err error) {
	edgeInfo := m.Lookup.LookupEdge(edge.Type)
	if edgeInfo == nil {
		return errors.Errorf("unable to find edge kind %s", edge.Type)
	}

	keys := []string{
		countCacheKey(edge.From, edge.Type),
		countCacheKey(edge.To, edgeInfo.ReverseName),
	}

	if edgeInfo.PrimaryEdgeType {
		keys = append(keys, getCacheKey(edge))
	} else {
		keys = append(keys, getCacheKey(edge.Reversed(edgeInfo.ReverseName)))
	}

	keys = append(keys, listCacheKey(edge.From, edge.Type, false))
	keys = append(keys, listCacheKey(edge.From, edge.Type, true))
	keys = append(keys, listCacheKey(edge.To, edgeInfo.ReverseName, false))
	keys = append(keys, listCacheKey(edge.To, edgeInfo.ReverseName, true))

	for _, k := range keys {
		a := asyncDirtyRequest{
			k:     k,
			cache: m.cache(edge.Type),
			log:   m.Log,
		}
		f := workerpool.OfferOrDo(&m.AsyncDirtyPool, a.do)
		if m.BlockForDirty.Get() {
			<-f.Done()
		}
	}
	return err
}

// asyncDirtyRequest allows workerpool.Pool to do memcache dirty in the background
type asyncDirtyRequest struct {
	k     string
	cache ObjCache
	log   log.Logger
}

func (a *asyncDirtyRequest) do() interface{} {
	ctx, cancel := context.WithTimeout(context.Background(), time.Second*2)
	defer cancel()
	err := a.cache.Invalidate(ctx, a.k)
	if err != nil {
		// Log here since we never inspect the result later
		a.log.Log("err", err, "key", a.k, "could not delete key from cache")
	}
	return err
}

func (m *MemcacheCache) isDisabled() bool {
	return m.Config.DisableCache.Get()
}

// Create only checks if the item already exists.
// Create waits for the storage layer to create the item and then lazily populates the cache
func (m *MemcacheCache) Create(ctx context.Context, edge graphdbmodel.Edge, data *graphdbmodel.DataBag, creationTime time.Time) (*graphdbmodel.LoadedEdge, error) {
	if m.isDisabled() {
		return m.Storage.Create(ctx, edge, data, creationTime)
	}

	go func() {
		err := m.SNSClient.SendInvalidateCacheMessage(edge)
		if err != nil {
			m.Stats.IncC("send_invalidate_cache_message.count.err", 1, 1.0)
		}
	}()
	if err := m.InvalidateCache(ctx, edge); err != nil {
		m.Log.Log("err", err)
	}
	loadedEdge, err := m.Storage.Create(ctx, edge, data, creationTime)
	if err != nil {
		return nil, err
	}
	if loadedEdge != nil {
		// repopulate created edges back into the cache
		if err := m.cache(edge.Type).ForceCached(ctx, getCacheKey(loadedEdge.Edge), loadedEdge); err != nil {
			m.Stats.IncC("backfill.create.cache.err", 1, 1)
		}
	}
	return loadedEdge, nil
}

func (m *MemcacheCache) Start() error {
	// Does nothing for runs.  Just exists
	return nil
}

func (m *MemcacheCache) Close() error {
	return m.AsyncDirtyPool.Close()
}

// UpdateOrPut stores a new edge into memcache, after updating it
func (m *MemcacheCache) UpdateOrPut(ctx context.Context, edge graphdbmodel.Edge, data *graphdbmodel.DataBag, updateDatabag bool, creationTime *time.Time, requiredVersion *time.Time) (*graphdbmodel.LoadedEdge, error) {
	if m.isDisabled() {
		return m.Storage.UpdateOrPut(ctx, edge, data, updateDatabag, creationTime, requiredVersion)
	}

	go func() {
		err := m.SNSClient.SendInvalidateCacheMessage(edge)
		if err != nil {
			m.Stats.IncC("send_invalidate_cache_message.count.err", 1, 1.0)
		}
	}()
	if err := m.InvalidateCache(ctx, edge); err != nil {
		m.Log.Log("err", err)
	}
	loadedEdge, err := m.Storage.UpdateOrPut(ctx, edge, data, updateDatabag, creationTime, requiredVersion)
	if err != nil {
		return nil, err
	}
	if loadedEdge != nil {
		if err := m.cache(edge.Type).ForceCached(ctx, getCacheKey(loadedEdge.Edge), loadedEdge); err != nil {
			m.Stats.IncC("backfill.update.cache.err", 1, 1)
		}
	}
	return loadedEdge, nil
}

// getCacheKey assumes the edgeKind passed in is a PrimaryEdgeType
func getCacheKey(edge graphdbmodel.Edge) string {
	return fmt.Sprintf("%s:%s:%s:%s", cachePrefixVersion, edge.From.Encode(), edge.To.Encode(), edge.Type)
}

// Get an edge from memcache, or storage (and put back in cache).  If the edge type isn't the primary edge direction, we
// reverse the edge before checking memcache
func (m *MemcacheCache) Get(ctx context.Context, edge graphdbmodel.Edge) (*graphdbmodel.LoadedEdge, error) {
	if m.isDisabled() {
		return m.Storage.Get(ctx, edge)
	}

	edgeInfo := m.Lookup.LookupEdge(edge.Type)
	if edgeInfo == nil {
		return nil, errors.Errorf("unable to find edge kind %s", edge.Type)
	}

	// We only store the primary edge direction in memcache.  If it's the reverse (followed_by vs follows), we reverse
	// the edge and look for that.
	if !edgeInfo.PrimaryEdgeType {
		as, err := m.Get(ctx, edge.Reversed(edgeInfo.ReverseName))
		if err != nil {
			return nil, err
		}
		if as == nil {
			return nil, nil
		}
		// If we reversed the edge to look in memcache, we have to reverse it back when we return
		ras := as.Reverse(edge.Type)
		return &ras, nil
	}
	key := getCacheKey(edge)
	assoc := &graphdbmodel.LoadedEdge{}
	cacheHit := true
	err := m.cache(edge.Type).Cached(ctx, key, func() (interface{}, error) {
		cacheHit = false
		return m.Storage.Get(ctx, edge)
	}, &assoc)
	m.cacheResult(ctx, "get", edge.Type, cacheHit)
	if err != nil {
		return nil, err
	}
	if assoc == nil {
		return assoc, nil
	}
	return assoc, nil
}

func countCacheKey(from graphdbmodel.Node, edgeKind string) string {
	return fmt.Sprintf("%s:%s:%s:count", cachePrefixVersion, from.Encode(), edgeKind)
}

// Count returns the number of edges of a type, checking memcache first
func (m *MemcacheCache) Count(ctx context.Context, from graphdbmodel.Node, edgeKind string) (int64, error) {
	if m.isDisabled() {
		return m.Storage.Count(ctx, from, edgeKind)
	}

	var cnt int64
	key := countCacheKey(from, edgeKind)
	cacheHit := true
	err := m.cache(edgeKind).Cached(ctx, key, func() (interface{}, error) {
		cacheHit = false
		c, err := m.Storage.Count(ctx, from, edgeKind)
		if err != nil {
			return nil, err
		}
		return c, nil
	}, &cnt)
	m.cacheResult(ctx, "count", edgeKind, cacheHit)
	return cnt, err
}

func listCacheKey(from graphdbmodel.Node, edgeKind string, descendingOrder bool) string {
	var directionKey string
	if descendingOrder {
		directionKey = "D"
	} else {
		directionKey = "A"
	}
	return fmt.Sprintf("%s:%s:%s:%s:list", cachePrefixVersion, from.Encode(), edgeKind, directionKey)
}

// trimListResult processes a cached ListResult and returns just the parts of it the user asked for
func trimListResult(result *graphdbmodel.ListResult, cursor string, limit int64) *graphdbmodel.ListResult {
	// find the right starting point:
	if cursor != "" {
		foundCursor := false
		// Do not go to the end since that means the trimmed result (from doing [i+1:]) would be empty
		// Instead of returning nothing, goto storage
		for i := 0; i < len(result.To)-1; i++ {
			if result.To[i].Cursor == cursor {
				result.To = result.To[i+1:]
				foundCursor = true
				break
			}
		}
		// If the cursor isn't cached, then goto storage
		if !foundCursor {
			return nil
		}
	}

	// once we have the right starting point, it is possible we still need to trim.
	if result != nil && limit != 0 && int64(len(result.To)) > limit {
		nr := &graphdbmodel.ListResult{
			Cursor: result.To[limit-1].Cursor,
			To:     result.To[:limit],
		}
		return nr
	}
	return result
}

// List returns results from the cache only if cursor is empty or limit is less than Config.ListItemLimit
// Note that we actually cache around 1,000 - 2,000 first followers, even if they only ask for a few.  This makes the next
// few requests fast.
func (m *MemcacheCache) List(ctx context.Context, from graphdbmodel.Node, edgeKind string, page graphdbmodel.PagedRequest) (*graphdbmodel.ListResult, error) {
	listItemLimit := m.Config.ListItemLimit.Get()

	// rds cursors are not cached.
	if _, err := strconv.Atoi(page.Cursor); err == nil {
		m.Stats.IncC("list"+"."+edgeKind+"."+"rdsCursor", 1, 1.0)
		return m.Storage.List(ctx, from, edgeKind, page)
	}

	if m.isDisabled() || page.Limit > listItemLimit {
		return m.Storage.List(ctx, from, edgeKind, page)
	}

	edgeInfo := m.Lookup.LookupEdge(edgeKind)
	if edgeInfo == nil {
		return m.Storage.List(ctx, from, edgeKind, page)
	}

	key := listCacheKey(from, edgeKind, page.DescendingOrder)
	result := new(graphdbmodel.ListResult)
	cacheHit := true
	err := m.cache(edgeKind).Cached(ctx, key, func() (interface{}, error) {
		cacheHit = false
		// We change the input page to have a large limit
		// Do *not* include the cursor in this call
		newPage := graphdbmodel.PagedRequest{
			Limit:           listItemLimit,
			DescendingOrder: page.DescendingOrder,
		}
		return m.Storage.List(ctx, from, edgeKind, newPage)
	}, &result)

	m.cacheResult(ctx, "list", edgeKind, cacheHit)
	if err != nil {
		if result == nil || len(result.To) == 0 {
			// No real result to use.  Return the error we have.
			return nil, err
		}
		// Here we can just use result like normal (since we got something from the backend)
	}

	// At this point, we have to trim the cached result down from 2,000 items to what they wanted
	result = trimListResult(result, page.Cursor, page.Limit)
	// cache hit, however, cursor is beyond info stored in the cache
	if result == nil {
		// Caller is trying to scroll past where we cache.  Just fall back to storage
		return m.Storage.List(ctx, from, edgeKind, page)
	}
	return result, nil
}

func (m *MemcacheCache) cacheResult(ctx context.Context, method string, edgeType string, isHit bool) {
	if isHit {
		accesslog.TraceIntInc(ctx, "cache_hit", 1)
		m.Stats.IncC(method+"."+edgeType+"."+"hit", 1, .1)
	} else {
		accesslog.TraceIntInc(ctx, "cache_miss", 1)
		m.Stats.IncC(method+"."+edgeType+"."+"miss", 1, .1)
	}
}

// Delete an item and clear the cache for it.
func (m *MemcacheCache) Delete(ctx context.Context, edge graphdbmodel.Edge, requiredVersion *time.Time) (*graphdbmodel.LoadedEdge, error) {
	if m.isDisabled() {
		return m.Storage.Delete(ctx, edge, requiredVersion)
	}

	go func() {
		err := m.SNSClient.SendInvalidateCacheMessage(edge)
		if err != nil {
			m.Stats.IncC("send_invalidate_cache_message.count.err", 1, 1.0)
		}
	}()
	if err := m.InvalidateCache(ctx, edge); err != nil {
		m.Log.Log("err", err)
	}
	ret, err := m.Storage.Delete(ctx, edge, requiredVersion)
	if err != nil {
		return nil, err
	}

	// ret is nil if the delete operation did not happen at the dynamodb layer.
	// This can happen if the edge did not exist or if the version was a mismatch.
	// Only remove the item from the cache if the operation actually happened.
	// Otherwise, assume that what is currently in the cache is consistent.
	if ret != nil {
		var nilItem *graphdbmodel.LoadedEdge
		if err := m.cache(edge.Type).ForceCached(ctx, getCacheKey(edge), nilItem); err != nil {
			m.Stats.IncC("backfill.delete.cache.err", 1, 1)
		}
	}
	return ret, nil
}

// OverrideCount is called when we repair edge counts (admin action, not normally called).
func (m *MemcacheCache) OverrideCount(ctx context.Context, from graphdbmodel.Node, edgeKind string, count int64) (oldCount int64, err error) {
	oldCount, err = m.Storage.OverrideCount(ctx, from, edgeKind, count)
	if err != nil || m.isDisabled() {
		return
	}
	m.NewCount(from, edgeKind, count)
	return
}

// NewCount is called from the storage layer. After we create or remove an edge, we pre cache the new 'count' value for
// that edge type.
func (m *MemcacheCache) NewCount(from graphdbmodel.Node, edgeKind string, newCount int64) {
	ctx, cancel := context.WithTimeout(context.Background(), time.Second*1)
	defer cancel()
	if err := m.cache(edgeKind).ForceCached(ctx, countCacheKey(from, edgeKind), newCount); err != nil {
		m.Stats.IncC("backfill.update.newcount.err", 1, 1)
	}
}
