package audience

import (
	"sync"

	"code.justin.tv/devhub/e2ml/libs/logging"
	"code.justin.tv/devhub/e2ml/libs/metrics"
	"code.justin.tv/devhub/e2ml/libs/stream"
	"code.justin.tv/devhub/e2ml/libs/stream/listener"
	"code.justin.tv/devhub/e2ml/libs/stream/protocol"
	"code.justin.tv/devhub/e2ml/libs/stream/protocol/message"
)

type emptiedFunction func(*channel)
type writtenFunction func(int64)
type sendFunction func(stream.Address, []byte, bool, stream.MutableTrackerPromise)

type cstats struct {
	//addresses     metrics.Aggregator // unused (structcheck)
	subscriptions   metrics.Aggregator
	writers         metrics.Aggregator
	bytesWritten    metrics.Count // total fanout bytes
	bytesSent       metrics.Count // only count at top level (no fan-out)
	msgsSent        metrics.Count
	loadErrorCount  metrics.Count
	writeErrorCount metrics.Count
}

type childMap map[stream.AddressKey]*channel
type writerMap map[audienceWriter]struct{}
type moveData struct {
	msg protocol.Message
	err error
}

type channel struct {
	addr      stream.Address
	logic     stream.ServerLogic
	scheduler stream.Scheduler
	onWritten writtenFunction
	onEmptied emptiedFunction
	onSend    sendFunction
	ancestors []*channel
	listeners listener.Collection
	history   historySocket
	children  childMap
	writers   writerMap
	stats     cstats
	moved     *moveData
	mutex     sync.RWMutex
	logger    logging.Function
}

func newChannel(logic stream.ServerLogic, addr stream.Address, scheduler stream.Scheduler, ancestors []*channel, onWritten writtenFunction, onEmptied emptiedFunction, onSend sendFunction, metrics metrics.Tracker, logger logging.Function) *channel {
	return &channel{
		stats: cstats{
			subscriptions:   metrics.Aggregator("audience.subscriptions", []string{}),
			writers:         metrics.Aggregator("audience.writers", []string{}),
			bytesWritten:    metrics.Count("audience.bytes.written", []string{}),
			bytesSent:       metrics.Count("audience.bytes.sent", []string{}),
			msgsSent:        metrics.Count("audience.msgs.sent", []string{}),
			loadErrorCount:  metrics.Count("audience.history.load.errCount", []string{}),
			writeErrorCount: metrics.Count("audience.history.write.errCount", []string{}),
		},
		addr:      addr,
		logic:     logic,
		scheduler: scheduler,
		onEmptied: onEmptied,
		onWritten: onWritten,
		onSend:    onSend,
		ancestors: ancestors,
		listeners: listener.NewCollection(),
		children:  make(childMap),
		writers:   make(writerMap),
		logger:    logger,
	}
}

func (c *channel) newWriter(b *binding) audienceWriter {
	var writer audienceWriter
	c.mutex.Lock()
	moved := c.moved
	shouldRequest := false
	if c.moved != nil {
		writer = &movedWriter{c.addr}
	} else {
		shouldRequest = len(c.writers) == 0
		writer = newWriter(c, b)
		c.writers[writer] = struct{}{}
	}
	c.mutex.Unlock()
	if moved != nil {
		b.respond(moved.msg, moved.err)
	} else {
		c.stats.writers.Add(1)
	}
	if shouldRequest {
		c.logic.OnAddressRequested(c.Address())
	}
	return writer
}

func (c *channel) closeWriter(w *writer) error {
	c.mutex.Lock()
	if _, ok := c.writers[w]; !ok {
		c.mutex.Unlock()
		return nil
	}
	delete(c.writers, w)
	shouldRelease := len(c.writers) == 0
	if shouldRelease && c.listeners.IsEmpty() && len(c.children) == 0 {
		c.onEmptied(c)
	}
	c.mutex.Unlock()
	c.stats.writers.Add(-1)

	if shouldRelease {
		return c.logic.OnAddressReleased(c.Address())
	}
	return nil
}

func (c *channel) attach(child *channel) {
	c.mutex.Lock()
	defer c.mutex.Unlock()
	c.children[child.Address().Key()] = child
}

func (c *channel) detach(child *channel) {
	c.mutex.Lock()
	defer c.mutex.Unlock()
	delete(c.children, child.Address().Key())
	if len(c.children) == 0 && c.listeners.IsEmpty() {
		c.onEmptied(c)
	}
}

