package audience

import (
	"sync"
	"sync/atomic"
	"unsafe"

	"code.justin.tv/devhub/e2ml/libs/logging"
	"code.justin.tv/devhub/e2ml/libs/session"
	"code.justin.tv/devhub/e2ml/libs/stream"
	"code.justin.tv/devhub/e2ml/libs/stream/protocol"
	"code.justin.tv/devhub/e2ml/libs/stream/protocol/message"
	"code.justin.tv/devhub/e2ml/libs/stream/scheduler"
	"code.justin.tv/devhub/e2ml/libs/ticket"
)

type channelFunc func(stream.Address) (*channel, error)
type dropFunc func(*binding, stream.Listener, []stream.AddressKey)

// per-client data within our audience compatible with the session library
type binding struct {
	channel     channelFunc
	drop        dropFunc
	client      session.Client
	writers     map[stream.AddressKey]audienceWriter
	logger      logging.Function
	redeemer    ticket.Redeemer
	credsPtr    unsafe.Pointer
	listenerPtr unsafe.Pointer
	draining    int32
	wMutex      sync.RWMutex // for coordinating writers
}

var _ session.Binding = (*binding)(nil)
var initialCreds = stream.NoPermissions()

func newBinding(
	client session.Client,
	channel channelFunc,
	drop dropFunc,
	redeemer ticket.Redeemer,
	logger logging.Function,
) *binding {
	return &binding{
		channel:  channel,
		credsPtr: unsafe.Pointer(&initialCreds),
		drop:     drop,
		client:   client,
		writers:  make(map[stream.AddressKey]audienceWriter),
		redeemer: redeemer,
		logger:   logger,
	}
}

func (b *binding) listener() *clientListener {
	return (*clientListener)(atomic.LoadPointer(&b.listenerPtr))
}

func (b *binding) credentials() stream.Credentials {
	return *(*stream.Credentials)(atomic.LoadPointer(&b.credsPtr))
}

func (b *binding) respond(msg protocol.Message, err error) {
	if listener := b.listener(); listener != nil {
		listener.respond(msg, err)
	}
}

func (b *binding) OnTextMessage(string) {
	b.client.Close(protocol.ErrInvalidHeader)
}

