package audience

import (
	"sort"
	"sync"
	"sync/atomic"
	"time"

	"code.justin.tv/devhub/e2ml/libs/logging"
	"code.justin.tv/devhub/e2ml/libs/metrics"
	"code.justin.tv/devhub/e2ml/libs/session"
	"code.justin.tv/devhub/e2ml/libs/stream"
	"code.justin.tv/devhub/e2ml/libs/stream/protocol"
	"code.justin.tv/devhub/e2ml/libs/ticket"
)

type stats struct {
	tracker           metrics.Tracker
	bindings          metrics.Aggregator
	timeoutOk         metrics.Count
	timeoutExpiring   metrics.Count
	timeoutExpired    metrics.Count
	addressMovedCount metrics.Count
	addresses         metrics.Aggregator
	load              uint64
}

type Remote interface {
	stream.Audience
	session.Server
	Tick()
	Shutdown() error
}

type remote struct {
	logger      logging.Function
	logic       stream.ServerLogic
	enabled     stream.AddressScopes           // scopes allowed to have channels
	channels    map[stream.AddressKey]*channel // scopes with open channels
	bindings    sync.Map
	warningTime time.Duration
	scheduler   stream.Scheduler
	shutdown    bool
	tickets     ticket.Redeemer
	stats       stats
	mutex       sync.RWMutex
}

// NewRemote acts as a bridge between session.Server and stream.ServerLogic,
// performing automatic audience and session management.
func NewRemote(factory stream.ServerLogicFactory, scheduler stream.Scheduler, tickets ticket.Redeemer, warnBeforeExpiration time.Duration, tracker metrics.Tracker, logger logging.Function) Remote {
	if tickets == nil {
		tickets = ticket.GetNilRedeemer()
	}
	s := &remote{
		stats: stats{
			tracker:           tracker,
			bindings:          tracker.Aggregator("audience.bindings", []string{}),
			addresses:         tracker.Aggregator("audience.addresses", []string{}),
			timeoutOk:         tracker.Count("audience.timeouts", []string{"status:ok"}),
			timeoutExpiring:   tracker.Count("audience.timeouts", []string{"status:expiring"}),
			timeoutExpired:    tracker.Count("audience.timeouts", []string{"status:expired"}),
			addressMovedCount: tracker.Count("audience.counts", []string{"error:address_moved"}),
		},
		warningTime: warnBeforeExpiration,
		enabled:     make(stream.AddressScopes, 0),
		channels:    make(map[stream.AddressKey]*channel),
		tickets:     tickets,
		scheduler:   scheduler,
		logger:      logger,
	}
	s.logic = factory(s)
	return s
}

func (r *remote) LoadFactor() uint64 { return atomic.SwapUint64(&r.stats.load, 0) }

func (r *remote) Tick() {
	shouldExpire := time.Now()
	shouldWarn := shouldExpire.Add(r.warningTime)
	r.bindings.Range(func(key, value interface{}) bool {
		b := key.(*binding)
		if expirationTime := b.credentials().Expires(); expirationTime == nil || expirationTime.After(shouldWarn) {
			r.stats.timeoutOk.Add(1)
		} else if expirationTime.After(shouldExpire) {
			r.stats.timeoutExpiring.Add(1)
			b.expiring()
		} else {
			r.stats.timeoutExpired.Add(1)
			b.expired()
		}
		return true
	})
}

func (r *remote) Shutdown() error {
	r.mutex.Lock()
	r.shutdown = true
	r.mutex.Unlock()
	r.bindings.Range(func(key, value interface{}) bool {
		key.(*binding).drain()
		return true
	})
	return r.logic.Shutdown()
}

func (r *remote) Factory() session.BindingFactory { return r.create }

func (r *remote) ForAddress(addr stream.Address) (stream.Topic, error) {
	return r.channel(addr)
}

func (r *remote) Enable(scopes stream.AddressScopes) {
	r.mutex.Lock()
	defer r.mutex.Unlock()
	for _, scope := range scopes {
		if !r.enabled.Contains(scope) {
			r.enabled = append(r.enabled, scope)
		}
	}
	sort.Sort(r.enabled)
}

func (r *remote) Disable(scopes stream.AddressScopes) {
	r.mutex.Lock()
	// filter from enabled
	enabled := stream.AddressScopes{}
	for _, scope := range r.enabled {
		if scopes.Contains(scope) {
			continue
		}
		enabled = append(enabled, scope)
	}
	r.enabled = enabled

	// Look for subtrees to prune at the point of each diabled scope; a
	// particular subtree might still be valid if a larger scope still
	// covers it.
	topLevel := []*channel{}
	for _, scope := range scopes {
		if ch, ok := r.channels[scope.Key()]; ok {
			if !r.isEnabled(ch.Address()) {
				topLevel = append(topLevel, ch)
			}
		}
	}

	// any invalid subtrees need to be culled including descendants;
	// use a map to collapse any overlap
	toRevoke := make(childMap)
	for _, ch := range topLevel {
		toRevoke[ch.Address().Key()] = ch
		ch.descendents(toRevoke)
	}

	// detach channels from the subtree
	for key := range toRevoke {
		delete(r.channels, key)
	}
	r.mutex.Unlock()

	// cleanup listeners and writers of each channel
	for _, ch := range toRevoke {
		ch.revoke()
	}
}

