package balanced

import (
	"net"
	"sync"

	"code.justin.tv/devhub/e2ml/libs/discovery/protocol"
	"code.justin.tv/devhub/e2ml/libs/discovery/protocol/message"
	"code.justin.tv/devhub/e2ml/libs/logging"
	"code.justin.tv/devhub/e2ml/libs/metrics"
	"code.justin.tv/devhub/e2ml/libs/session"
	"code.justin.tv/devhub/e2ml/libs/stream"
)

type reporterFactory func(session.Client) *reporterBinding
type peerFactory func(session.Client) *peerBinding

type bstats struct {
	peers metrics.Aggregator
	hosts metrics.Aggregator
}

type bindingManager struct {
	newReporter reporterFactory
	newPeer     peerFactory
	peerAuth    stream.AuthSource
	reporters   map[string]*reporterBinding
	peers       map[*peerBinding]struct{}
	stats       bstats
	mutex       sync.RWMutex
}

func newBindingManager(metrics metrics.Tracker, newReporter reporterFactory, newPeer peerFactory, peerAuth stream.AuthSource) *bindingManager {
	return &bindingManager{
		stats: bstats{
			peers: metrics.Aggregator("broker.peers", []string{}),
			hosts: metrics.Aggregator("broker.hosts", []string{}),
		},
		newReporter: newReporter,
		newPeer:     newPeer,
		peerAuth:    peerAuth,
		reporters:   make(map[string]*reporterBinding),
		peers:       make(map[*peerBinding]struct{}),
	}
}

func (b *bindingManager) createReporter(client session.Client) *reporterBinding {
	binding := b.newReporter(client)
	key := binding.key()
	b.mutex.Lock()
	prev, found := b.reporters[key]
	if !found {
		b.reporters[key] = binding
	} else {
		defer func() {
			prev.on.log(logging.Warning, "Name collision in reporter table:", key)
			prev.client.Close(nil)
		}()
	}
	b.mutex.Unlock()
	if !found {
		b.stats.hosts.Add(1)
	}
	return binding
}

func (b *bindingManager) findReporter(src net.Addr) *reporterBinding {
	b.mutex.RLock()
	found := b.reporters[src.String()]
	b.mutex.RUnlock()
	return found
}

func (b *bindingManager) removeReporter(binding *reporterBinding) {
	key := binding.key()
	b.mutex.Lock()
	prev, ok := b.reporters[key]
	ok = ok && (prev == binding)
	if ok {
		delete(b.reporters, key)
	}
	b.mutex.Unlock()
	if ok {
		b.stats.hosts.Add(-1)
	}
}

func (b *bindingManager) createPeer(client session.Client) session.Binding {
	binding := b.newPeer(client)
	writer := clientWriter(protocol.BroadcastSafe, client)
	if writer(message.NewConnect(protocol.BroadcastSafe, b.peerAuth().Encode())) == nil {
		b.mutex.Lock()
		b.peers[binding] = struct{}{}
		b.mutex.Unlock()
		b.stats.peers.Add(1)
	}
	return binding
}

func (b *bindingManager) removePeer(binding *peerBinding) {
	b.mutex.Lock()
	_, found := b.peers[binding]
	delete(b.peers, binding)
	b.mutex.Unlock()
	if found {
		b.stats.peers.Add(-1)
	}
}

func (b *bindingManager) initializePeer(version protocol.Version, client session.Client) {
	b.mutex.RLock()
	reporters := make([]*reporterBinding, 0, len(b.reporters))
	for _, r := range b.reporters {
		reporters = append(reporters, r)
	}
	b.mutex.RUnlock()

	writer := clientWriter(version, client)
	var err error
	var msgs []protocol.Message
reportLoop:
	// iterate all local reporter hosts and replay their current status for the new peer
	for _, r := range reporters {
		if host := r.getHost(); host != nil {
			if msgs, err = host.Summarize(); err == nil {
				for _, msg := range msgs {
					if err = writer(message.NewForward(true, host.RemoteAddress(), msg)); err != nil {
						break reportLoop
					}
				}
			}
		}
	}
	if err == nil {
		err = writer(message.NewReady())
	}
	if err != nil {
		client.Close(err)
	}
}

func clientWriter(version protocol.Version, client session.Client) func(msg protocol.Message, err error) error {
	return func(msg protocol.Message, err error) error {
		var bytes []byte
		if err == nil {
			bytes, err = msg.Marshal(version)
		}
		if err == nil {
			err = client.WriteBinary(bytes)
		}
		if err != nil {
			client.Close(err)
		}
		return err
	}
}
