package pick

import (
	"fmt"
	"math/rand"
	"sort"
	"time"

	"code.justin.tv/devhub/e2ml/libs/discovery/protocol/message"
	"code.justin.tv/devhub/e2ml/libs/stream"
)

type AddFlags int

const (
	None     = AddFlags(0)
	Appended = AddFlags(1)
)

const (
	noScore  = uint64(0)
	minScore = uint64(1)
	maxScore = uint64(1)<<63 - 1
)

// enforce immutable interface
type List interface {
	Pick() (Host, bool)
	Find(hostname string) (Host, bool)
	HasCollision() bool
	GetSourceEntries() []Entry
	IsEmpty() bool

	Tick() (List, bool)
	Add(Host, stream.SourceID) (List, AddFlags, bool)
	Remove(Host) (List, bool)
}

type list struct {
	hosts   []*pickEntry
	weight  int64
	period  time.Duration
	expires time.Time
}

func NewList(period time.Duration) List {
	return &list{[]*pickEntry{}, int64(noScore), period, time.Now().Add(period)}
}

func (l *list) IsEmpty() bool {
	return len(l.hosts) == 0
}

func (l *list) HasCollision() bool {
	found := false
	for _, entry := range l.hosts {
		flags := entry.host.Status().Flags()
		if flags.IsSource() && !flags.IsDraining() {
			if found {
				return true
			}
			found = true
		}
	}
	return false
}

func (l *list) GetSourceEntries() []Entry {
	filtered := make([]Entry, 0)
	for _, entry := range l.hosts {
		flags := entry.host.Status().Flags()
		if flags.IsSource() && !flags.IsDraining() {
			filtered = append(filtered, entry)
		}
	}

	// sort ascending, to guarantee same pick list on all peer instances
	sort.Slice(filtered, func(i, j int) bool {
		return filtered[i].Source() < filtered[j].Source()
	})
	return filtered
}

func (l *list) Add(host Host, source stream.SourceID) (List, AddFlags, bool) {
	var entry *pickEntry
	flags := Appended
	i := 0
	for i, entry = range l.hosts {
		if entry.host == host {
			if entry.source == source {
				return l, None, false
			}
			flags = None
			break
		}
	}
	score := statusToScore(host.Status())
	result := l.clone()
	entry = newEntry(host, source, score, noScore)
	if flags == Appended {
		result.hosts = append(result.hosts, entry)
	} else {
		result.hosts[i] = entry
	}
	if score > noScore {
		result.calc()
	}
	return result, flags, true
}

func (l *list) Remove(host Host) (List, bool) {
	index := -1
	for i, entry := range l.hosts {
		if entry.host == host {
			index = i
			break
		}
	}
	if index < 0 {
		return l, false
	}
	result := l.clone()
	score := result.hosts[index].score
	result.hosts[index] = result.hosts[len(result.hosts)-1]
	result.hosts = result.hosts[:len(result.hosts)-1]
	if score > noScore {
		result.calc()
	}
	return result, true
}

func (l *list) Tick() (List, bool) {
	if len(l.hosts) == 0 || !l.hasExpired() {
		return l, false
	}
	return l.clone().refresh(), true
}

func (l *list) Pick() (Host, bool) {
	if l.weight == 0 {
		return nil, false
	}
	chosen := uint64(rand.Int63n(l.weight))
	for _, entry := range l.hosts {
		if chosen < entry.weight {
			return entry.host, true
		}
		chosen -= entry.weight
	}
	panic("Out of bounds selection from pick list")
}

func (l *list) Find(hostname string) (Host, bool) {
	for _, entry := range l.hosts {
		if hostname == entry.host.Hostname() {
			return entry.host, true
		}
	}
	return nil, false
}

func (l *list) clone() *list {
	result := &list{
		hosts:   make([]*pickEntry, len(l.hosts)),
		weight:  l.weight,
		expires: l.expires,
	}
	for i, entry := range l.hosts { // deep copy
		result.hosts[i] = entry.clone()
	}
	return result
}

func (l *list) refresh() *list {
	for _, entry := range l.hosts {
		entry.refresh()
	}
	return l.calc()
}

func (l *list) calc() *list {
	min := maxScore
	max := minScore
	total := noScore
	active := uint64(0)
	for _, entry := range l.hosts {
		if entry.score == noScore {
			continue
		}
		active += 1
		if entry.score < min {
			min = entry.score
		}
		if entry.score > max {
			max = entry.score
		}
		total += entry.score
	}
	if active == 0 {
		l.weight = 0
		l.setExpired()
		return l
	}
	target := (min + max)
	scale := 0 // prevent overflow by scaling down values
retry:
	l.weight = 0
	for _, entry := range l.hosts {
		l.weight += entry.balance(target, scale)
		if l.weight < 0 {
			scale += 1
			goto retry
		}
	}
	l.setExpired()
	return l
}

func (l *list) hasExpired() bool {
	return l.expires.Before(time.Now())
}

func (l *list) setExpired() {
	l.expires = time.Now().Add(l.period)
}

func (l *list) String() string {
	names := make([]string, len(l.hosts))
	for i, h := range l.hosts {
		names[i] = h.host.Hostname()
	}
	return fmt.Sprintf("%v", names)
}

func statusToScore(status message.Status) uint64 {
	if !status.Flags().IsAvailable() {
		return noScore
	}
	return uint64(status.LoadFactor() + 1)
}
