package service_common

import (
	"encoding/json"
	"time"

	"reflect"

	"code.justin.tv/feeds/distconf"
	"code.justin.tv/feeds/errors"
	"code.justin.tv/feeds/log"
	"github.com/garyburd/redigo/redis"
	"golang.org/x/net/context"
)

type RedisConfig struct {
	Address        *distconf.Str
	ConnectTimeout *distconf.Duration
	ReadTimeout    *distconf.Duration
	WriteTimeout   *distconf.Duration

	MaxConns *distconf.Int
	Password *distconf.Str
}

func (r *RedisConfig) Verify(prefix string, d *distconf.Distconf) error {
	r.Address = d.Str(prefix+".address", "")
	if r.Address.Get() == "" {
		return errors.Errorf("uanble to find redis address at %s.address", prefix)
	}
	r.ConnectTimeout = d.Duration(prefix+".connect_timeout", 1*time.Second)
	r.ReadTimeout = d.Duration(prefix+".read_timeout", 100*time.Millisecond)
	r.WriteTimeout = d.Duration(prefix+".write_timeout", 100*time.Millisecond)
	r.MaxConns = d.Int(prefix+".max_conns", 32)
	return nil
}

func (r *RedisConfig) VerifyAuth(prefix string, d *distconf.Distconf) error {
	r.Password = d.Str(prefix+".password", "")
	return nil
}

func (r *RedisConfig) RedisConn() (redis.Conn, error) {
	redisConn, err := redis.DialTimeout(
		"tcp",
		r.Address.Get(),
		r.ConnectTimeout.Get(),
		r.ReadTimeout.Get(),
		r.WriteTimeout.Get(),
	)
	if err != nil {
		return nil, err
	}

	if r.Password != nil && r.Password.Get() != "" {
		if _, err := redisConn.Do("AUTH", r.Password.Get()); err != nil {
			// if we try to auth when we're already authed it's not an error
			if redisErr, ok := err.(redis.Error); !ok || redisErr != "ERR Client sent AUTH, but no password is set" {
				if closeErr := redisConn.Close(); closeErr != nil {
					return nil, ConsolidateErrors([]error{err, closeErr})
				}
				return nil, err
			}
		}
	}
	return redisConn, nil
}

func NewRedisPool(c *RedisConfig) *redis.Pool {
	return &redis.Pool{
		MaxIdle:   int(c.MaxConns.Get()),
		MaxActive: int(c.MaxConns.Get()),
		Dial:      c.RedisConn,
	}
}

type RedisCacheConfig struct {
	DefaultTTL *distconf.Duration
}

func (c *RedisCacheConfig) Verify(prefix string, d *distconf.Distconf) error {
	c.DefaultTTL = d.Duration(prefix+".default_ttl", time.Hour*24)
	return nil
}

type RedisCache struct {
	Pool      *redis.Pool
	Config    *RedisCacheConfig
	Stats     *StatSender
	KeyPrefix string
	Log       log.Logger
	DebugLog  log.Logger
}

func (r *RedisCache) key(key string) string {
	return r.KeyPrefix + key
}

func (r *RedisCache) Invalidate(ctx context.Context, key string) error {
	key = r.key(key)
	c := r.Pool.Get()
	defer func() {
		if err := c.Close(); err != nil {
			r.Stats.IncC("close_err", 1, 1.0)
		}
	}()
	_, err := redis.Int(c.Do("DEL", key))
	return err
}

func (r *RedisCache) Cached(ctx context.Context, key string, callback func() (interface{}, error), storeIntoPtr interface{}) error {
	r.DebugLog.Log("key", key, "redis.fetch")
	c := r.Pool.Get()
	defer func() {
		if err := c.Close(); err != nil {
			r.Stats.IncC("close_err", 1, 1.0)
		}
	}()
	err := r.cached(ctx, c, key, storeIntoPtr, func() (interface{}, error) {
		a, err := callback()
		if err == nil {
			return a, nil
		}
		return nil, err
	})
	if err != nil {
		return err
	}
	return nil
}

func (r *RedisCache) Ints(ctx context.Context, key string, callback func() ([]int64, error)) ([]int64, error) {
	var storeIntoPtr []int64
	err := r.Cached(ctx, key, func() (interface{}, error) {
		return callback()
	}, &storeIntoPtr)
	if err != nil {
		return nil, err
	}
	return storeIntoPtr, nil
}

func (r *RedisCache) cached(ctx context.Context, conn redis.Conn, key string, storeInto interface{}, callback func() (interface{}, error)) error {
	key = r.key(key)
	str, err := redis.String(conn.Do("GET", key))
	if err == nil {
		if err := json.Unmarshal([]byte(str), storeInto); err != nil {
			return err
		}
		r.Stats.IncC("cache.hit", 1, 1.0)
		return nil
	}
	if err != redis.ErrNil {
		return err
	}
	r.Stats.IncC("cache.miss", 1, 1.0)
	vals, err := callback()
	if err != nil {
		return err
	}
	rv := reflect.ValueOf(storeInto).Elem()
	if vals == nil {
		rv.Set(reflect.Zero(rv.Type()))
	} else {
		rv.Set(reflect.ValueOf(vals))
	}
	bytes, err := json.Marshal(vals)
	if err != nil {
		r.Stats.IncC("cache.json_err", 1, 1.0)
		return err
	}
	_, err = redis.String(conn.Do("SET", key, string(bytes), "PX", r.Config.DefaultTTL.Get().Nanoseconds()/time.Millisecond.Nanoseconds(), "NX"))
	if err == redis.ErrNil {
		r.Stats.IncC("already_exists", 1, 1.0)
		// already existed
	} else if err != nil {
		r.Stats.IncC("error", 1, 1.0)
	}
	return nil
}
