package rediser

import (
	"context"
	"fmt"
	"log"
	"sync"
	"sync/atomic"
	"time"

	"github.com/cactus/go-statsd-client/statsd"
	"github.com/go-redis/redis"

	rc "code.justin.tv/identity/rediser/common"
)

var (
	// Test mock interface.
	testMockHandler = (*MockHandler)(nil)
)

// Handler implements the interface of a Redis client.
type Handler interface {
	// Del removes the given keys from the database. Returns the number of keys
	// removed.
	Del(ctx context.Context, keys ...string) (int64, error)

	// Eval runs a Lua script against the dataset, returning arbitrary data.
	Eval(ctx context.Context, script string, keys []string, args []interface{}) (interface{}, error)

	// Exists returns the number of keys which were found in the database.
	Exists(ctx context.Context, keys ...string) (int64, error)

	// Expire deletes the given key from the database after ttl has passed.
	// Returns true when the timeout is set successfully.
	Expire(ctx context.Context, key string, ttl time.Duration) (bool, error)

	// FlushDB deletes all keys in the current database.
	FlushDB(ctx context.Context) error

	// Get retrieves the value of the given key from the database.
	Get(ctx context.Context, key string) (string, error)

	// HDel removes the specified fields from the hash stored at key. Returns the
	// number of fields removed from the hash.
	HDel(ctx context.Context, key string, fields ...string) (int64, error)

	// HGetAll returns all fields and values of the hash stored at key.
	HGetAll(ctx context.Context, key string) (map[string]string, error)

	// HKeys returns all field names in the hash stored at key.
	HKeys(ctx context.Context, key string) ([]string, error)

	// HMGet returns the values associated with the specified fields in the hash
	// stored at key.
	HMGet(ctx context.Context, key string, fields ...string) ([]interface{}, error)

	// HMSet sets the specified fields to their respective values in the hash
	// stored at key.
	HMSet(ctx context.Context, key string, fields map[string]interface{}) error

	// HSet sets field in the hash stored at key to val. Returns true when the
	// field was created, and false when the field was updated.
	HSet(ctx context.Context, key string, field string, val interface{}) (bool, error)

	// HSetNX sets field in the hash stored at key to val if not set. Returns true when the
	// field was created, and false when the field was already set
	HSetNX(ctx context.Context, key string, field string, val interface{}) (bool, error)

	// HIncrBy increments field in the hash stored at key by val
	// Returns current value of field:key after operation
	HIncrBy(ctx context.Context, key string, field string, val int64) (int64, error)

	// HLen returns the number of fields stored in a hash key
	HLen(ctx context.Context, key string) (int64, error)

	// Incr increments the number stored at key by one. Returns its new value.
	Incr(ctx context.Context, key string) (int64, error)

	// LLen returns the length of the specified list.
	LLen(ctx context.Context, key string) (int64, error)

	// LRange returns the specified elements of the list stored at key. Results
	// are inclusive of start and stop indices.
	LRange(ctx context.Context, key string, start int64, stop int64) ([]string, error)

	// LRem removes the first count occurrences of elements equal to val from the
	// given list. Returns the number of elements removed.
	LRem(ctx context.Context, key string, count int64, val interface{}) (int64, error)

	// LTrim removes all values from the list stored at key which are outside the
	// start:stop range.
	LTrim(ctx context.Context, key string, start int64, stop int64) error

	// MGet returns the values of all specified keys.
	MGet(ctx context.Context, keys ...string) ([]interface{}, error)

	// MSet sets the given keys to their respective values.
	MSet(ctx context.Context, pairs ...interface{}) error

	// MSetNxWithTTL leverages redis pipelining to issue many SET commands with NX and TTL specified.
	// Pipelining is necessary since MSET alone does not support NX and TTL arguments.
	MSetNxWithTTL(ctx context.Context, ttl time.Duration, pairs ...interface{}) error

	// MSetWithTTL leverages redis pipelining to issue many SET commands with TTL specified.
	// Pipelining is necessary since MSET alone does not support TTL arguments.
	MSetWithTTL(ctx context.Context, ttl time.Duration, pairs ...interface{}) error

	// PipelinedGet leverages redis pipelining to issue multiple GET commands and returns the values of specified keys
	// This can be used to perform MGETs accross a cluster without client side hashing
	PipelinedGet(ctx context.Context, keys ...string) ([]interface{}, error)

	// PipelinedInvalidate leverages redis pipelining to issue multiple DEL commands and returns the
	// number of keys that were deleted
	// This can be used to perform DELs on multiple keys accross a cluster without client side hashing
	PipelinedInvalidate(ctx context.Context, keys ...string) (int64, error)

	// RPush adds the given value to the list stored at key. Returns the length
	// of the list after adding the new value.
	RPush(ctx context.Context, key string, val string) (int64, error)

	// SAdd adds elements to a set. Returns the number of elements newly inserted,
	// ignoring elements already in the set.
	SAdd(ctx context.Context, key string, vals ...interface{}) (int64, error)

	// SDiff returns the members of the set resulting from the difference between
	// the first set and all the successive sets.
	SDiff(ctx context.Context, keys ...string) ([]string, error)

	// Set writes the given value to the database.
	Set(ctx context.Context, key string, val string, ttl time.Duration) error

	// SetNX writes the given value to the database if it does not exist.
	SetNX(ctx context.Context, key string, val string, ttl time.Duration) (bool, error)

	// SIsMember returns whether the element was found in the set.
	SIsMember(ctx context.Context, key string, val interface{}) (bool, error)

	// SMembers returns all elements in a set.
	SMembers(ctx context.Context, key string) ([]string, error)

	// SRem removes elements from a set. Returns the number of elements removed,
	// ignoring elements that were not in the set.
	SRem(ctx context.Context, key string, vals ...interface{}) (int64, error)

	// SScan fetches elements in a set. Returns a cursor and list of elements.
	// Pass "*" for the match arg to match all results.
	SScan(ctx context.Context, key string, cursor uint64, match string, count int64) (uint64, []string, error)

	// TTL returns the remaining TTL of the given key.
	TTL(ctx context.Context, key string) (time.Duration, error)

	// ZAdd adds elements to a sorted set. Returns the number of added elements.
	ZAdd(ctx context.Context, key string, items ...redis.Z) (int64, error)

	// ZCard returns the number of elements in a sorted set.
	ZCard(ctx context.Context, key string) (int64, error)

	// ZRangeByLex returns all elements of the sorted set lexicographically between min and max
	// (inclusive).
	ZRangeByLex(ctx context.Context, key string, opt redis.ZRangeBy) ([]string, error)

	// ZRangeByScore returns all elements of the sorted set with a score between min and max
	// (inclusive).
	ZRangeByScore(ctx context.Context, key string, opt redis.ZRangeBy) ([]string, error)

	// ZRem removes an element from a sorted set. Returns the number of elements
	// removed.
	ZRem(ctx context.Context, key string, element string) (int64, error)

	// ZRemRangeByScore removes all elements from a sorted set whose values are
	// inside the min:max range. Returns the number of elements removed.
	ZRemRangeByScore(ctx context.Context, key, min, max string) (int64, error)

	// ZScore returns the score of an element in a sorted set.
	ZScore(ctx context.Context, key string, element string) (float64, error)

	// Cache Helpers below. All cache helpers attempt to retrieve a value from
	// Redis. If the request succeeds, the value is returned. If not, the value
	// is retrieved from a primary data source (via the fn func argument), and is
	// then cached in Redis for future requests.

	Cached(ctx context.Context, key string, ttl time.Duration, res interface{}, fn func() (interface{}, error)) error
	CachedBool(ctx context.Context, key string, ttl time.Duration, fn func() (interface{}, error)) (bool, error)
	CachedBoolDifferentTTLs(ctx context.Context, key string, ttls CachedBoolTTLs, fn func() (bool, error)) (bool, error)
	CachedInt(ctx context.Context, key string, ttl time.Duration, fn func() (interface{}, error)) (int, error)
	CachedInt64(ctx context.Context, key string, ttl time.Duration, fn func() (interface{}, error)) (int64, error)
	CachedString(ctx context.Context, key string, ttl time.Duration, fn func() (interface{}, error)) (string, error)
	CachedIntSlice(ctx context.Context, key string, ttl time.Duration, fn func() (interface{}, error)) ([]int, error)
	CachedInt64Slice(ctx context.Context, key string, ttl time.Duration, fn func() (interface{}, error)) ([]int64, error)
	CachedStringSlice(ctx context.Context, key string, ttl time.Duration, fn func() (interface{}, error)) ([]string, error)

	// Invalidate deletes a single cache key with Del().
	Invalidate(ctx context.Context, key string) error

	// RateIncr is a convenience function for incrementing and optionally expiring
	// an item, using the EVAL command. Returns the new value of key.
	RateIncr(ctx context.Context, key string, ttl time.Duration) (int64, error)

	// Close closes the Handler, releasing any open resources.
	// It is rare to Close a Handler, as the Handler is meant to be long-lived and shared between many goroutines.
	Close() error
}