func (r *remote) create(client session.Client) session.Binding {
	binding := newBinding(client, r.channel, r.drop, r.tickets, r.logger)
	r.mutex.Lock()
	shutdown := r.shutdown
	if !shutdown {
		r.bindings.Store(binding, struct{}{})
	}
	r.mutex.Unlock()
	if shutdown {
		binding.reject(protocol.ErrDraining)
	} else {
		r.stats.bindings.Add(1)
	}
	return binding
}

func (r *remote) channel(addr stream.Address) (*channel, error) {
	addrKey := addr.Key()
	r.mutex.RLock()
	ch, ok := r.channels[addrKey]
	r.mutex.RUnlock()
	if ok {
		return ch, nil // cached channel
	}

	parents := []*channel{}
	for _, parent := range addr.Parents() {
		if cast, ok := parent.(stream.Address); ok {
			if p, err := r.channel(cast); err != nil {
				r.logger(logging.Error, "Remote.channel: ", err)
				return nil, err
			} else {
				parents = append(parents, p)
			}
		}
	}
	dedupe := make(map[stream.AddressKey]*channel)
	for _, parent := range parents {
		pkey := parent.Address().Key()
		dedupe[pkey] = parent
		for _, ancestor := range parent.ancestors {
			dedupe[pkey] = ancestor
		}
	}
	ancestors := make([]*channel, 0, len(dedupe))
	for _, ancestor := range dedupe {
		ancestors = append(ancestors, ancestor)
	}

	r.mutex.Lock()
	ch, ok = r.channels[addrKey]
	if ok {
		r.mutex.Unlock()
		return ch, nil // cached channel
	}
	if !r.isEnabled(addr) {
		r.mutex.Unlock()
		r.stats.addressMovedCount.Add(1)
		// This will cause a protocol.Move message to be returned to Threshold, so it can go find another Source
		return nil, protocol.ErrAddressMoved
	}

	// Add new channel
	ch, ok = r.channels[addrKey]
	if !ok {
		ch = newChannel(r.logic, addr, r.scheduler, ancestors, r.onWritten, r.free, r.logic.OnSendRequested, r.stats.tracker, r.logger)
		r.channels[addrKey] = ch
	}
	r.mutex.Unlock()

	r.stats.addresses.Add(1)
	for _, p := range parents {
		p.attach(ch)
	}
	return ch, nil
}

func (r *remote) free(ch *channel) {
	addr := ch.Address()
	key := addr.Key()
	r.mutex.Lock()
	if found, ok := r.channels[key]; ok && found == ch {
		delete(r.channels, key)
	}
	parents := []*channel{}
	for _, parentAddr := range addr.Parents() {
		if cast, ok := parentAddr.(stream.Address); ok {
			if parent, ok := r.channels[cast.Key()]; ok {
				parents = append(parents, parent)
			}
		}
	}
	r.mutex.Unlock()

	r.stats.addresses.Add(-1)
	for _, parent := range parents {
		parent.detach(ch)
	}
}

func (r *remote) onWritten(count int64) {
	atomic.AddUint64(&r.stats.load, uint64(count))
}

func (r *remote) drop(b *binding, l stream.Listener, keys []stream.AddressKey) {
	channels := []*channel{}
	_, found := r.bindings.Load(b)
	r.bindings.Delete(b)
	r.mutex.RLock()
	for _, key := range keys {
		if ch, ok := r.channels[key]; ok {
			channels = append(channels, ch)
		}
	}
	r.mutex.RUnlock()
	if found {
		r.stats.bindings.Add(-1)
	}
	for _, ch := range channels {
		ch.remove(l)
	}
}

// isEnabled returns true if the address matches anything on the enabled list of scopes.
// Needs to be called from within a lock
func (r *remote) isEnabled(addr stream.Address) bool {
	// possibility #1 - inside allowed subtree
	if _, allowed := r.enabled.HasBetterMatch(addr, stream.NoAddressMatch); allowed {
		return true
	}
	// possiblity #2 - ancestor of allowed address necessary to build tree to root
	for _, entry := range r.enabled {
		if cast, ok := entry.(stream.Address); ok && addr.Includes(cast) {
			return true
		}
	}
	return false
}
