package peering

import (
	"errors"
	"math/rand"
	"sync"
	"sync/atomic"
	"time"
	"unsafe"

	"code.justin.tv/devhub/e2ml/libs/logging"
	"code.justin.tv/devhub/e2ml/libs/session"
)

var (
	peerBinary       = func(p Peer, bytes []byte) error { return p.WriteBinary(bytes) }
	peerBinaryAsText = func(p Peer, bytes []byte) error { return p.WriteBinaryAsText(bytes) }
)

type PeerMap map[string]Peer

// nil, nil return means no update
type transformFunc func(name string, expected Peer, peers PeerMap) (PeerMap, error)

type manager struct {
	pFactory    *proxyFactory
	cFactory    session.ClientResolver
	sFactory    session.ServiceFactory
	resolver    NameResolver
	service     session.Service
	listeners   map[Listener]struct{}
	peerPtr     unsafe.Pointer
	connector   *mgrListener
	logger      logging.Function
	mutex       sync.Mutex
	hasShutdown int32
}

// mgrListener is an internal Listener used by manager to assist with automatic
// connections on startup -- manager doesn't directly implement Listener so
// that changes can't be spoofed through the public interface.
type mgrListener struct {
	list ClosableServerList
	m    *manager
}

func (l *mgrListener) OnPeerAdded(name string)   { l.m.onAdded(name) }
func (l *mgrListener) OnPeerRemoved(name string) { l.m.onRemoved(name) }

// Manager acts as a service with the important addition that it can also send
// a broadcast to all peers and notify when a new peer has been added.
type Manager interface {
	session.Service
	ServerList
	DiscoveredCount() int
	Find(name string) (Peer, error)
	BroadcastBinary(bytes []byte) error
	BroadcastText(value string) error
	BroadcastBinaryAsText(bytes []byte) error
}

// NewPeerAdapter wraps a Service and ServerList to abstract the direction
// of peer-to-peer connections and provide a consistent Binding interface
// between peers regardless of which one establishes the connection -- an
// extra handshake packet is sent on initial connection before using the logic
// from the input binding factory to establish the local name of the node
func NewManager(
	bindings session.BindingFactory,
	clients session.ClientResolver,
	service session.ServiceFactory,
	list ClosableServerList,
	resolver NameResolver,
	logger logging.Function,
) Manager {
	peers := make(PeerMap)
	m := &manager{
		cFactory: clients,
		sFactory: service,
		peerPtr:  unsafe.Pointer(&peers),
		resolver: resolver,
		logger:   logger,
	}
	m.connector = &mgrListener{list, m}
	m.pFactory = newProxyFactory(m, bindings)
	return m
}

func (m *manager) Start() error {
	if m.HasShutdown() {
		return errors.New("Start called after Shutdown")
	}
	m.mutex.Lock()
	var err error
	if m.service == nil {
		if m.service, err = m.sFactory(m.pFactory.AsBindingFactory()); err != nil {
			m.mutex.Unlock()
			return err
		}
	}
	m.mutex.Unlock()
	m.connector.list.AddListener(m.connector)
	if !m.service.IsRunning() {
		return m.service.Start()
	}
	return nil
}

func (m *manager) Stop() {
	if m.service != nil {
		m.service.Stop()
	}
}

func (m *manager) IsRunning() bool {
	m.mutex.Lock()
	service := m.service
	m.mutex.Unlock()
	return service != nil && service.IsRunning()
}

func (m *manager) HasShutdown() bool {
	return atomic.LoadInt32(&m.hasShutdown) != 0
}

func (m *manager) peers() PeerMap {
	return *(*PeerMap)(atomic.LoadPointer(&m.peerPtr))
}

func (m *manager) DiscoveredCount() int { return len(m.peers()) }
func (m *manager) LocalName() string    { return m.connector.list.LocalName() }
func (m *manager) AddListener(listener Listener) {
	m.mutex.Lock()
	m.listeners[listener] = struct{}{}
	m.mutex.Unlock()
	for name := range m.peers() {
		listener.OnPeerAdded(name)
	}
}

func (m *manager) RemoveListener(listener Listener) {
	m.mutex.Lock()
	delete(m.listeners, listener)
	m.mutex.Unlock()
}

func (m *manager) Find(name string) (Peer, error) {
	if peer, found := m.peers()[name]; found {
		return peer, nil
	}
	return nil, ErrPeerNotFound
}

func (m *manager) BroadcastText(value string) error {
	return m.broadcast([]byte(value), peerBinaryAsText)
}

func (m *manager) BroadcastBinary(bytes []byte) error {
	return m.broadcast(bytes, peerBinary)
}

func (m *manager) BroadcastBinaryAsText(bytes []byte) error {
	return m.broadcast(bytes, peerBinaryAsText)
}

func (m *manager) broadcast(bytes []byte, write func(Peer, []byte) error) error {
	var err error
	for _, peer := range m.peers() {
		if peer == nil {
			continue
		}
		if e := write(peer, bytes); e != nil {
			if err == nil {
				err = e
			}
		}
	}
	return err
}