type rediser struct {
	Client RedisClient
	Stats  statsd.Statter
	Opts   *rc.Options

	lock      sync.Mutex
	active    int32
	maxActive int32

	shutdown   sync.Once
	shutdownCh chan struct{}
}

// RedisClient is the common interface between *redis.ClusterClient and *redis.Ring.
type RedisClient interface {
	Del(...string) *redis.IntCmd
	Eval(string, []string, ...interface{}) *redis.Cmd
	Exists(...string) *redis.IntCmd
	Expire(string, time.Duration) *redis.BoolCmd
	FlushDb() *redis.StatusCmd
	Get(string) *redis.StringCmd
	HDel(string, ...string) *redis.IntCmd
	HGetAll(string) *redis.StringStringMapCmd
	HKeys(string) *redis.StringSliceCmd
	HMGet(string, ...string) *redis.SliceCmd
	HMSet(string, map[string]interface{}) *redis.StatusCmd
	HSet(string, string, interface{}) *redis.BoolCmd
	HSetNX(string, string, interface{}) *redis.BoolCmd
	HIncrBy(string, string, int64) *redis.IntCmd
	HLen(string) *redis.IntCmd
	Incr(string) *redis.IntCmd
	LLen(string) *redis.IntCmd
	LRange(string, int64, int64) *redis.StringSliceCmd
	LRem(string, int64, interface{}) *redis.IntCmd
	LTrim(string, int64, int64) *redis.StatusCmd
	MGet(...string) *redis.SliceCmd
	MSet(...interface{}) *redis.StatusCmd
	RPush(string, ...interface{}) *redis.IntCmd
	SAdd(string, ...interface{}) *redis.IntCmd
	SDiff(...string) *redis.StringSliceCmd
	Set(string, interface{}, time.Duration) *redis.StatusCmd
	SetNX(string, interface{}, time.Duration) *redis.BoolCmd
	SIsMember(string, interface{}) *redis.BoolCmd
	SMembers(string) *redis.StringSliceCmd
	SRem(string, ...interface{}) *redis.IntCmd
	SScan(string, uint64, string, int64) *redis.ScanCmd
	TTL(string) *redis.DurationCmd
	ZAdd(string, ...redis.Z) *redis.IntCmd
	ZCard(string) *redis.IntCmd
	ZRangeByLex(string, redis.ZRangeBy) *redis.StringSliceCmd
	ZRangeByScore(string, redis.ZRangeBy) *redis.StringSliceCmd
	ZRem(string, ...interface{}) *redis.IntCmd
	ZRemRangeByScore(string, string, string) *redis.IntCmd
	ZScore(string, string) *redis.FloatCmd

	Pipeline() redis.Pipeliner
	PoolStats() *redis.PoolStats
	Close() error
}

