package registry

import (
	"sync"
	"time"

	"code.justin.tv/devhub/e2ml/libs/stream/listener"
	"code.justin.tv/devhub/e2ml/libs/stream/protocol"
	"code.justin.tv/devhub/e2ml/libs/stream/scheduler"

	"code.justin.tv/devhub/e2ml/libs/metrics"
	"code.justin.tv/devhub/e2ml/libs/promise"
	"code.justin.tv/devhub/e2ml/libs/stream"
	"code.justin.tv/devhub/e2ml/libs/stream/history"
)

// A channel provides an address-specific binding for read/write operations that
// sits on top of the connection layer, enabling state preservation across
// connection changes.
//
// The state machine between the channel and its underlying binding uses two
// counters for tracking whether a join has been successful: bound and joining.
// Bound is incremented for each new binding registered, and joining is set equal
// to bound each time a join request is sent.  This allows proper debouncing of
// responses from stale connections so that promises are not resolved prematurely
// on a binding level failure.

type emptiedFunc func(*channel)

type cstats struct {
	subscriptions metrics.Aggregator
	writers       metrics.Aggregator
	bytesRead     metrics.Count // total fanout bytes
	msgsRead      metrics.Count
}

type channel struct {
	bRequest     bindingRequest
	bAck         bindingRequest
	uRequest     unbindingRequest
	history      stream.History
	scheduler    stream.Scheduler
	listeners    listener.Collection
	writers      int
	children     map[stream.AddressKey]*channel
	binding      *binding
	ancestors    []*channel
	msgs         []*deferredMessage
	hasJoined    promise.MutableBool
	hasParted    promise.MutableBool
	onEmptied    emptiedFunc
	bound        int32
	joining      int32
	stats        cstats
	markedJoined bool
	lMutex       sync.RWMutex
	bMutex       sync.RWMutex
}

var _ stream.Reader = (*channel)(nil)

func newChannel(addr stream.Address, bRequest, bAck bindingRequest, uRequest unbindingRequest, ancestors []*channel, onEmptied emptiedFunc, metrics metrics.Tracker) *channel {
	tags := metricTagsForAddress(addr)
	return &channel{
		stats: cstats{
			writers:       metrics.Aggregator("registry.writers", tags),
			subscriptions: metrics.Aggregator("registry.subscriptions", tags),
			bytesRead:     metrics.Count("registry.bytes.read", tags),
			msgsRead:      metrics.Count("registry.msgs.read", tags),
		},
		bAck:      bAck,
		bRequest:  bRequest,
		uRequest:  uRequest,
		history:   history.NewSingleEntry(addr),
		scheduler: scheduler.NewInline(),
		listeners: listener.NewCollection(),
		children:  make(map[stream.AddressKey]*channel),
		hasJoined: promise.NewBool(),
		onEmptied: onEmptied,
		ancestors: ancestors,
	}
}

func (c *channel) Address() stream.Address { return c.history.Address() }

func (c *channel) newWriter() stream.Writer {
	c.lMutex.Lock()
	c.writers += 1
	writer := newWriter(c)
	c.lMutex.Unlock()
	c.stats.writers.Add(1)
	return writer
}

func (c *channel) closeWriter(w *writer) error {
	c.lMutex.Lock()
	c.writers -= 1
	shouldRelease := c.writers == 0 // TODO : ideally we'd only release after an active send
	c.checkEmpty()
	c.lMutex.Unlock()

	c.stats.writers.Add(-1)

	if !shouldRelease {
		return nil
	}

	c.bMutex.RLock()
	binding := c.binding
	c.bMutex.RUnlock()

	if binding == nil {
		return nil
	}

	addr := c.history.Address()
	binding.release(addr, func(src stream.SourceID, pos stream.Position, err error) {
		c.lMutex.RLock()
		c.checkUnbind(binding)
		c.lMutex.RUnlock()
	})
	return nil
}

func (c *channel) attach(child *channel) {
	addr := c.history.Address().Key()
	c.lMutex.Lock()
	c.children[addr] = child
	c.lMutex.Unlock()
}

func (c *channel) detatch(child *channel) {
	addr := c.history.Address().Key()
	c.lMutex.Lock()
	delete(c.children, addr)
	c.checkEmpty()
	c.lMutex.Unlock()
}

func (c *channel) send(content []byte, isDelta bool) stream.TrackerPromise {
	addr := c.history.Address()
	c.bMutex.RLock()
	binding := c.binding
	c.bMutex.RUnlock()

	if binding != nil { // TODO : need to react properly when connection is closed
		return binding.send(addr, content, isDelta)
	}

	c.bMutex.Lock()
	binding = c.binding
	if binding != nil {
		c.bMutex.Unlock()
		return binding.send(addr, content, isDelta)
	}
	msg := newDeferredMessage(content, isDelta)
	if isDelta {
		c.msgs = append(c.msgs, msg)
	} else {
		// we can send the latest message to represent all previous ones; this
		// will prevent us from tripping rate limits on connection.
		for _, m := range c.msgs {
			m.set(msg) // tie success/failure of old messages to current message
		}
		c.msgs = []*deferredMessage{msg}
	}
	c.bMutex.Unlock()

	c.bRequest(addr)
	return msg
}

func (c *channel) Join(listener stream.Listener) promise.Bool {
	c.lMutex.Lock()
	added := c.listeners.Add(listener)
	c.lMutex.Unlock()
	if !added {
		c.bMutex.RLock()
		hasJoined := c.hasJoined
		c.bMutex.RUnlock()
		return newDeferredBool(hasJoined, false)
	}
	return c.tryJoin()
}