func (m *manager) WaitForDrainingConnections(until time.Time) {
	m.mutex.Lock()
	service := m.service
	m.mutex.Unlock()
	if service != nil {
		service.WaitForDrainingConnections(until)
	}
}

func (m *manager) Shutdown() error {
	atomic.StoreInt32(&m.hasShutdown, 1)
	m.connector.list.Close()
	// TODO : spin until no listeners?
	m.mutex.Lock()
	service := m.service
	m.service = nil
	m.mutex.Unlock()
	if service != nil {
		return service.Shutdown()
	}
	return nil
}

func (m *manager) getListeners() []Listener {
	listeners := make([]Listener, 0, len(m.listeners))
	for l := range m.listeners {
		listeners = append(listeners, l)
	}
	return listeners
}

func (m *manager) onAdded(name string) {
	if m.update(name, nil, doInsert) == nil {
		m.connect(name, 0)
	}
}

func (m *manager) onRemoved(name string) { // allow graceful close of current
	if m.update(name, nil, doRemove) != nil {
		return
	}
	m.mutex.Lock()
	listeners := m.getListeners()
	m.mutex.Unlock()
	for _, l := range listeners { // let listeners gracefully handle disconnections
		l.OnPeerRemoved(name)
	}
}

func (m *manager) onConnected(p *proxy) {
	name := p.RemoteName()
	if err := m.update(name, p, doSet); err != nil {
		m.logger(logging.Info, "Peer [REJECT]", m.LocalName(), "<->", name, " (", p.Address(), ")", err)
		p.Close(err)
		return
	}
	if p.isClient {
		m.logger(logging.Info, "Peer [OPENED]", m.LocalName(), "<->", name, " (", p.Address(), ")")
	} else {
		m.logger(logging.Info, "Peer [ACCEPT]", m.LocalName(), "<->", name, " (", p.Address(), ")")
	}
	m.mutex.Lock()
	listeners := m.getListeners()
	m.mutex.Unlock()
	for _, l := range listeners {
		l.OnPeerAdded(name)
	}
}

func (m *manager) onClose(p *proxy, err error) {
	if m.unset(p.RemoteName(), p) && p.isClient { // client side is responsible for reconnection to prevent exponential growth
		m.connect(p.RemoteName(), p.retryDelay)
	}
}

func (m *manager) unset(name string, expected Peer) bool {
	return m.update(name, expected, doUnset) == nil
}

func (m *manager) connect(name string, delay time.Duration) {
	if name == m.LocalName() {
		return
	}
	parsed, err := m.resolver(name)
	if err != nil {
		return
	}
	time.AfterFunc(delay, func() {
		if !m.unset(name, nil) { // Already set? We're done.
			return
		}
		delay = nextDelay(delay)
		err := m.pFactory.create(parsed, delay)
		if r := recover(); r != nil {
			m.logger(logging.Debug, "Panic creating peer", r)
			m.connect(name, delay)
		} else if err != nil {
			m.logger(logging.Debug, "Failed to create peer", err)
			m.connect(name, delay)
		}
	})
}

func nextDelay(prev time.Duration) time.Duration {
	if prev == 0 {
		return time.Millisecond * (450 + time.Duration(rand.Int31n(100)))
	}
	if prev > 10*time.Second {
		return prev
	}
	return prev * 2
}

func doInsert(name string, _ Peer, peers PeerMap) (PeerMap, error) {
	if _, found := peers[name]; found {
		return nil, ErrPeerAlreadyConnected
	}
	peers[name] = nil
	return peers, nil
}

func doRemove(name string, _ Peer, peers PeerMap) (PeerMap, error) {
	if _, found := peers[name]; !found {
		return nil, ErrPeerNotFound
	}
	delete(peers, name)
	return peers, nil
}

func doSet(name string, value Peer, peers PeerMap) (PeerMap, error) {
	prev, found := peers[name]
	if !found {
		return nil, ErrPeerNotFound
	}
	if prev == value {
		return nil, nil
	}
	if prev != nil {
		return nil, ErrPeerAlreadyConnected
	}
	peers[name] = value
	return peers, nil
}

func doUnset(name string, value Peer, peers PeerMap) (PeerMap, error) {
	if prev, found := peers[name]; !found || prev != value {
		return nil, ErrPeerNotFound
	}
	if value == nil {
		return nil, nil
	}
	peers[name] = nil
	return peers, nil
}

func (m *manager) update(name string, value Peer, transform transformFunc) error {
	for {
		prev := atomic.LoadPointer(&m.peerPtr)
		peers := make(PeerMap)
		for k, v := range *(*PeerMap)(prev) {
			peers[k] = v
		}
		update, err := transform(name, value, peers)
		if err != nil || update == nil {
			return err
		}
		if atomic.CompareAndSwapPointer(&m.peerPtr, prev, unsafe.Pointer(&update)) {
			return err
		}
	}
}
