package balanced

import (
	"net"
	"sync"
	"sync/atomic"
	"time"

	"code.justin.tv/devhub/e2ml/libs/discovery"
	"code.justin.tv/devhub/e2ml/libs/discovery/broker"
	"code.justin.tv/devhub/e2ml/libs/discovery/broker/balanced/pick"
	"code.justin.tv/devhub/e2ml/libs/discovery/protocol"
	"code.justin.tv/devhub/e2ml/libs/discovery/protocol/message"
	"code.justin.tv/devhub/e2ml/libs/logging"
	"code.justin.tv/devhub/e2ml/libs/session"
	"code.justin.tv/devhub/e2ml/libs/stream"
)

type connectHandler func(protocol.Version, session.Client)
type reporterSearch func(net.Addr) *reporterBinding
type messageHandler func(protocol.Version, session.Client, protocol.Message)
type peerClosedHandler func(*peerBinding)

// Binding for a peered broker proxy handling many reporters
type peerBinding struct {
	on         *hostFunctions
	hosts      map[string]*host
	version    protocol.Version
	onElection messageHandler
	onAuthHost authHostHandler
	onConnect  connectHandler
	onClosed   peerClosedHandler
	onResolve  reporterSearch
	requests   *requestManager
	client     session.Client
	ready      int32
	queue      *forwardingQueue
	mutex      sync.RWMutex
}

var _ session.Binding = (*peerBinding)(nil)

func newPeerBinding(
	onAdd scopeHandler,
	onRemove scopeHandler,
	onElection messageHandler,
	onClosed peerClosedHandler,
	onAuthHost authHostHandler,
	onConnect connectHandler,
	onResolve reporterSearch,
	logger logging.Function,
	client session.Client,
) *peerBinding {
	wrapper := &peerBinding{
		client:     client,
		hosts:      make(map[string]*host),
		onElection: onElection,
		onAuthHost: onAuthHost,
		onConnect:  onConnect,
		onClosed:   onClosed,
		onResolve:  onResolve,
		requests:   newRequestManager(5 * time.Second),
		queue:      newForwardingQueue(),
		version:    protocol.Unknown,
	}

	logger(logging.Debug, "Peer [OPEN]", client.Address())
	wrapper.on = &hostFunctions{onAdd, onRemove, onRemove, wrapper.allocate, wrapper.detach, wrapper.reserve, logger}
	return wrapper
}

func (p *peerBinding) OnTextMessage(string) {
	p.client.Close(protocol.ErrInvalidHeader)
}

func (p *peerBinding) OnBinaryMessage(bytes []byte) {
	msg, err := message.Unmarshal(bytes)
	if err != nil {
		p.on.log(logging.Debug, err)
		p.client.Close(err)
		return
	}

	p.mutex.RLock()
	version := p.version
	p.mutex.RUnlock()

	if !version.IsValid() && msg.OpCode() != protocol.Connect {
		p.client.Close(protocol.ErrForbidden)
		return
	}

	p.on.log(logging.Trace, "Peer <=[RCV]", p.client.Address(), msg)
	switch msg.OpCode() {
	case protocol.Connect:
		if cast, ok := msg.(message.Connect); ok {
			if err := p.onAuthHost(cast.Token()); err != nil {
				msg, err := message.NewRejected()
				p.client.Close(p.send(cast.Version(), msg, err))
				return
			}
			p.mutex.Lock()
			version = cast.Version()
			p.version = version
			p.mutex.Unlock()
			p.onConnect(version, p.client)
			return
		}
		return
	case protocol.Forward:
		if cast, ok := msg.(message.Forward); ok {
			if !cast.IsInit() && atomic.LoadInt32(&p.ready) == 0 {
				p.queue.append(cast)
			} else {
				p.onForward(version, cast.Source(), cast.Message())
			}
			return
		}
	case protocol.Ready:
		if _, ok := msg.(message.Ready); ok {
			atomic.StoreInt32(&p.ready, 1)
			for _, msg := range p.queue.Flush() {
				p.onForward(version, msg.Source(), msg.Message())
			}
			return
		}
	case protocol.Rejected:
		p.on.log(logging.Debug, "Rejected by peer", p.client.Address())
		p.client.Close(nil)
		return
	case protocol.Prepare: // paxos 1
		fallthrough
	case protocol.Promise: // paxos 2
		fallthrough
	case protocol.Accept: // paxos 3
		fallthrough
	case protocol.Accepted: // paxos 4
		p.onElection(version, p.client, msg)
		return
	default:
		p.on.log(logging.Debug, protocol.ErrInvalidOpCode)
		return
	}
	p.on.log(logging.Debug, protocol.ErrInvalidPayload(msg.OpCode()))
}

func (p *peerBinding) OnClosed(error) {
	p.on.log(logging.Debug, "Peer [CLOSE]", p.client.Address())
	p.mutex.Lock()
	hosts := p.hosts
	p.hosts = make(map[string]*host)
	p.mutex.Unlock()
	for _, host := range hosts {
		host.close()
	}
	p.onClosed(p)
	p.requests.close()
}

