package dnscache

import (
	"context"
	"net"
	"strconv"
	"strings"
	"sync"
	"time"

	"expvar"
	"sync/atomic"

	"code.justin.tv/feeds/errors"
)

var (
	errCacheAlreadyClosed         = errors.New("cache loop has already been shutdown")
	errCacheRefreshLoopNotStarted = errors.New("tried to close refreshLoop before it started")
)

// Cache caches host to ip lookups and refreshes them periodically.
// Cache's zero value is ready to use.
type Cache struct {
	// Every specifies the time interval after which the cache is refreshed
	Every time.Duration
	// OnCacheUpsert is called everytime a key is updated in the cache.
	// OnCacheUpsert must not be set after the cache has been instantiated.
	OnCacheUpsert func(host string, oldAddrs []string, newAddrs []string, lookupTime time.Duration)
	// OnErr is called everytime the cache encounters an error.
	// OnErr must not be set after the cache has been instantiated.
	OnErr func(err error, host string)
	// Dialer used to establish connections to a host
	Dialer *net.Dialer

	once sync.Once

	disabled int32
	mu       sync.RWMutex
	cache    map[string][]string
	donec    chan struct{}
	done     bool

	// test hook to detect that the background loop has shutdown
	testOnClose func()
}

const (
	defaultRefreshInterval = 3 * time.Second
	defaultCacheSize       = 32
)

func (c *Cache) setup() {
	if c.Every == 0 {
		c.Every = defaultRefreshInterval
	}

	if c.Dialer == nil {
		// taken from the implementation of http.DefaultTransport
		c.Dialer = &net.Dialer{
			Timeout:   30 * time.Second,
			KeepAlive: 30 * time.Second,
			DualStack: true,
		}
	}

	c.mu.Lock()
	defer c.mu.Unlock()

	// donec is protected by mutex to ensure R/W consistency incase Close() is
	// called concurrently with setup()
	c.donec = make(chan struct{})
	c.cache = make(map[string][]string, defaultCacheSize)
}

func (c *Cache) start() {
	c.once.Do(func() {
		c.setup()
		go c.refreshLoop()
	})
}

func (c *Cache) shouldSkipCache(network string) bool {
	return !strings.HasPrefix(network, "tcp") || c.isClosed()
}

// DialContext looks for addr in the cache.
// DialContext starts a background loop to refresh all entries in the cache.
// DialContext falls back to the standard library in case any errors were encountered.
// DialContext exposes the errors encountered via the OnErr callback on the Cache.
//
// NOTE: It is the callers responsibility to call Close() on the cache the background
// refresh loop.
func (c *Cache) DialContext(ctx context.Context, network, addr string) (net.Conn, error) {
	// start background refresh loop if required
	c.start()
	if atomic.LoadInt32(&c.disabled) == 1 {
		return c.Dialer.DialContext(ctx, network, addr)
	}

	if c.shouldSkipCache(network) {
		return c.Dialer.DialContext(ctx, network, addr)
	}

	host, port, err := net.SplitHostPort(addr)
	if err != nil {
		return nil, err
	}

	// handle port like :http or :https
	portnum, err := net.DefaultResolver.LookupPort(ctx, network, port)
	if err != nil {
		return nil, err
	}

	// lookup inserts host into the cache if required
	ips, err := c.lookup(ctx, host)
	if err != nil {
		c.onErr(err, host)
		// try a hail mary
		return c.Dialer.DialContext(ctx, network, addr)
	}

	conn, err := c.dialIPs(ctx, ips, strconv.Itoa(portnum))
	if cerr := ctx.Err(); cerr != nil {
		return nil, cerr
	}

	if err != nil {
		c.onErr(err, host)
		// try a hail mary
		return c.Dialer.DialContext(ctx, network, addr)
	}
	return conn, nil
}

func (c *Cache) dialIPs(ctx context.Context, ips []string, port string) (net.Conn, error) {
	var errs error

	for _, ip := range ips {
		ipPort := net.JoinHostPort(ip, port)
		conn, err := c.Dialer.DialContext(ctx, "tcp", ipPort)
		if err != nil {
			errs = errors.Wrapf(err, "dial failed: ip:port = %s ctxErr: %s", ipPort, ctx.Err())
			continue
		}
		return conn, nil
	}
	return nil, errs
}

