package rediser

import (
	"context"
	"encoding/json"
	"fmt"
	"log"
	"reflect"
	"strconv"
	"time"

	rc "code.justin.tv/identity/rediser/common"
)

// The following code is adapted from chat/redicache in order to support easier
// usage of Redis as a caching backend.

// Cached is a helper to cache any arbitrary type.
// Notes:
// - Use the Cached* helpers below to cache common data types (like primitives)
// - The type of <res> must be a pointer to the return type of <fn>
//   Examples:
//   - <res> is type *int, <fn> returns 10
//   - <res> is type *Foo, <fn> returns Foo{}
//   - <res> is type **Foo, <fn> returns nil or &Foo{}
func (r *rediser) Cached(ctx context.Context, key string, ttl time.Duration, res interface{}, fn func() (interface{}, error)) error {
	// Check key for cached value
	if err := r.getJSON(ctx, key, &res); err != nil {
		if err != ErrRedisNil {
			log.Println(fmt.Sprintf("cache read error: %v, key: %s", err, key))
		}
	} else {
		// When the specified key does not exist, Redis returns the special error "redis: nil".
		// If we have not received this error (or any error above), then the value is cached.
		return nil
	}

	// Fetch value from true data source; use reflection to set <res> argument
	fetched, err := fn()
	if err != nil {
		return err
	}
	rv := reflect.ValueOf(res).Elem()
	if !rv.CanSet() {
		return fmt.Errorf("cannot call set on zero value")
	}
	if fetched == nil {
		rv.Set(reflect.Zero(rv.Type()))
	} else {
		rv.Set(reflect.ValueOf(fetched))
	}

	// Cache value for future lookups
	if err := r.safeSetJSON(ctx, key, fetched, ttl); err != nil {
		log.Println(fmt.Sprintf("cache write error: %v, key: %s", err, key))
	}
	return nil
}

// CachedBoolTTLs stores durations for use in CachedBoolDifferentTTLs below.
type CachedBoolTTLs struct {
	TrueTTL, FalseTTL time.Duration
}

// CachedBoolDifferentTTLs works like CachedBool but lets the caller specify a
// different cache TTL for true or false values.
// A TTL of 0 implies that result should not be cached.
func (r *rediser) CachedBoolDifferentTTLs(ctx context.Context, key string, ttls CachedBoolTTLs, fn func() (bool, error)) (bool, error) {
	// Check key for cached value
	var res bool
	if err := r.getJSON(ctx, key, &res); err != nil {
		if err != ErrRedisNil {
			log.Println(fmt.Sprintf("cache read error: %v, key: %s", err, key))
		}
	} else {
		// When the specified key does not exist, Redis returns the special error "redis: nil".
		// If we have not received this error (or any error above), then the value is cached.
		return res, nil
	}

	// Fetch value from true data source; use reflection to set <res> argument
	fetched, err := fn()
	if err != nil {
		return false, err
	}

	// Cache value for future lookups
	ttl := ttls.FalseTTL
	if fetched {
		ttl = ttls.TrueTTL
	}
	if ttl > 0 {
		if err := r.safeSetJSON(ctx, key, fetched, ttl); err != nil {
			log.Println(fmt.Sprintf("cache write error: %v, key: %s", err, key))
		}
	}
	return fetched, nil
}

func (r *rediser) CachedBool(ctx context.Context, key string, ttl time.Duration, fn func() (interface{}, error)) (bool, error) {
	var res bool
	err := r.Cached(ctx, key, ttl, &res, fn)
	return res, err
}

func (r *rediser) CachedInt(ctx context.Context, key string, ttl time.Duration, fn func() (interface{}, error)) (int, error) {
	var res int
	err := r.Cached(ctx, key, ttl, &res, fn)
	return res, err
}

func (r *rediser) CachedInt64(ctx context.Context, key string, ttl time.Duration, fn func() (interface{}, error)) (int64, error) {
	var res int64
	err := r.Cached(ctx, key, ttl, &res, fn)
	return res, err
}

func (r *rediser) CachedString(ctx context.Context, key string, ttl time.Duration, fn func() (interface{}, error)) (string, error) {
	var res string
	err := r.Cached(ctx, key, ttl, &res, fn)
	return res, err
}

func (r *rediser) CachedIntSlice(ctx context.Context, key string, ttl time.Duration, fn func() (interface{}, error)) ([]int, error) {
	var res []int
	err := r.Cached(ctx, key, ttl, &res, fn)
	return res, err
}

func (r *rediser) CachedInt64Slice(ctx context.Context, key string, ttl time.Duration, fn func() (interface{}, error)) ([]int64, error) {
	var res []int64
	err := r.Cached(ctx, key, ttl, &res, fn)
	return res, err
}

func (r *rediser) CachedStringSlice(ctx context.Context, key string, ttl time.Duration, fn func() (interface{}, error)) ([]string, error) {
	var res []string
	err := r.Cached(ctx, key, ttl, &res, fn)
	return res, err
}

func (r *rediser) Invalidate(ctx context.Context, key string) error {
	_, err := r.Del(ctx, []string{key}...)
	return err
}

func (r *rediser) RateIncr(ctx context.Context, key string, ttl time.Duration) (int64, error) {
	script := `
local current
current = tonumber(redis.call("incr", KEYS[1]))
if current == 1 then
	redis.call("expire", KEYS[1], ARGV[1])
end
return current`
	strttl := strconv.Itoa(int(ttl.Seconds()))
	out, err := r.Eval(ctx, script, []string{key}, []interface{}{strttl})
	if err != nil {
		return 0, err
	}
	i, ok := out.(int64)
	if !ok {
		return 0, rc.ErrInvalidArguments
	}
	return i, nil
}

func (r *rediser) getJSON(ctx context.Context, key string, res interface{}) error {
	val, err := r.Get(ctx, key)
	if err != nil {
		return err
	}
	err = json.Unmarshal([]byte(val), res)
	return err
}

func (r *rediser) safeSetJSON(ctx context.Context, key string, val interface{}, ttl time.Duration) error {
	marshalled, err := json.Marshal(val)
	if err != nil {
		return err
	}
	_, err = r.SetNX(ctx, key, string(marshalled), ttl)
	return err
}
