package balanced

import (
	"fmt"
	"sync"
	"time"

	"code.justin.tv/devhub/e2ml/libs/discovery/broker"
	"code.justin.tv/devhub/e2ml/libs/discovery/broker/balanced/pick"
	"code.justin.tv/devhub/e2ml/libs/discovery/protocol"
	"code.justin.tv/devhub/e2ml/libs/discovery/protocol/message"
	"code.justin.tv/devhub/e2ml/libs/logging"
	"code.justin.tv/devhub/e2ml/libs/session"
	"code.justin.tv/devhub/e2ml/libs/stream"
)

type serializeHandler func(host pick.Host, msgs ...protocol.Message)
type authHostHandler func(stream.OpaqueBytes) error
type validateHandler func(stream.OpaqueBytes) (stream.Credentials, error)
type reporterClosedHandler func(*reporterBinding)

// Binding for a single reporter
type reporterBinding struct {
	host       *host
	addrKey    string
	onAuthHost authHostHandler
	onValidate validateHandler
	onMessages serializeHandler
	onClosed   reporterClosedHandler
	on         *hostFunctions
	requests   *requestManager
	client     session.Client
	mutex      sync.Mutex
}

var _ session.Binding = (*reporterBinding)(nil)

func newReporterBinding(
	onAdd scopeHandler,
	onRemove scopeHandler,
	onHostClosed scopeHandler,
	onMessages serializeHandler,
	onReporterClosed reporterClosedHandler,
	onAuthHost authHostHandler,
	onValidate validateHandler,
	logger logging.Function,
	client session.Client,
) *reporterBinding {
	logger(logging.Debug, "Brk [OPEN]", client.Address())
	wrapper := &reporterBinding{
		onAuthHost: onAuthHost,
		onValidate: onValidate,
		onMessages: onMessages,
		onClosed:   onReporterClosed,
		requests:   newRequestManager(5 * time.Second),
		client:     client,
		addrKey:    client.Address().String(),
	}
	wrapper.on = &hostFunctions{onAdd, onRemove, onHostClosed, wrapper.allocate, wrapper.detach, wrapper.reserve, logger}
	return wrapper
}

func (r *reporterBinding) OnTextMessage(msg string) {
	r.client.Close(protocol.ErrInvalidHeader)
}

func (r *reporterBinding) OnBinaryMessage(bytes []byte) {
	msg, err := message.Unmarshal(bytes)
	if err != nil {
		r.on.log(logging.Debug, err)
		r.client.Close(err)
		return
	}

	host := r.getHost()
	if host == nil && msg.OpCode() != protocol.AuthHost {
		r.client.Close(protocol.ErrForbidden)
		return
	}

	r.on.log(logging.Trace, "Brk <=[RCV]", r.client.Address(), msg)
	switch msg.OpCode() {
	case protocol.AuthHost:
		if cast, ok := msg.(message.AuthHost); ok {
			if err := r.onAuthHost(cast.Token()); err != nil {
				r.on.log(logging.Debug, "Brk [AUTH FAILURE]", err)
				r.onBadHost(cast, err)
				return
			}
			if host, ok = r.setHost(cast); ok {
				r.onMessages(host, msg)
				r.ack(host, cast.AckID())
			} else {
				r.onBadHost(cast, protocol.ErrInvalidHost)
			}
			return
		}
	case protocol.Scopes:
		if cast, ok := msg.(message.Scopes); ok {
			if cast.Remove() {
				host.removeScopes(cast.Scopes())
			} else {
				host.addScopes(cast.Scopes())
			}
			r.onMessages(host, msg)
			r.ack(host, cast.AckID())
			return
		}
	case protocol.Status:
		if cast, ok := msg.(message.Status); ok {
			host.setStatus(cast)
			r.onMessages(host, msg)
			return
		}
	case protocol.Failure:
		if cast, ok := msg.(message.Failure); ok {
			r.requests.onFailure(cast.ForRequestID(), cast)
			return
		}
	case protocol.Ticket:
		if cast, ok := msg.(message.Ticket); ok {
			r.requests.onTicketSuccess(cast.ForRequestID(), host.createTicket(cast))
			return
		}
	case protocol.Allocation:
		if cast, ok := msg.(message.Allocation); ok {
			r.requests.onScopesSuccess(cast.ForRequestID(), cast.Scopes())
			return
		}
	case protocol.Detached:
		if cast, ok := msg.(message.Detached); ok {
			r.requests.onScopesSuccess(cast.ForRequestID(), cast.Scopes())
			return
		}
	case protocol.Validate:
		var res protocol.Message
		if cast, ok := msg.(message.Validate); ok {
			creds, err := r.onValidate(cast.Token())
			if err == nil {
				res, err = message.NewBind(cast.AckID(), cast.Token(), creds)
			}
			if err != nil {
				res, _ = message.NewError(cast.AckID(), err)
			}
			if res != nil {
				r.write(host, res)
			}
			return
		}
	default:
		r.on.log(logging.Debug, protocol.ErrInvalidOpCode)
		return
	}
	r.on.log(logging.Debug, protocol.ErrInvalidPayload(msg.OpCode()))
}