func (p *peerBinding) onForward(version protocol.Version, src net.Addr, msg protocol.Message) {
	switch msg.OpCode() {
	case protocol.AuthHost:
		if cast, ok := msg.(message.AuthHost); ok {
			p.setHost(src, cast)
			return
		}
	case protocol.Scopes:
		if cast, ok := msg.(message.Scopes); ok {
			if host, ok := p.getHost(src); ok {
				if cast.Remove() {
					host.removeScopes(cast.Scopes())
				} else {
					host.addScopes(cast.Scopes())
				}
			}
			return
		}
	case protocol.Status:
		if cast, ok := msg.(message.Status); ok {
			if cast.Flags().IsDraining() {
				p.clearHost(src)
			} else if host, ok := p.getHost(src); ok {
				host.setStatus(cast)
			}
			return
		}
	case protocol.Failure:
		if cast, ok := msg.(message.Failure); ok {
			if _, ok := p.getHost(src); ok {
				p.requests.onFailure(cast.ForRequestID(), cast)
			}
			return
		}
	case protocol.Allocate:
		if cast, ok := msg.(message.Allocate); ok {
			fwd := p.forwardResponse(version, src)
			if resolver := p.onResolve(src); resolver != nil {
				if host := resolver.getHost(); host != nil {
					resolver.allocate(host, cast.Address(), func(scopes stream.AddressScopes, err error) {
						if err != nil {
							fwd(message.NewFailure(cast.RequestID(), err))
						} else {
							fwd(message.NewAllocation(cast.RequestID(), scopes))
						}
					})
					return
				}
			}
			fwd(message.NewFailure(cast.RequestID(), protocol.ErrInvalidHost))
			return
		}
	case protocol.Allocation:
		if cast, ok := msg.(message.Allocation); ok {
			p.requests.onScopesSuccess(cast.ForRequestID(), cast.Scopes())
			return
		}
	case protocol.Ticket:
		if cast, ok := msg.(message.Ticket); ok {
			if host, ok := p.getHost(src); ok {
				p.requests.onTicketSuccess(cast.ForRequestID(), host.createTicket(cast))
			}
			return
		}
	case protocol.Reserve:
		if cast, ok := msg.(message.Reserve); ok {
			fwd := p.forwardResponse(version, src)
			if resolver := p.onResolve(src); resolver != nil {
				if host := resolver.getHost(); host != nil {
					resolver.reserve(host, cast.Credentials(), cast.Address(), func(ticket discovery.Ticket, err error) {
						if err != nil {
							fwd(message.NewFailure(cast.RequestID(), err))
						} else {
							fwd(message.NewTicket(cast.RequestID(), ticket.AccessCode(), ticket.Scopes()))
						}
					})
					return
				}
			}
			fwd(message.NewFailure(cast.RequestID(), protocol.ErrInvalidHost))
			return
		}
	case protocol.Validate:
		fallthrough // should never be handled by peer links
	default:
		p.on.log(logging.Debug, protocol.ErrInvalidOpCode, msg.OpCode())
		return
	}
	p.on.log(logging.Debug, protocol.ErrInvalidPayload(msg.OpCode()))
}

func (p *peerBinding) setHost(src net.Addr, msg message.AuthHost) (*host, bool) {
	key := src.String()
	p.mutex.Lock()
	host, ok := p.hosts[key]
	if ok && !host.header.Equals(msg) {
		p.mutex.Unlock()
		return nil, false
	}
	host = newHost(msg, src, p.on)
	p.hosts[key] = host
	p.mutex.Unlock()
	return host, true
}

func (p *peerBinding) getHost(src net.Addr) (*host, bool) {
	p.mutex.RLock()
	host, ok := p.hosts[src.String()]
	p.mutex.RUnlock()
	return host, ok
}

func (p *peerBinding) clearHost(src net.Addr) {
	key := src.String()
	p.mutex.Lock()
	host, ok := p.hosts[key]
	delete(p.hosts, key)
	p.mutex.Unlock()
	if ok {
		host.close()
	}
}

func (p *peerBinding) allocate(host pick.Host, address stream.Address, callback broker.ScopesCallback) {
	id := p.requests.bindAlloc(callback)
	req, err := message.NewAllocate(id, address)
	if err == nil {
		p.mutex.Lock()
		version := p.version
		p.mutex.Unlock()
		err = p.forward(version, false, host.RemoteAddress(), req)
	}
	if err != nil {
		p.requests.onFailure(id, err)
	}
}

func (p *peerBinding) detach(host pick.Host, address stream.Address, callback broker.ScopesCallback) {
	// no operation -- the instance directly connected to the host will
	// handle the detach request; there is no need for the other cluster
	// members to also make the attempt so leave the listing unchanged
	callback(stream.AddressScopes{}, nil)
}

func (p *peerBinding) reserve(host pick.Host, creds stream.Credentials, address stream.Address, callback broker.TicketCallback) {
	id := p.requests.bind(callback)
	req, err := message.NewReserve(id, address, creds)
	if err == nil {
		p.mutex.Lock()
		version := p.version
		p.mutex.Unlock()
		err = p.forward(version, false, host.RemoteAddress(), req)
	}
	if err != nil {
		p.requests.onFailure(id, err)
	}
}

func (p *peerBinding) forwardResponse(version protocol.Version, src net.Addr) func(resp protocol.Response, err error) {
	return func(resp protocol.Response, err error) { p.forward(version, false, src, resp) }
}

func (p *peerBinding) forward(version protocol.Version, init bool, src net.Addr, msg protocol.Message) error {
	fwd, err := message.NewForward(init, src, msg)
	return p.send(version, fwd, err)
}

func (p *peerBinding) send(version protocol.Version, msg protocol.Message, err error) error {
	var bytes []byte
	if err == nil {
		bytes, err = msg.Marshal(version)
	}
	if err == nil {
		err = p.client.WriteBinary(bytes)
	}
	if err != nil {
		p.on.log(logging.Debug, "Unable to send message", msg, err)
	} else {
		p.on.log(logging.Trace, "Peer [SND]=>", p.client.Address(), msg)
	}
	return err
}