func (c *channel) Leave(listener stream.Listener) promise.Bool {
	shouldLeave := false
	c.lMutex.Lock()
	removed := c.listeners.Remove(listener)
	if removed && c.listeners.IsEmpty() {
		shouldLeave = true
		c.checkEmpty()
	}
	c.lMutex.Unlock()
	if !removed {
		return promise.NewSyncBool(false, nil)
	}

	if shouldLeave {
		c.bMutex.Lock()
		result := c.hasParted
		if result != nil {
			c.bMutex.Unlock()
			return result
		}
		binding := c.binding
		result = promise.NewBool()
		c.hasParted = result
		c.markedJoined = false
		c.bMutex.Unlock()

		if binding == nil {
			result.Set(true, nil)
			return result
		}

		addr := c.history.Address()
		binding.part(addr, func(src stream.SourceID, pos stream.Position, err error) {
			c.bMutex.Lock()
			if c.hasParted != nil {
				c.hasParted.Set(true, err)
			}
			c.bMutex.Unlock()
			c.lMutex.RLock()
			c.checkUnbind(binding)
			c.lMutex.RUnlock()
		})
		c.stats.subscriptions.Add(-1)
	}

	return promise.NewSyncBool(true, nil)
}

func (c *channel) bind(binding *binding) {
	var msgs []*deferredMessage
	c.bMutex.RLock()
	if binding != nil && c.binding == binding { // handle reentry due to async request completion
		c.bMutex.RUnlock()
		return
	}
	c.bMutex.RUnlock()

	c.lMutex.RLock()
	shouldJoin := !c.listeners.IsEmpty()
	shouldSend := c.writers > 0
	c.lMutex.RUnlock()

	c.bMutex.Lock()
	prev := c.binding
	if !c.hasJoined.WouldBlock() { // we need to establish a new promise for our new connection
		c.hasJoined = promise.NewBool()
	}
	c.binding = binding
	if binding != nil {
		c.bound++
		msgs = c.msgs
		c.msgs = nil
	}
	c.bMutex.Unlock()

	addr := c.history.Address()
	if prev != nil {
		c.uRequest(addr, prev)
	}

	if binding != nil {
		c.bAck(addr)
	}

	if shouldJoin {
		c.tryJoin()
	} else if binding == nil && shouldSend {
		c.bRequest(addr)
	}

	for _, msg := range msgs {
		msg.set(binding.send(addr, msg.content, msg.isDelta))
	}

}

func (c *channel) onSent(msg stream.Message) error {
	if err := c.history.Write(msg); err != nil {
		return err
	}
	c.notify(c.history, stream.Origin, false)
	return nil
}

func (c *channel) onLost(desc stream.MessageDescription) error {
	c.notify(c.history, desc.At().End, false)
	return nil
}

func (c *channel) onClosed(src stream.SourceID, at stream.Position) error {
	c.notify(c.history, at, true)
	return c.history.Close()
}

func (c *channel) notifyLocal(hist stream.History, end stream.Position, close bool) {
	c.lMutex.RLock()
	// TODO : if a listener attempts to remove itself while in this readlock
	// we have a deadlock problem... solve that problem. Serialize read/writes in
	// a queue?  Use a second collection as a staging area when iterating?
	c.scheduler.Update(hist, end, c.listeners.AsSlice())
	if close {
		c.scheduler.SetClosed(hist.Address(), nil, c.listeners.AsSlice())
	}
	c.lMutex.RUnlock()
}

func (c *channel) notify(hist stream.History, end stream.Position, close bool) {
	c.notifyLocal(hist, end, close)
	for _, ancestor := range c.ancestors {
		ancestor.notifyLocal(hist, end, close)
	}
}

func (c *channel) tryJoin() promise.Bool {
	c.bMutex.Lock()
	c.hasParted = nil
	hasJoined := c.hasJoined
	addr := c.history.Address()
	binding := c.binding
	if c.joining == c.bound || binding == nil {
		c.bMutex.Unlock()
		if binding == nil {
			c.bRequest(addr)
		}
		return hasJoined
	}
	c.joining = c.bound
	bound := c.bound
	c.bMutex.Unlock()

	_, pos := c.history.Last()
	binding.join(addr, pos, func(src stream.SourceID, pos stream.Position, err error) {
		joined := false
		success := err != nil
		c.bMutex.Lock()
		if bound != c.bound {
			// drop out-of-date connection attempts
			binding.part(addr, func(src stream.SourceID, pos stream.Position, err error) {})
		} else if err == protocol.ErrAddressMoved {
			// try to find a replacement server since the address has moved
			time.AfterFunc(time.Duration(bound)*time.Second, func() {
				binding.onMove(binding, stream.AddressScopes{addr})
			})
		} else {
			if success {
				joined = !c.markedJoined
				c.markedJoined = true
				c.bound = 0
				c.joining = 0
			}
			hasJoined.Set(success, err)
		}
		c.bMutex.Unlock()
		if joined {
			c.stats.subscriptions.Add(1)
		}
	})
	return hasJoined
}

// must be called from within a lock
func (c *channel) checkEmpty() {
	if c.writers == 0 && len(c.children) == 0 && c.listeners.IsEmpty() {
		c.onEmptied(c)
	}
}

// must be called from within a lock
func (c *channel) checkUnbind(b *binding) {
	if b != c.binding || (c.writers == 0 && c.listeners.IsEmpty()) {
		c.uRequest(c.history.Address(), b)
	}
}
