package registry

import (
	"io"
	"net/url"
	"sync"
	"sync/atomic"
	"time"

	"code.justin.tv/devhub/e2ml/libs/discovery"
	"code.justin.tv/devhub/e2ml/libs/logging"
	"code.justin.tv/devhub/e2ml/libs/metrics"
	"code.justin.tv/devhub/e2ml/libs/session"
	"code.justin.tv/devhub/e2ml/libs/stream"
	"code.justin.tv/devhub/lib-lifecycle/src/lifecycle"
)

type bindingRequest func(addr stream.Address)
type unbindingRequest func(addr stream.Address, current *binding)

type Remote interface {
	stream.Registry
	io.Closer
	Tick()
}

type rstats struct {
	tracker     metrics.Tracker
	connections metrics.Aggregator
	addresses   metrics.Aggregator
}

type remote struct {
	auth             stream.AuthSource
	disco            discovery.Broker
	resolver         session.ClientResolver
	brokerRetries    map[stream.AddressKey]*retryLogic
	hostRetries      map[string]*retryLogic
	connecting       map[string]*addressSet
	connected        map[string]*addressSet
	draining         map[string]*addressSet
	aMutex           sync.Mutex
	channels         map[stream.AddressKey]*channel
	cMutex           sync.RWMutex
	onClose          lifecycle.Manager
	stats            rstats
	logger           logging.Function
	closed           int32
	refreshAuthCount int
}

func NewRemote(auth stream.AuthSource, disco discovery.Broker, resolver session.ClientResolver, timeout time.Duration, stats metrics.Tracker, logger logging.Function) Remote {
	remote := &remote{
		stats: rstats{
			tracker:     stats,
			connections: stats.Aggregator("registry.connections", []string{}),
			addresses:   stats.Aggregator("registry.addresses", []string{}),
		},
		auth:          auth,
		disco:         disco,
		resolver:      resolver,
		logger:        logger,
		brokerRetries: make(map[stream.AddressKey]*retryLogic),
		hostRetries:   make(map[string]*retryLogic),
		connecting:    make(map[string]*addressSet),
		connected:     make(map[string]*addressSet),
		draining:      make(map[string]*addressSet),
		channels:      make(map[stream.AddressKey]*channel),
		onClose:       lifecycle.NewManager(),
	}
	remote.onClose.TickUntilClosed(remote.Tick, timeout)
	return remote
}

func (r *remote) Close() error {
	atomic.StoreInt32(&r.closed, 1)
	return r.onClose.ExecuteAll()
}

func (r *remote) Tick() {
	r.aMutex.Lock()
	defer r.aMutex.Unlock()
	for _, a := range r.connected {
		a.binding.Tick()
	}
}

func (r *remote) RefreshAuth() {
	r.aMutex.Lock()
	defer r.aMutex.Unlock()
	r.refreshAuthCount++
	for _, a := range r.connected {
		r.onExpiring(a.binding)
	}
}

func (r *remote) OnClose() lifecycle.Manager               { return r.onClose }
func (r *remote) Reader(addr stream.Address) stream.Reader { return r.channel(addr) }
func (r *remote) Writer(addr stream.Address) stream.Writer { return r.channel(addr).newWriter() }

func (r *remote) channel(addr stream.Address) *channel {
	r.cMutex.RLock()
	ch, found := r.channels[addr.Key()]
	r.cMutex.RUnlock()
	if found {
		return ch
	}
	parents := []*channel{}
	for _, parent := range addr.Parents() {
		if cast, ok := parent.(stream.Address); ok {
			parents = append(parents, r.channel(cast))
		}
	}
	dedupe := make(map[stream.AddressKey]*channel)
	for _, parent := range parents {
		dedupe[parent.Address().Key()] = parent
		for _, ancestor := range parent.ancestors {
			dedupe[ancestor.Address().Key()] = ancestor
		}
	}
	ancestors := make([]*channel, 0, len(dedupe))
	for _, ancestor := range dedupe {
		ancestors = append(ancestors, ancestor)
	}
	r.cMutex.Lock()
	ch, found = r.channels[addr.Key()]
	if found {
		r.cMutex.Unlock()
		return ch
	}
	ch = newChannel(addr, r.requestHost, r.onBound, r.unbind, ancestors, r.free, r.stats.tracker)
	r.channels[addr.Key()] = ch

	r.cMutex.Unlock()
	r.stats.addresses.Add(1)
	for _, p := range parents {
		p.attach(ch)
	}
	return ch
}

func (r *remote) free(ch *channel) {
	addr := ch.Address()
	key := addr.Key()
	for _, parent := range addr.Parents() {
		if cast, ok := parent.(stream.Address); ok {
			r.channel(cast).detatch(ch)
		}
	}
	r.cMutex.Lock()
	if found, ok := r.channels[key]; ok && found == ch {
		delete(r.channels, key)
		r.stats.addresses.Add(-1)
	}
	r.cMutex.Unlock()
}

