package registry

import (
	"errors"
	"sync"
	"sync/atomic"

	"code.justin.tv/devhub/e2ml/libs/logging"
	"code.justin.tv/devhub/e2ml/libs/session"
	"code.justin.tv/devhub/e2ml/libs/stream"
	"code.justin.tv/devhub/e2ml/libs/stream/protocol"
	"code.justin.tv/devhub/e2ml/libs/stream/protocol/message"
)

type drainFunction func(binding *binding)
type moveFunction func(binding *binding, scopes stream.AddressScopes)
type expiringFunction func(binding *binding)
type sentFunction func(msg stream.Message)
type lostFunction func(desc stream.MessageDescription)
type streamClosedFunction func(stream.Address, stream.SourceID, stream.Position)
type hostClosedFunction func(binding *binding, err error)
type responseFunction func(src stream.SourceID, pos stream.Position, err error)

// binding tracks requests across a particular Client
type binding struct {
	host           string // URL of host
	client         session.Client
	logger         logging.Function
	onDrain        drainFunction
	onMove         moveFunction
	onExpiring     expiringFunction
	onSent         sentFunction
	onLost         lostFunction
	onStreamClosed streamClosedFunction
	onHostClosed   hostClosedFunction
	req            protocol.RequestID
	requests       *requestList
	closed         int32
	mutex          sync.Mutex
}

var _ session.Binding = (*binding)(nil)

func newBinding(
	host string,
	onSent sentFunction,
	onLost lostFunction,
	onSClosed streamClosedFunction,
	onDrain drainFunction,
	onMove moveFunction,
	onExpiring expiringFunction,
	onHClosed hostClosedFunction,
	logger logging.Function,
) *binding {
	return &binding{
		host:           host,
		logger:         logger,
		onHostClosed:   onHClosed,
		onStreamClosed: onSClosed,
		onDrain:        onDrain,
		onMove:         onMove,
		onExpiring:     onExpiring,
		onSent:         onSent,
		onLost:         onLost,
		req:            protocol.FirstRequest,
		requests:       newRequestList(),
	}
}

func (b *binding) OnTextMessage(string) {
	b.logger(logging.Debug, errors.New("Unexpected text message"))
}

func (b *binding) OnBinaryMessage(bytes []byte) {
	// todo : error reporting
	msg, err := message.Unmarshal(bytes)
	if err != nil {
		b.logger(logging.Debug, err)
		return
	}
	b.logger(logging.Trace, "Reg - RECV", b.client.Address(), msg)

	switch msg.OpCode() {
	case protocol.Drain:
		b.onDrain(b)
	case protocol.Expiring:
		b.onExpiring(b)
	case protocol.Move:
		if cast, ok := msg.(message.Move); ok {
			b.onMove(b, cast.Scopes())
			break
		}
	case protocol.Ack:
		if cast, ok := msg.(message.Ack); ok && cast.ForRequestID().IsValid() {
			b.fulfill(cast.ForRequestID(), cast.Source(), cast.Position(), nil)
			break
		}
		b.logger(logging.Debug, protocol.ErrInvalidPayload(protocol.Ack))
	case protocol.Error:
		if cast, ok := msg.(message.Error); ok {
			if cast.ForRequestID().IsValid() {
				b.fulfill(cast.ForRequestID(), stream.None, stream.Origin, cast.Unwrap())
			}
			break
		}
		b.logger(logging.Debug, protocol.ErrInvalidPayload(protocol.Error))
	case protocol.Closed:
		if cast, ok := msg.(message.Closed); ok {
			b.onStreamClosed(cast.Address(), cast.Source(), cast.Position())
			break
		}
		b.logger(logging.Debug, protocol.ErrInvalidPayload(protocol.Closed))
	case protocol.Sent:
		if cast, ok := msg.(message.Sent); ok {
			b.onSent(cast)
			break
		}
		b.logger(logging.Debug, protocol.ErrInvalidPayload(protocol.Sent))
	case protocol.Lost:
		if cast, ok := msg.(message.Lost); ok {
			b.onLost(cast)
			break
		}
		b.logger(logging.Debug, protocol.ErrInvalidPayload(protocol.Lost))
	default:
		b.logger(logging.Debug, protocol.ErrInvalidOpCode)
	}
}