// NewClient initializes a Redis client which connects to a Redis cluster and falls back to a
// Redis ring if the endpoint does not have cluster support.
func NewClient(opts *rc.Options, stats statsd.Statter) (Handler, error) {
	var err error
	if err = opts.Validate(); err != nil {
		return &rediser{}, err
	}

	cluster := redis.NewClusterClient(opts.GetClusterOptions())
	if err = cluster.Ping().Err(); err == nil {
		return newRediserClient(opts, cluster, stats)
	}
	if err := cluster.Close(); err != nil {
		return &rediser{}, err
	}

	log.Println("unable to initialize redis cluster. falling back to redis ring: ", err)
	ring := redis.NewRing(opts.GetRingOptions())
	if err = ring.Ping().Err(); err == nil {
		return newRediserClient(opts, ring, stats)
	}
	if err := ring.Close(); err != nil {
		return &rediser{}, err
	}

	return &rediser{}, fmt.Errorf("unable to initialize redis ring: %q", err)
}

func newRediserClient(opts *rc.Options, client RedisClient, stats statsd.Statter) (Handler, error) {
	r := &rediser{
		Client:     client,
		Stats:      stats,
		Opts:       opts,
		active:     0,
		maxActive:  0,
		shutdownCh: make(chan struct{}),
	}
	go r.monitor()
	return r, nil
}

func (r *rediser) monitor() {
	ticker := time.NewTicker(r.Opts.MonitorInterval)
	defer ticker.Stop()
	for {
		select {
		case <-ticker.C:
			r.lock.Lock()
			max := r.maxActive
			r.maxActive = atomic.LoadInt32(&r.active)
			r.lock.Unlock()
			bucket := fmt.Sprintf("%s.max_active_commands", r.Opts.StatPrefix)
			err := r.Stats.Gauge(bucket, int64(max), 1)
			if err != nil {
				log.Printf("failed to report statsd metric: %q", err)
			}

			// Report connection pool stats
			stats := r.Client.PoolStats()
			rc.ReportPoolStats(r.Stats, r.Opts.StatPrefix, stats)
		case <-r.shutdownCh:
			return
		}
	}
}

func (r *rediser) Close() error {
	var err error
	r.shutdown.Do(func() {
		close(r.shutdownCh)
		err = r.Client.Close()
	})
	return err
}
