package elastimemcache

import (
	"context"
	"crypto/md5"
	"encoding/json"
	"errors"
	"fmt"
	"hash"
	"io"
	"io/ioutil"
	"net"
	"strconv"
	"strings"
	"sync"
	"sync/atomic"
	"time"

	"github.com/serialx/hashring"
)

// Elasticache implements gomemcache's ServerSelector interface using serialx's hashring implementation to pick nodes
type Elasticache struct {
	CfgServer           string
	PollInterval        time.Duration
	UpdateTimeout       time.Duration
	FetchErrorsCallback func(error)
	HealthVerifier      HealthVerifier

	atomicLoadedServers atomicLoadedServers

	stats Stats
	done  chan struct{}
	once  sync.Once
}

// HealthVerifier is anything that can verify an address is a valid member of the ring
type HealthVerifier interface {
	CheckHealth(ctx context.Context, dst net.Addr) error
}

// Stats returns internal stat tracking about how well behaved the Elasticache abstraction is
func (e *Elasticache) Stats() Stats {
	return Stats{
		Reloads: atomic.LoadInt64(&e.stats.Reloads),
		Picks:   atomic.LoadInt64(&e.stats.Picks),
	}
}

// Stats is a snapshot in time of the state of the Elasticache
type Stats struct {
	// Reloads are how many times the elasticache cluster version was changed
	Reloads int64
	// Picks is how many times a node was picked from the ring
	Picks int64
}

// atomicInt64 allows me to JSON marshal an [] without violating -race
type atomicInt64 struct{ val int64 }

var _ json.Marshaler = &atomicInt64{}

// MarshalJSON encodes this value as an int in a thread safe way
func (a *atomicInt64) MarshalJSON() ([]byte, error) {
	return json.Marshal(a.Get())
}

func (a *atomicInt64) Get() int64 {
	return atomic.LoadInt64(&a.val)
}

func (a *atomicInt64) String() string {
	return strconv.FormatInt(a.Get(), 10)
}

func (a *atomicInt64) Add(value int64) int64 {
	return atomic.AddInt64(&a.val, value)
}

// ServerInfo is a cache of information about a memcache server in the ring
type ServerInfo struct {
	Addr       *staticAddr
	ServerHits atomicInt64
}

// LoadedServers is a constructed hashring of elasticache servers
type LoadedServers struct {
	CacheVersion int
	Servers      map[string]*ServerInfo
	servers      []string
	Ring         *hashring.HashRing `json:"-"`
}

// md5Hasher is a wrapper around md5 to make hashring work.
type md5Hasher struct {
	hash.Hash
}

// Sum calls md5.Sum and returns its slice representation.
func (m md5Hasher) Sum(key []byte) []byte {
	result := md5.Sum(key)
	return result[:]
}

type atomicLoadedServers struct {
	loadedServerStruct atomic.Value
}

// Load is a type-safe atomic.Value.Load
func (a *atomicLoadedServers) Load() *LoadedServers {
	v := a.loadedServerStruct.Load()
	if v == nil {
		return nil
	}
	return v.(*LoadedServers)
}

var _ json.Marshaler = &Elasticache{}

// MarshalJSON returns a debugging JSON version of the elasti cache.  Good for expvar
func (e *Elasticache) MarshalJSON() ([]byte, error) {
	return json.Marshal(struct {
		CfgServer     string
		PollInterval  time.Duration
		UpdateTimeout time.Duration
		LoadedServers *LoadedServers
	}{
		CfgServer:     e.CfgServer,
		PollInterval:  e.PollInterval,
		UpdateTimeout: e.UpdateTimeout,
		LoadedServers: e.atomicLoadedServers.Load(),
	})
}

// Store is a type-safe atomic.Value.Store
func (a *atomicLoadedServers) Store(v *LoadedServers) {
	a.loadedServerStruct.Store(v)
}

func makeAddr(server string) (*staticAddr, error) {
	if strings.Contains(server, "/") {
		addr, err := net.ResolveUnixAddr("unix", server)
		if err != nil {
			return nil, err
		}
		return newStaticAddr(addr), nil
	}
	tcpaddr, err := net.ResolveTCPAddr("tcp", server)
	if err != nil {
		return nil, err
	}
	return newStaticAddr(tcpaddr), nil
}