func (b *binding) Tick() {
	b.mutex.Lock()
	expired := b.requests.tick()
	b.mutex.Unlock()
	for _, h := range expired {
		h(stream.None, stream.Origin, protocol.ErrServiceTimeout)
	}
}

func (b *binding) Close() error {
	if b.client != nil {
		b.client.Close(nil)
	} else {
		b.OnClosed(nil)
	}
	// flush all remaining promises with double tick
	b.Tick()
	b.Tick()
	return nil
}

func (b *binding) OnClosed(err error) {
	atomic.StoreInt32(&b.closed, 1)
	b.onHostClosed(b, err)
}

// note: should only be called from within a mutex lock
func (b *binding) nextRequest() protocol.RequestID {
	req := b.req
	b.req = req.Next()
	return req
}

func (b *binding) initialize(client session.Client, auth stream.AuthMethod, accessCode stream.OpaqueBytes, onResponse responseFunction) {
	b.mutex.Lock()
	req := b.nextRequest()
	b.requests.set(req, onResponse)
	b.client = client
	b.mutex.Unlock()

	init, err := message.NewInit(req, protocol.Current, auth, accessCode)
	if err != nil {
		b.fulfill(req, stream.None, stream.Origin, err)
	}
	b.write(req, init)
}

func (b *binding) refresh(auth stream.AuthRequest, onResponse responseFunction) {
	b.mutex.Lock()
	req := b.nextRequest()
	b.requests.set(req, onResponse)
	b.mutex.Unlock()

	ref, err := message.NewRefresh(req, stream.Validation, auth.Encode())
	if err != nil {
		b.fulfill(req, stream.None, stream.Origin, err)
	}
	b.write(req, ref)
}

func (b *binding) join(addr stream.Address, pos stream.Position, onResponse responseFunction) {
	b.mutex.Lock()
	req := b.nextRequest()
	b.requests.set(req, onResponse)
	b.mutex.Unlock()

	join, err := message.NewJoin(req, addr, stream.None, pos)
	if err != nil {
		b.fulfill(req, stream.None, stream.Origin, err)
		return
	}
	b.write(req, join)
}

func (b *binding) part(addr stream.Address, onResponse responseFunction) {
	b.mutex.Lock()
	req := b.nextRequest()
	b.requests.set(req, onResponse)
	b.mutex.Unlock()

	part, err := message.NewPart(req, addr)
	if err != nil {
		b.fulfill(req, stream.None, stream.Origin, err)
		return
	}
	b.write(req, part)
}

func (b *binding) release(addr stream.Address, onResponse responseFunction) {
	b.mutex.Lock()
	req := b.nextRequest()
	b.requests.set(req, onResponse)
	b.mutex.Unlock()

	release, err := message.NewRelease(req, addr)
	if err != nil {
		b.fulfill(req, stream.None, stream.Origin, err)
		return
	}
	b.write(req, release)
}

func (b *binding) send(addr stream.Address, content []byte, isDelta bool) stream.TrackerPromise {
	out := stream.NewTrackerPromise()
	b.mutex.Lock()
	defer b.mutex.Unlock()
	req := b.nextRequest()
	b.requests.set(req, func(src stream.SourceID, pos stream.Position, err error) {
		out.Set(stream.CreateTracker(src, pos), err)
	})

	msg, _ := message.NewSend(req, addr, content, isDelta)
	b.write(req, msg)
	return out
}

func (b *binding) write(req protocol.RequestID, msg protocol.Message) {
	b.logger(logging.Trace, "Reg - SEND", b.client.Address(), msg)
	bytes, _ := msg.Marshal(protocol.Current)
	if err := b.client.WriteBinary(bytes); err != nil {
		b.fulfill(req, stream.None, stream.Origin, err)
	}
}

func (b *binding) fulfill(req protocol.RequestID, src stream.SourceID, pos stream.Position, err error) {
	b.mutex.Lock()
	h, ok := b.requests.get(req)
	b.mutex.Unlock()
	if ok {
		h(src, pos, err)
	}
}