func (b *binding) OnBinaryMessage(bytes []byte) {
	msg, err := message.Unmarshal(bytes)
	if err != nil {
		b.logger(logging.Debug, err)
		b.client.Close(err)
		return
	}
	b.logger(logging.Trace, "Aud <=[RCV]", b.client.Address(), msg)
	listener := b.listener()
	if listener == nil && msg.OpCode() != protocol.Init {
		b.client.Close(protocol.ErrInvalidVersion)
		return
	}
	switch msg.OpCode() {
	case protocol.Init:
		if cast, ok := msg.(message.Init); ok {
			if listener != nil {
				b.client.Close(protocol.ErrInvalidOpCode)
				break
			}

			if !cast.Version().IsValid() {
				b.client.Close(protocol.ErrInvalidVersion)
				break
			}
			// if this is a remote auth system, expect this to take a while
			// it is blocking to pause additional incoming requests until authorization is confirmed
			creds, err := b.redeemer.Redeem(cast.AuthMethod(), cast.AccessCode()).Result()
			if err != nil {
				b.client.Close(err)
				break
			}
			atomic.StorePointer(&b.credsPtr, unsafe.Pointer(&creds))

			listener = newListener(b.client, cast.Version(), b.logger)
			atomic.StorePointer(&b.listenerPtr, unsafe.Pointer(listener))
			if atomic.LoadInt32(&b.draining) != 0 {
				b.client.Close(protocol.ErrDraining)
			} else {
				listener.respond(message.NewAck(cast.RequestID(), stream.None, stream.Origin))
			}
			break
		}
		b.client.Close(protocol.ErrInvalidPayload(protocol.Init))
	case protocol.Refresh:
		if cast, ok := msg.(message.Refresh); ok {
			creds, err := b.redeemer.Redeem(cast.AuthMethod(), cast.AccessCode()).Result()
			if err != nil {
				listener.respond(message.NewError(cast.RequestID(), err))
			} else {
				atomic.StorePointer(&b.credsPtr, unsafe.Pointer(&creds))
				listener.respond(message.NewAck(cast.RequestID(), stream.None, stream.Origin))
				for _, addr := range listener.listInvalidated(creds) {
					if ch, err := b.channel(addr); err != nil {
						ch.remove(listener)
					}
					listener.unsubscribe(addr)
					src, pos := listener.pos(addr).Next()
					listener.respond(message.NewClosed(addr, src, pos))
				}
				b.cullWriters(creds)
			}
			break
		}
		b.client.Close(protocol.ErrInvalidPayload(protocol.Refresh))
	case protocol.Join:
		if cast, ok := msg.(message.Join); ok {
			creds := b.credentials()
			addr := cast.Address()
			if !creds.CanListen(cast.Address()) {
				b.logger(logging.Debug, "Aud [DENY]", creds, addr)
				listener.respond(message.NewError(cast.RequestID(), protocol.ErrForbiddenAddress))
				break
			}
			ch, err := b.channel(addr)
			if err != nil {
				listener.respond(message.NewError(cast.RequestID(), err))
			} else if err = listener.subscribe(cast); err != nil {
				listener.respond(message.NewError(cast.RequestID(), err))
			} else if _, err = ch.add(listener); err != nil {
				listener.unsubscribe(addr)
				listener.respond(message.NewError(cast.RequestID(), err))
			} else {
				listener.respond(message.NewAck(cast.RequestID(), cast.Source(), cast.Position()))
				if history, _ := ch.history.load(); history != nil {
					scheduler.Update(history, stream.Origin, listener)
				}
				desc := make(map[stream.AddressKey]*channel)
				ch.descendents(desc)
				for _, d := range desc {
					if history, _ := d.history.load(); history != nil {
						scheduler.Update(history, stream.Origin, listener)
					}
				}
			}
			break
		}
		b.client.Close(protocol.ErrInvalidPayload(protocol.Join))
	case protocol.Part:
		if cast, ok := msg.(message.Part); ok {
			ch, err := b.channel(cast.Address())
			if err != nil {
				listener.respond(message.NewError(cast.RequestID(), err))
			} else if _, err := ch.remove(listener); err != nil {
				listener.respond(message.NewError(cast.RequestID(), err))
			} else {
				listener.unsubscribe(cast.Address())
				listener.respond(message.NewAck(cast.RequestID(), stream.None, stream.Origin))
			}
			break
		}
		b.client.Close(protocol.ErrInvalidPayload(protocol.Part))
	case protocol.Release:
		// TODO : release writer from this binding and notify
		if cast, ok := msg.(message.Release); ok {
			key := cast.Address().Key()
			b.wMutex.Lock()
			writer, ok := b.writers[key]
			delete(b.writers, key)
			b.wMutex.Unlock()
			if ok {
				if err := writer.Close(); err != nil {
					listener.respond(message.NewError(cast.RequestID(), err))
					break
				}
			}
			listener.respond(message.NewAck(cast.RequestID(), stream.None, stream.Origin))
			break
		}
		b.client.Close(protocol.ErrInvalidPayload(protocol.Release))
	case protocol.Send:
		// TODO : writer collection on this binding for automatic cleanup, move to writer fetch function
		// instead of a send function
		if cast, ok := msg.(message.Send); ok {
			creds := b.credentials()
			if !creds.CanSend(cast.Address()) {
				b.logger(logging.Debug, "Aud [DENY]", creds, cast.Address())
				listener.respond(message.NewError(cast.RequestID(), protocol.ErrForbiddenAddress))
				break
			}
			writer, err := b.writer(cast.Address())
			if err == protocol.ErrAddressMoved {
				// This Source is not handling this address, maybe because of a collision resolution.
				// Respond with <move> message, so the Threshold reconnects to another Source.
				listener.respond(message.NewMove(stream.AddressScopes{cast.Address()}))
			} else if err != nil {
				// Respond with <error>
				listener.respond(message.NewError(cast.RequestID(), err))
			} else {
				// Prepare a promise to <ack> or <error> that the message is being delivered to listeners
				writer.send(cast.Content(), cast.IsDelta(), newReceipt(listener, cast))
			}
			break
		}
		b.client.Close(protocol.ErrInvalidPayload(protocol.Send))
	default:
		b.client.Close(protocol.ErrInvalidOpCode)
	}
}

func (b *binding) expiring() {
	if listener := b.listener(); listener != nil && listener.version >= protocol.Four {
		listener.respond(message.NewExpiring())
	}
}

func (b *binding) expired() {
	b.client.Close(protocol.ErrExpiredAccess)
}

func (b *binding) reject(err error) {
	b.drain()
	b.client.Close(err)
}

func (b *binding) drain() {
	atomic.AddInt32(&b.draining, 1)
	if listener := b.listener(); listener != nil {
		listener.respond(message.NewDrain())
	}
}

func (b *binding) writer(addr stream.Address) (audienceWriter, error) {
	key := addr.Key()
	b.wMutex.RLock()
	w, ok := b.writers[key]
	b.wMutex.RUnlock()
	if ok {
		return w, nil
	}
	ch, err := b.channel(addr)
	b.wMutex.Lock()
	w, ok = b.writers[key]
	if !ok && err == nil {
		w = ch.newWriter(b)
		b.writers[key] = w
	}
	b.wMutex.Unlock()
	return w, err
}

func (b *binding) cullWriters(creds stream.Credentials) {
	b.wMutex.Lock()
	defer b.wMutex.Unlock()
	for key, writer := range b.writers {
		if !creds.CanSend(writer.Address()) {
			writer.Close()
			delete(b.writers, key)
		}
	}
}

func (b *binding) OnClosed(error) {
	listener := (*clientListener)(atomic.SwapPointer(&b.listenerPtr, unsafe.Pointer(nil)))
	addrs := []stream.AddressKey{}
	if listener != nil {
		addrs = listener.onClosed()
	}
	b.wMutex.Lock()
	writers := b.writers
	b.writers = make(map[stream.AddressKey]audienceWriter)
	b.wMutex.Unlock()
	for _, writer := range writers {
		writer.Close()
	}
	b.drop(b, listener, addrs)
}