func (c *channel) descendents(into map[stream.AddressKey]*channel) {
	c.mutex.RLock()
	defer c.mutex.RUnlock()
	for _, child := range c.children {
		into[child.Address().Key()] = child
		child.descendents(into)
	}
}

func (c *channel) send(message []byte, isDelta bool, promise stream.MutableTrackerPromise) {
	c.onSend(c.Address(), message, isDelta, promise)
}

func (c *channel) update(history stream.History) (int, error) {
	c.mutex.RLock()
	defer c.mutex.RUnlock()
	written, err := c.scheduler.Update(history, stream.Origin, c.listeners.AsSlice())
	return written, err
}

func (c *channel) Address() stream.Address { return c.addr }
func (c *channel) Current() (stream.SourceID, stream.Position) {
	if history, _ := c.history.load(); history != nil {
		return history.Last()
	}
	return stream.None, stream.Origin
}

func (c *channel) Send(src stream.SourceID, at stream.Segment, data []byte) error {
	msg, err := message.NewSent(c.addr, src, at, data)
	if err != nil {
		return err
	}
	return c.Forward(msg)
}

func (c *channel) Forward(msg stream.Message) error {
	history, err := c.history.loadOrCreate(c.addr, c.logic)
	if err != nil {
		c.logger(logging.Error, "Channel.history.loadOrCreate: ", err)
		c.stats.loadErrorCount.Add(1)
		return err
	}
	if err = history.Write(msg); err != nil {
		c.logger(logging.Error, "Channel.history.Write: ", err)
		c.stats.writeErrorCount.Add(1)
		return err
	}
	written, err := c.update(history)
	for _, ancestor := range c.ancestors {
		w, e := ancestor.update(history)
		written += w
		if err == nil {
			err = e
		}
	}
	c.stats.msgsSent.Add(1)
	c.stats.bytesSent.Add(int64(len(msg.Data())))
	c.stats.bytesWritten.Add(int64(written))
	c.onWritten(int64(written))
	return err
}

func (c *channel) Close(err error) error {
	c.notifyClosed(c.addr, err)
	for _, ancestor := range c.ancestors {
		ancestor.notifyClosed(c.addr, err)
	}
	return nil
}

func (c *channel) notifyClosed(addr stream.Address, cause error) {
	c.mutex.RLock()
	c.scheduler.SetClosed(addr, cause, c.listeners.AsSlice())
	c.mutex.RUnlock()
}

func (c *channel) add(l stream.Listener) (bool, error) {
	c.mutex.Lock()
	moved := c.moved
	empty := false
	added := false
	if moved == nil {
		empty = c.listeners.IsEmpty()
		added = c.listeners.Add(l)
	}
	c.mutex.Unlock()
	if moved != nil {
		if cast, ok := l.(*clientListener); ok {
			cast.respond(moved.msg, moved.err)
		}
	}
	if added {
		if empty {
			if err := c.logic.OnAddressJoined(c.Address()); err != nil {
				c.mutex.Lock()
				c.listeners.Remove(l)
				c.mutex.Unlock()
				return false, err
			}
		}
		c.stats.subscriptions.Add(1)
	}
	return added, nil
}

func (c *channel) remove(l stream.Listener) (bool, error) {
	err := error(nil)
	c.mutex.Lock()
	removed := c.listeners.Remove(l)
	emptied := c.listeners.IsEmpty()
	if emptied && len(c.writers) == 0 && len(c.children) == 0 {
		c.onEmptied(c)
	}
	c.mutex.Unlock()
	if removed {
		c.stats.subscriptions.Add(-1)
	}
	if emptied {
		err = c.logic.OnAddressParted(c.Address())
	}
	return removed, err
}

// channel or parent has moved; block any new writers/listeners
// with a flag and clean up all existing ones with a moved notice
func (c *channel) revoke() error {
	msg, err := message.NewMove(stream.AddressScopes{c.Address()})
	c.mutex.Lock()
	c.moved = &moveData{msg: msg, err: err}
	listeners := c.listeners.AsSlice()
	writers := make(writerMap)
	for writer := range c.writers {
		writers[writer] = struct{}{}
	}
	c.mutex.Unlock()
	for _, listener := range listeners {
		if cast, ok := listener.(*clientListener); ok {
			cast.respond(msg, err)
		}
		c.remove(listener)
	}
	for writer := range writers {
		writer.respond(msg, err)
		writer.Close()
	}
	return err
}