func (r *remote) requestHost(addrs stream.Address) {
	r.requestHosts([]stream.Address{addrs})
}

func (r *remote) requestHosts(addrs []stream.Address) {
	if atomic.LoadInt32(&r.closed) != 0 {
		return
	}
	// TODO : bulk broker requests
	r.logger(logging.Trace, "Reg [FIND]", addrs)
	for _, addr := range addrs {
		r.aMutex.Lock()
		brokerRetries, ok := r.brokerRetries[addr.Key()]
		if !ok {
			brokerRetries = newRetryLogic(r.onClose)
			r.brokerRetries[addr.Key()] = brokerRetries
			r.onClose.RegisterHook(brokerRetries, brokerRetries.cancel)
		}
		r.aMutex.Unlock()
		brokerRetries.schedule(func() {
			r.onClose.ExecuteHook(brokerRetries)
			success := r.performRequestHost(addr)
			if !success {
				r.logger(logging.Trace, "Reg [FIND] -- RETRY --", addrs)
				r.requestHost(addr) // retry
			}
		})
	}
}

func (r *remote) performRequestHost(addr stream.Address) bool {
	r.logger(logging.Trace, "Reg [FIND] performRequestHost", addr.Key())
	var best *addressSet
	var binding *binding
	var found bool
	score := stream.NoAddressMatch
	r.aMutex.Lock()
	for _, addrSet := range r.connected {
		if score, found = addrSet.find(addr, score); found {
			binding = addrSet.binding
			best = addrSet
		}
	}
	for _, addrSet := range r.connecting {
		if score, found = addrSet.find(addr, score); found {
			binding = nil
			best = addrSet
		}
	}
	if best != nil {
		best.insert(addr)
	}
	r.aMutex.Unlock()

	if binding != nil {
		r.bind(addr, binding)
	}

	if best != nil {
		return true // host for addr was already requested
	}

	ticket, err := r.disco.FindHost(r.auth(), addr)
	if err != nil {
		r.logger(logging.Warning, "Reg - Unable to get host from discovery", addr.Key(), err)
		return false
	}

	url, err := url.Parse(ticket.URL())
	if err != nil {
		r.logger(logging.Warning, "Reg - Invalid URL from discovery host ticket", addr.Key(), err)
		return false
	}

	r.aMutex.Lock()
	active := true
	addrSet, found := r.connected[ticket.URL()]
	if !found {
		addrSet, found = r.connecting[ticket.URL()]
		active = false
	}

	// piggyback on any connection in progress and return early
	if found {
		addrSet.addScopes(ticket.Scopes())
		addrSet.insert(addr)
		r.aMutex.Unlock()
		if active {
			r.bind(addr, addrSet.binding)
		}
		return true
	}

	binding = newBinding(ticket.URL(), r.onSent, r.onLost, r.onStreamClosed, r.onDrain, r.onMove, r.onExpiring, r.onHostClosed, r.logger)
	r.onClose.RegisterHook(binding, binding.Close)
	addrSet = newAddressSet(binding, ticket.Scopes())
	r.connecting[binding.host] = addrSet
	addrSet.insert(addr)
	r.aMutex.Unlock()
	r.stats.connections.Add(1)

	r.connect(url, ticket, addrSet)
	return true
}

func (r *remote) connect(url *url.URL, ticket discovery.Ticket, set *addressSet) {
	r.aMutex.Lock()
	hostRetries, ok := r.hostRetries[set.binding.host]
	if !ok {
		hostRetries = newRetryLogic(r.onClose)
		r.hostRetries[set.binding.host] = hostRetries
	}
	r.aMutex.Unlock()
	hostRetries.schedule(func() { r.performConnect(url, ticket, set) })
}

func (r *remote) performConnect(url *url.URL, ticket discovery.Ticket, set *addressSet) {
	r.aMutex.Lock()
	expectedCount := r.refreshAuthCount
	r.aMutex.Unlock()
	host := set.binding.host
	r.logger(logging.Debug, "Reg [OPEN]", host)
	if client, err := r.resolver(url, set.binding); err == nil {
		set.binding.initialize(client, ticket.Method(), ticket.AccessCode(), r.onConnected(set, expectedCount))
	}
}

