/*
Copyright 2011 Google Inc.

Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at

     http://www.apache.org/licenses/LICENSE-2.0

Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
*/

package memcache

import (
	"fmt"
	"math/rand"
	"net"
	"strings"
	"sync"
)

// ServerSelector is the interface that selects a memcache server
// as a function of the item's key.
//
// All ServerSelector implementations must be safe for concurrent use
// by multiple goroutines.
type ServerSelector interface {
	// PickServer returns the server address that a given item
	// should be shared onto.
	PickServer(key string) (net.Addr, error)
	Each(func(net.Addr) error) error
	PickRandomServer(key string) (net.Addr, error)
	GetServers() []net.Addr
}

// ServerList is a simple ServerSelector. Its zero value is usable.
type ServerList struct {
	mu        sync.RWMutex
	addrs     map[string]net.Addr
	addrsList []net.Addr
	ring      *HashRing
}

// staticAddr caches the Network() and String() values from any net.Addr.
type staticAddr struct {
	ntw, str string
}

func newStaticAddr(a net.Addr) net.Addr {
	return &staticAddr{
		ntw: a.Network(),
		str: a.String(),
	}
}

func (s *staticAddr) Network() string { return s.ntw }
func (s *staticAddr) String() string  { return s.str }

// SetServers changes a ServerList's set of servers at runtime and is
// safe for concurrent use by multiple goroutines.
//
// Each server is given equal weight. A server is given more weight
// if it's listed multiple times.
//
// SetServers returns an error if any of the server names fail to
// resolve. No attempt is made to connect to the server. If any error
// is returned, no changes are made to the ServerList.
func (ss *ServerList) SetServers(servers ...string) error {
	naddr := map[string]net.Addr{}
	for _, server := range servers {
		if strings.Contains(server, "/") {
			addr, err := net.ResolveUnixAddr("unix", server)
			if err != nil {
				return err
			}
			naddr[server] = newStaticAddr(addr)
		} else {
			tcpaddr, err := net.ResolveTCPAddr("tcp", server)
			if err != nil {
				return err
			}
			naddr[server] = newStaticAddr(tcpaddr)
		}
	}

	// Initialize a consistent hasher with the requested servers
	// nodesPerServer=9000 was chosen by simulating running 10M keys through
	// the hashring with len(servers)=2,10,100,1000 and comparing the drift in
	// number of keys per server. The goal is to get drift within a specific percentage
	// while minimizing nodesPerServer, as more nodes increases memory usage and CPU cost
	// & time for each call to GetNode (due to more binary searching).
	//
	// This number can be adjusted if memory usage is too high, or drift is unbearable.
	hr := NewHashRing(servers, 9000)

	ss.mu.Lock()
	defer ss.mu.Unlock()
	ss.addrs = naddr
	ss.ring = hr

	ss.addrsList = make([]net.Addr, len(ss.addrs))
	index := 0
	for _, addr := range ss.addrs {
		ss.addrsList[index] = addr
		index += 1
	}

	return nil
}

func (ss *ServerList) GetServers() []net.Addr {
	return ss.addrsList
}

// Each iterates over each server calling the given function
func (ss *ServerList) Each(f func(net.Addr) error) error {
	if ss.addrs == nil {
		return nil
	}

	ss.mu.RLock()
	defer ss.mu.RUnlock()
	for _, a := range ss.addrs {
		if err := f(a); nil != err {
			return err
		}
	}
	return nil
}

func (ss *ServerList) PickServer(key string) (net.Addr, error) {
	ss.mu.RLock()
	defer ss.mu.RUnlock()

	if ss.ring == nil {
		return nil, ErrNoServers
	}

	server, ok := ss.ring.GetNode(key)
	if !ok {
		return nil, ErrNoServers
	}

	addr, ok := ss.addrs[server]
	if !ok {
		return nil, fmt.Errorf("Couldn't find net.Addr for server: %s", server)
	}

	return addr, nil
}

func (ss *ServerList) PickRandomServer(key string) (net.Addr, error) {
	ss.mu.RLock()
	defer ss.mu.RUnlock()

	if len(ss.addrsList) == 0 {
		return nil, fmt.Errorf("Couldn't find any server")
	}

	index := rand.Intn(len(ss.addrsList))
	return ss.addrsList[index], nil
}