func (c *Cache) isClosed() bool {
	c.mu.RLock()
	defer c.mu.RUnlock()
	return c.done
}

// Close kills the background loop refreshing ips in the cache.
func (c *Cache) Close() error {
	c.mu.Lock()
	defer c.mu.Unlock()

	if c.donec == nil {
		return errCacheRefreshLoopNotStarted
	}

	if c.done {
		return errCacheAlreadyClosed
	}
	c.done = true

	close(c.donec)
	return nil
}

func (c *Cache) lookup(ctx context.Context, host string) ([]string, error) {
	c.mu.RLock()
	if addrs, ok := c.cache[host]; ok {
		c.mu.RUnlock()
		return addrs, nil
	}
	// required as upsertHost acquires a write lock.
	c.mu.RUnlock()
	return c.upsertHost(ctx, host)
}

func (c *Cache) upsertHost(ctx context.Context, host string) ([]string, error) {
	start := time.Now()
	ips, err := net.DefaultResolver.LookupIPAddr(ctx, host)
	if err != nil {
		return nil, err
	}
	totalTime := time.Since(start)

	addrs := make([]string, 0, len(ips))
	for _, ip := range ips {
		// from resolveAddrList() in net/dial.go
		if !isWildcard(ip.IP) && !matchAddrFamily(ip.IP, c.Dialer.LocalAddr) {
			continue
		}
		addrs = append(addrs, ip.IP.String())
	}

	c.mu.Lock()
	oldAddrs := c.cache[host]
	c.cache[host] = addrs
	c.mu.Unlock()

	if c.OnCacheUpsert != nil {
		c.OnCacheUpsert(host, addrs, oldAddrs, totalTime)
	}
	return addrs, nil
}

func (c *Cache) refreshLoop() {
	for {
		select {
		case <-c.donec:
			if c.testOnClose != nil {
				c.testOnClose()
			}
			return
		case <-time.After(c.Every):
			c.refresh()
		}
	}
}

func (c *Cache) refresh() {
	for _, h := range c.Hosts() {
		if _, err := c.upsertHost(context.Background(), h); err != nil {
			c.onErr(err, h)
		}
	}
}

// Var returns the internal cache as an expvar for view on /debug/vars
func (c *Cache) Var() expvar.Var {
	return expvar.Func(func() interface{} {
		c.mu.RLock()
		defer c.mu.RUnlock()
		cache := make(map[string][]string)
		for k, v := range c.cache {
			cache[k] = v
		}
		return cache
	})
}

// Enable the cache, if it has been previously disabled
func (c *Cache) Enable() {
	atomic.StoreInt32(&c.disabled, 0)
}

// Disable the cache, forwarding calls to the Dialer
func (c *Cache) Disable() {
	atomic.StoreInt32(&c.disabled, 1)
}

// Hosts returns the host names currently cached by the cache
func (c *Cache) Hosts() []string {
	c.mu.RLock()
	defer c.mu.RUnlock()
	hosts := make([]string, 0, len(c.cache))

	for host := range c.cache {
		hosts = append(hosts, host)
	}
	return hosts
}

func (c *Cache) onErr(err error, host string) {
	if c.OnErr != nil {
		c.OnErr(err, host)
	}
	return
}

func isWildcard(ip net.IP) bool {
	if ip == nil {
		return true
	}
	return ip.IsUnspecified()
}

func matchAddrFamily(ip net.IP, laddr net.Addr) bool {
	if laddr == nil {
		return true
	}

	x, ok := laddr.(*net.IPAddr)
	if !ok {
		return true
	}

	// taken from: func (ip IP) matchAddrFamily(x IP) bool in net/ip.go
	return ip.To4() != nil && x.IP.To4() != nil || ip.To16() != nil && ip.To4() == nil && x.IP.To16() != nil && x.IP.To4() == nil
}