func makeServerInfo(servers []string) (map[string]*ServerInfo, error) {
	ret := make(map[string]*ServerInfo, len(servers))
	for _, server := range servers {
		addr, err := makeAddr(server)
		if err != nil {
			return nil, err
		}
		ret[server] = &ServerInfo{
			Addr: addr,
		}
	}
	return ret, nil
}

func dedup(servers []string) []string {
	// The list is so small, a double loop is good enough
	ret := make([]string, 0, len(servers))
	for _, server := range servers {
		contained := false
		for _, containedServer := range ret {
			if server == containedServer {
				contained = true
				break
			}
		}
		if !contained {
			ret = append(ret, server)
		}
	}
	return ret
}

func initLoadedServers(cacheVersion int, servers []string) (*LoadedServers, error) {
	servers = dedup(servers)
	serverInfo, err := makeServerInfo(servers)
	if err != nil {
		return nil, err
	}
	ring, err := hashring.NewWithHash(servers, md5Hasher{})
	if err != nil {
		return nil, err
	}
	return &LoadedServers{
		CacheVersion: cacheVersion,
		Servers:      serverInfo,
		servers:      servers,
		Ring:         ring,
	}, nil
}

// staticAddr caches the Network() and String() values from any net.Addr.
type staticAddr struct {
	ntw, str string
}

var _ net.Addr = &staticAddr{}

func newStaticAddr(a net.Addr) *staticAddr {
	return &staticAddr{
		ntw: a.Network(),
		str: a.String(),
	}
}

var _ json.Marshaler = &staticAddr{}

func (s *staticAddr) MarshalJSON() ([]byte, error) {
	return json.Marshal(struct {
		Network string
		String  string
	}{
		Network: s.Network(),
		String:  s.String(),
	})
}

func (s *staticAddr) Network() string { return s.ntw }
func (s *staticAddr) String() string  { return s.str }

func (e *Elasticache) pollInterval() time.Duration {
	if e.PollInterval.Nanoseconds() == 0 {
		return time.Second * 5
	}
	return e.PollInterval
}

func (e *Elasticache) updateTimeout() time.Duration {
	if e.UpdateTimeout.Nanoseconds() == 0 {
		return time.Second * 15
	}
	return e.UpdateTimeout
}

// Init should be called after an Elasticache instance is created, but before you attempt to use it.  It will try
// to load a valid cache ring.
func (e *Elasticache) Init(ctx context.Context) error {
	return e.fetchLatestVersion(ctx)
}

func (e *Elasticache) checkErr(err error) {
	if err != nil && e.FetchErrorsCallback != nil {
		e.FetchErrorsCallback(err)
	}
}

func (e *Elasticache) fetchLatestVersion(ctx context.Context) error {
	e.setup()
	addr, err := net.ResolveTCPAddr("tcp", e.CfgServer)
	if err != nil {
		return fmt.Errorf("invalid TCP address: %s", e.CfgServer)
	}
	var d net.Dialer
	nc, err := d.DialContext(ctx, addr.Network(), addr.String())
	if err != nil {
		return err
	}
	defer func() {
		e.checkErr(nc.Close())
	}()
	version, servers, err := e.fetch(ctx, nc)
	if err != nil {
		return err
	}
	currentVal := e.atomicLoadedServers.Load()
	if currentVal != nil {
		// len() must match, even if version is the same, because we could filter out downed servers
		if currentVal.CacheVersion >= version && len(currentVal.servers) == len(servers) {
			return nil
		}
	}
	newServers, err := initLoadedServers(version, servers)
	if err != nil {
		return err
	}
	e.atomicLoadedServers.Store(newServers)
	return nil
}

func (e *Elasticache) setup() {
	e.once.Do(func() {
		e.done = make(chan struct{})
	})
}

// Start sets up a polling loop that ends when Close is called
func (e *Elasticache) Start() error {
	e.setup()
	for {
		select {
		case <-time.After(e.pollInterval()):
		case <-e.done:
			return nil
		}
		ctx, cancel := context.WithTimeout(context.Background(), e.updateTimeout())
		e.checkErr(e.fetchLatestVersion(ctx))
		cancel()
	}
}