func (r *reporterBinding) OnClosed(err error) {
	r.mutex.Lock()
	host := r.host
	r.host = nil
	r.mutex.Unlock()
	if host != nil {
		host.close()
	}
	r.onClosed(r)
	r.on.log(logging.Debug, "Brk [CLOSE]", r.client.Address(), err)
	r.requests.close()
}

func (r *reporterBinding) getHost() *host {
	r.mutex.Lock()
	host := r.host
	r.mutex.Unlock()
	return host
}

func (r *reporterBinding) ack(host *host, id protocol.AckID) {
	ack, err := message.NewAck(id)
	if err != nil {
		r.on.log(logging.Debug, err, ack)
		return
	}
	r.write(host, ack)
}

func (r *reporterBinding) onBadHost(src message.AuthHost, err error) {
	msg, merr := message.NewError(src.AckID(), err)
	if merr != nil {
		r.on.log(logging.Debug, merr, msg)
		return
	}
	r.write(&host{header: src}, msg)
	r.client.Close(err)
}

func (r *reporterBinding) allocate(h pick.Host, address stream.Address, callback broker.ScopesCallback) {
	host, ok := h.(*host)
	if !ok {
		callback(nil, protocol.ErrInvalidHost)
		return
	}
	id := r.requests.bindAlloc(callback)
	req, err := message.NewAllocate(id, address)
	if err != nil {
		r.requests.onFailure(id, err)
	} else {
		r.write(host, req)
	}
}

func (r *reporterBinding) detach(h pick.Host, address stream.Address, callback broker.ScopesCallback) {
	host, ok := h.(*host)
	if !ok || host.header.Version() < protocol.Two { // if it's an old host it won't support detach
		callback(nil, protocol.ErrInvalidHost)
		return
	}
	id := r.requests.bindAlloc(callback)
	req, err := message.NewDetach(id, address)
	if err != nil {
		r.requests.onFailure(id, err)
	} else {
		r.write(host, req)
	}
}

func (r *reporterBinding) reserve(h pick.Host, creds stream.Credentials, address stream.Address, callback broker.TicketCallback) {
	host, ok := h.(*host)
	if !ok {
		callback(nil, protocol.ErrInvalidHost)
		return
	}
	id := r.requests.bind(callback)
	req, err := message.NewReserve(id, address, creds)
	if err != nil {
		r.requests.onFailure(id, err)
	} else {
		r.write(host, req)
	}
}

func (r *reporterBinding) setHost(msg message.AuthHost) (*host, bool) {
	r.mutex.Lock()
	ok := r.host == nil || msg.Equals(r.host.header)
	if r.host == nil {
		r.host = newHost(msg, r.client.Address(), r.on)
	}
	host := r.host
	r.mutex.Unlock()
	if !ok {
		host = nil
	}
	return host, ok
}

func (r *reporterBinding) write(host *host, msg protocol.Message) {
	bytes, err := msg.Marshal(host.header.Version())
	if err != nil {
		r.on.log(logging.Debug, fmt.Errorf("Brk - Unable to marshal %v %v", msg, err))
		return
	}
	if err := r.client.WriteBinary(bytes); err != nil {
		r.on.log(logging.Debug, fmt.Errorf("Brk - Unable so send %v %v", msg, err))
	} else {
		r.on.log(logging.Trace, "Brk [SND]=>", r.client.Address(), msg)
	}
}

func (r *reporterBinding) key() string { return r.addrKey }