func (r *remote) onConnected(set *addressSet, expectedCount int) responseFunction {
	return func(_ stream.SourceID, _ stream.Position, err error) {
		r.aMutex.Lock()
		current, found := r.connecting[set.binding.host]
		found = found && current == set
		list := set.list()
		if found {
			r.connected[set.binding.host] = set
			delete(r.connecting, set.binding.host)
			delete(r.hostRetries, set.binding.host)
		}
		actualCount := r.refreshAuthCount
		r.aMutex.Unlock()

		if !found || len(list) == 0 {
			r.onClose.ExecuteHook(set.binding)
			return
		}
		// handle auth refresh requested during the connection process
		if expectedCount != actualCount {
			r.onExpiring(set.binding)
		}
		for _, addr := range list {
			r.bind(addr, set.binding)
		}
	}
}

func (r *remote) onExpiring(binding *binding) {
	binding.refresh(r.auth(), func(src stream.SourceID, pos stream.Position, err error) {
		// TODO: track auth refresh success/failure
	})
}

func (r *remote) onDrain(binding *binding) {
	r.aMutex.Lock()
	list := []stream.Address{}
	addrSet, found := r.connected[binding.host]
	if found && addrSet.binding == binding {
		delete(r.connected, binding.host)
		r.draining[binding.host] = addrSet
	} else {
		addrSet, found = r.connecting[binding.host]
		if found && addrSet.binding == binding {
			delete(r.connecting, binding.host)
			r.draining[binding.host] = addrSet
		}
	}
	if found {
		list = addrSet.list()
	}
	r.aMutex.Unlock()
	// seek new bindings while current connection is up but draining
	r.requestHosts(list)
}

func (r *remote) onMove(binding *binding, scopes stream.AddressScopes) {
	r.aMutex.Lock()
	list := []stream.Address{}
	channels := []*channel{}
	addrSet, found := r.connected[binding.host]
	if !found || addrSet.binding != binding {
		addrSet, found = r.connecting[binding.host]
	}
	if found && addrSet.binding == binding {
		// Remove the association with the current binding; this will
		// cause a new find request instead of picking the established
		// server.  We don't bother doing this if the service has
		// already been marked draining, since that would be redundant.
		list = addrSet.dropScopes(scopes)
	}
	for _, addr := range list {
		if ch, ok := r.channels[addr.Key()]; ok {
			channels = append(channels, ch)
		}
	}
	r.aMutex.Unlock()
	// detach the binding from affected channels, which will cause auto-reconnection if necessary
	for _, ch := range channels {
		ch.bind(nil)
	}
}

func (r *remote) onStreamClosed(addr stream.Address, src stream.SourceID, at stream.Position) {
	// TODO : kick listeners and reclaim channels if there are no desendents
	r.channel(addr).onClosed(src, at)
}

func (r *remote) onHostClosed(binding *binding, err error) {
	list := []stream.Address{}
	r.aMutex.Lock()

	// TODO : if we never initialized, trigger retry and return

	addrSet, found := r.connected[binding.host]
	if found && addrSet.binding == binding {
		delete(r.connected, binding.host)
	} else {
		addrSet, found = r.connecting[binding.host]
		if found && addrSet.binding == binding {
			delete(r.connecting, binding.host)
		} else {
			addrSet, found = r.draining[binding.host]
			if found && addrSet.binding == binding {
				delete(r.draining, binding.host)
			}
		}
	}
	if found {
		list = addrSet.list()
	}
	r.aMutex.Unlock()

	r.stats.connections.Add(-1)
	r.logger(logging.Debug, "Reg [CLOSED]", binding.host, err)
	if found && len(list) > 0 {
		r.logger(logging.Debug, "Reg [UNBIND]", list)
		for _, addr := range list {
			r.bind(addr, nil) // will cause reconnection attempts
		}
	}
}

func (r *remote) onSent(msg stream.Message)             { r.channel(msg.Address()).onSent(msg) }
func (r *remote) onLost(desc stream.MessageDescription) { r.channel(desc.Address()).onLost(desc) }
func (r *remote) bind(addr stream.Address, b *binding)  { r.channel(addr).bind(b) }

func (r *remote) onBound(addr stream.Address) {
	key := addr.Key()
	r.aMutex.Lock()
	retries, ok := r.brokerRetries[key]
	delete(r.brokerRetries, key)
	r.aMutex.Unlock()
	if ok {
		r.onClose.ExecuteHook(retries)
	}
}

// allow closing of old connections when address count hits 0
func (r *remote) unbind(addr stream.Address, b *binding) {
	r.aMutex.Lock()
	addrSet, found := r.connected[b.host]
	if !found || b != addrSet.binding {
		addrSet, found = r.connecting[b.host]
	}
	if !found || b != addrSet.binding {
		addrSet, found = r.draining[b.host]
	}
	shouldClose := found && b == addrSet.binding && addrSet.remove(addr)
	r.aMutex.Unlock()
	if shouldClose {
		_ = r.onClose.ExecuteHook(addrSet.binding) // use default error reporter
	}
}