// Close ends the Start polling loop
func (e *Elasticache) Close() error {
	e.setup()
	close(e.done)
	return nil
}

// Configuration returns the current elasticache configuration: the cache version and the server list
func (e *Elasticache) Configuration() (int, []string) {
	currentValP := e.atomicLoadedServers.Load()
	if currentValP == nil {
		return 0, nil
	}
	return currentValP.CacheVersion, currentValP.servers
}

func (e *Elasticache) fetch(ctx context.Context, rw io.ReadWriter) (int, []string, error) {
	if _, err := io.WriteString(rw, "config get cluster\r\n"); err != nil {
		return 0, nil, err
	}
	var (
		key   string
		flags int
		size  int
	)
	if _, err := fmt.Fscanf(rw, "CONFIG %s %d %d\r\n", &key, &flags, &size); err != nil {
		return 0, nil, err
	}
	if key != "cluster" || flags != 0 {
		return 0, nil, fmt.Errorf("unknown key or flags. key=%q flags=%d", key, flags)
	}
	value, err := ioutil.ReadAll(io.LimitReader(rw, int64(size)+2))
	if err != nil {
		return 0, nil, err
	}
	p := strings.SplitN(string(value), "\n", 3)
	if len(p) < 3 || p[2] != "\r\n" {
		return 0, nil, fmt.Errorf("malformed response: %s", string(value))
	}
	version, err := strconv.Atoi(p[0])
	if err != nil {
		return 0, nil, err
	}
	var servers []string
	for _, s := range strings.Split(p[1], " ") {
		p := strings.Split(s, "|")
		servers = append(servers, p[1]+":"+p[2])
	}
	return version, e.filterUnhealthyServers(ctx, servers), nil
}

func (e *Elasticache) filterUnhealthyServers(ctx context.Context, servers []string) []string {
	if e.HealthVerifier == nil {
		return servers
	}
	// Assume all servers are invalid to start with
	serverIsValid := make([]bool, len(servers))
	wg := sync.WaitGroup{}
	wg.Add(len(servers))
	for i, server := range servers {
		go func(i int, server string) {
			defer wg.Done()
			serverAddr, err := makeAddr(server)
			if err != nil {
				e.checkErr(err)
				return
			}
			if err := e.HealthVerifier.CheckHealth(ctx, serverAddr); err != nil {
				e.checkErr(err)
				return
			}
			serverIsValid[i] = true
		}(i, server)
	}
	wg.Wait()
	ret := make([]string, 0, len(servers))
	for i := range servers {
		if serverIsValid[i] {
			ret = append(ret, servers[i])
		}
	}
	// Here we have a sanity check.  If all the servers are bad, maybe our health check is invalid, so we return them
	// all
	if len(ret) == 0 {
		return servers
	}
	return ret
}

// PickServer returns the server address that a given item should be shared onto.  It uses a hash ring to calculate
// this result.
func (e *Elasticache) PickServer(key string) (net.Addr, error) {
	currentVal := e.atomicLoadedServers.Load()
	if currentVal == nil {
		return nil, errors.New("server struct not loaded")
	}
	nodeName, ok := currentVal.Ring.GetNode(key)
	if !ok {
		return nil, fmt.Errorf("unable to find node for key %s", key)
	}
	v, ok := currentVal.Servers[nodeName]
	if !ok {
		return nil, fmt.Errorf("unable to find server for key %s", nodeName)
	}
	v.ServerHits.Add(1)
	return v.Addr, nil
}

// Each iterates over each server calling the given function.  Returns on the *first* error.
func (e *Elasticache) Each(ctx context.Context, f func(context.Context, net.Addr) error) error {
	currentVal := e.atomicLoadedServers.Load()
	if currentVal == nil {
		return errors.New("server struct not loaded")
	}
	for _, serverAddr := range currentVal.Servers {
		if err := f(ctx, serverAddr.Addr); err != nil {
			return err
		}
	}
	return nil
}
