package e2topics

import (
	"context"
	"fmt"
	"time"

	"github.com/mediocregopher/radix/v3"

	"github.com/graph-gophers/dataloader"
)

type RedisBatchLoader struct {
	RedisCli   radix.Client
	Stats      Statter
	getLoaders map[string]*dataloader.Loader
	setLoaders map[string]*dataloader.Loader
}

func NewRedisBatchLoader(redisCli radix.Client, stats Statter) *RedisBatchLoader {
	d := &RedisBatchLoader{
		RedisCli: redisCli,
		Stats:    stats,
	}
	// Create one loader per group, to make sure each batch is sent to the same shard
	d.getLoaders = make(map[string]*dataloader.Loader, len(keyGroupTags))
	d.setLoaders = make(map[string]*dataloader.Loader, len(keyGroupTags))
	for _, key := range keyGroupTags {
		d.getLoaders[key] = newGetLoader(d.BatchGetFn)
		d.setLoaders[key] = newSetLoader(d.BatchSetFn)
	}
	return d
}

func newGetLoader(batchGetFn dataloader.BatchFunc) *dataloader.Loader {
	return dataloader.NewBatchedLoader(
		batchGetFn,
		dataloader.WithInputCapacity(100),
		dataloader.WithBatchCapacity(100),
		dataloader.WithWait(16*time.Millisecond),
		dataloader.WithClearCacheOnBatch(),
		dataloader.WithCache(&dataloader.NoCache{}),
	)
}

func newSetLoader(batchSetFn dataloader.BatchFunc) *dataloader.Loader {
	return dataloader.NewBatchedLoader(
		batchSetFn,
		dataloader.WithInputCapacity(50),
		dataloader.WithBatchCapacity(50),
		dataloader.WithWait(16*time.Millisecond),
		dataloader.WithClearCacheOnBatch(),
		dataloader.WithCache(&dataloader.NoCache{}),
	)
}

// Get blocks until a full bach is loaded from the key group and then returns the bytes result for that key.
func (d *RedisBatchLoader) Get(key string) ([]byte, error) {
	i := keyGroupIdx([]byte(key))
	keyGroupTag := keyGroupTags[i]
	getLoader := d.getLoaders[keyGroupTag]
	taggedKey := taggedGroupKey(keyGroupTag, key)

	thunk := getLoader.Load(context.Background(), dataloader.StringKey(taggedKey))
	result, err := thunk() // block until loaded
	if err != nil {
		return nil, err
	}
	value, ok := result.([]byte)
	if !ok {
		return nil, fmt.Errorf("RedisBatchLoader.Get: result type %T is not []byte", value)
	}
	return value, nil
}

// Set blocks until a full batch is written
func (d *RedisBatchLoader) Set(key string, value []byte) error {
	i := keyGroupIdx([]byte(key))
	keyGroupTag := keyGroupTags[i]
	setLoader := d.setLoaders[keyGroupTag]
	taggedKey := taggedGroupKey(keyGroupTag, key)

	thunk := setLoader.Load(context.Background(), &valueAsDataloaderKey{key: taggedKey, value: value})
	_, err := thunk() // block until all batch is written
	return err
}

func (d *RedisBatchLoader) BatchGetFn(ctx context.Context, keys dataloader.Keys) []*dataloader.Result {
	// Batch load using MGET.
	// All the keys must map to the same shard, which is achieved by using key group tags and a different loader for each group.
	// MGET returns the values of all specified keys; not found keys are returned as nil values. Because of this, the operation never fails.
	var mgetResp [][]byte
	err := d.RedisCli.Do(radix.Cmd(&mgetResp, "MGET", keys.Keys()...))
	if err != nil {
		return dataloaderResultsWithNilData(keys, err)
	}

	var results []*dataloader.Result
	for _, value := range mgetResp {
		results = append(results, &dataloader.Result{
			Data:  value, // []byte
			Error: nil,
		})
	}
	d.Stats.Gauge("RedisGetBatchSize", len(results))
	return results
}

func (d *RedisBatchLoader) BatchSetFn(ctx context.Context, dlKeys dataloader.Keys) []*dataloader.Result {
	const expireSecs = 10

	cmds := make([]radix.CmdAction, len(dlKeys))
	for i, dlKey := range dlKeys {
		key := dlKey.String()
		bytes := dlKey.Raw().([]byte)
		cmds[i] = radix.FlatCmd(nil, "SET", key, bytes, "EX", expireSecs)
	}

	err := d.RedisCli.Do(radix.Pipeline(cmds...))
	d.Stats.Gauge("RedisSetBatchSize", len(dlKeys))
	return dataloaderResultsWithNilData(dlKeys, err)
}

// implements dataloader.Key so it can be passed to the BatchSetFn
type valueAsDataloaderKey struct {
	key   string
	value []byte
}

func (i *valueAsDataloaderKey) String() string {
	return i.key
}

func (i *valueAsDataloaderKey) Raw() interface{} {
	return i.value
}
