package host

import (
	"sync"
	"time"

	"code.justin.tv/devhub/e2ml/libs/discovery/protocol"
	"code.justin.tv/devhub/e2ml/libs/discovery/protocol/message"
	"code.justin.tv/devhub/e2ml/libs/logging"
	"code.justin.tv/devhub/e2ml/libs/session"
	"code.justin.tv/devhub/e2ml/libs/stream"
)

type allocateFunction func(message.Allocate)
type detachFunction func(message.Detach)
type reserveFunction func(message.Reserve)
type initFunction func(error)

type binding struct {
	factory        session.ClientFactory
	logger         logging.Function
	client         session.Client
	hostname       string
	onAllocate     allocateFunction
	onDetach       detachFunction
	onReserve      reserveFunction
	onInit         initFunction
	onAuth         stream.AuthSource
	nextAck        protocol.AckID
	requests       *requestMap
	backoff        time.Duration
	connectAttempt *time.Timer
	authed         bool
	closed         bool
	mutex          sync.Mutex
}

var _ session.Binding = (*binding)(nil)

func newBinding(
	hostname string,
	factory session.ClientFactory,
	onAllocate allocateFunction,
	onDetach detachFunction,
	onReserve reserveFunction,
	onInit initFunction,
	onAuth stream.AuthSource,
	logger logging.Function,
) *binding {
	b := &binding{
		hostname:   hostname,
		factory:    factory,
		onAllocate: onAllocate,
		onDetach:   onDetach,
		onReserve:  onReserve,
		onInit:     onInit,
		onAuth:     onAuth,
		requests:   newRequestMap(),
		nextAck:    protocol.FirstAckID,
		logger:     logger,
	}
	logger(logging.Info, "Reporting hostname [", hostname, "] to broker")
	b.connect()
	return b
}

func (b *binding) tick() {
	b.mutex.Lock()
	old := b.requests.tick()
	b.mutex.Unlock()
	for _, promise := range old {
		promise.Set(nil, protocol.ErrServiceTimedOut)
	}
}

func (b *binding) close(err error) {
	b.mutex.Lock()
	client := b.client
	b.closed = true
	b.mutex.Unlock()
	if client == nil {
		b.OnClosed(err)
	} else {
		client.Close(err)
	}
}

func (b *binding) OnTextMessage(msg string) {
	b.client.Close(protocol.ErrInvalidHeader)
}

func (b *binding) OnBinaryMessage(bytes []byte) {
	msg, err := message.Unmarshal(bytes)
	if err != nil {
		b.client.Close(err)
		return
	}
	b.logger(logging.Trace, "Host <=[RCV]", msg)
	switch msg.OpCode() {
	case protocol.Ack:
		if cast, ok := msg.(message.Ack); ok {
			b.mutex.Lock()
			promise, ok := b.requests.remove(cast.ForAckID())
			b.mutex.Unlock()
			if ok {
				promise.Set(stream.NoPermissions(), nil)
			}
			return
		}
	case protocol.Bind:
		if cast, ok := msg.(message.Bind); ok {
			b.mutex.Lock()
			promise, ok := b.requests.remove(cast.ForAckID())
			b.mutex.Unlock()
			if ok {
				promise.Set(cast.Credentials(), nil)
			}
			return
		}
	case protocol.Error:
		if cast, ok := msg.(message.Error); ok {
			b.mutex.Lock()
			promise, ok := b.requests.remove(cast.ForAckID())
			b.mutex.Unlock()
			if ok {
				promise.Set(stream.NoPermissions(), cast.Unwrap())
			}
			return
		}
	case protocol.Allocate:
		if cast, ok := msg.(message.Allocate); ok {
			b.onAllocate(cast)
			return
		}
	case protocol.Detach:
		if cast, ok := msg.(message.Detach); ok {
			b.onDetach(cast)
			return
		}
	case protocol.Reserve:
		if cast, ok := msg.(message.Reserve); ok {
			b.onReserve(cast)
			return
		}
	default:
		b.logger(logging.Debug, protocol.ErrInvalidOpCode)
		return
	}
	b.logger(logging.Debug, protocol.ErrInvalidPayload(msg.OpCode()))
}

func (b *binding) OnClosed(err error) {
	b.logger(logging.Trace, "Host [CLOSED]", err)
	b.tick() // expire any current requests
	b.tick()
	b.mutex.Lock()
	b.authed = false
	b.client = nil
	if b.closed && b.connectAttempt != nil {
		b.connectAttempt.Stop()
		b.connectAttempt = nil
	}
	b.mutex.Unlock()
	b.connect()
}

func (b *binding) validate(msg message.Validate) stream.CredentialsPromise {
	b.mutex.Lock()
	closed := b.closed
	client := b.client
	authed := b.authed && client != nil
	b.mutex.Unlock()
	if closed {
		b.logger(logging.Trace, "Host [DROP]", msg.OpCode())
		return stream.NewSyncCredentialsPromise(nil, protocol.ErrServiceShuttingDown)
	}
	bytes, err := msg.Marshal(protocol.Current)
	if err != nil {
		return stream.NewSyncCredentialsPromise(nil, err)
	}
	b.mutex.Lock()
	p, found := b.requests.remove(msg.AckID()) // refresh timing if this is a retry
	if !found {
		p = stream.NewCredentialsPromise()
	}
	b.requests.add(msg.AckID(), p)
	if !authed {
		b.requests.queue(msg)
	}
	b.mutex.Unlock()

	if !authed {
		b.logger(logging.Trace, "Host [QUEUE]", msg.OpCode(), msg.AckID())
	} else {
		b.logger(logging.Trace, "Host [SND]=>", msg)
		if err = client.WriteBinary(bytes); err != nil {
			b.mutex.Lock()
			b.requests.remove(msg.AckID())
			b.mutex.Unlock()
			p.Set(nil, err)
		}
	}
	return p
}

// NOTE : don't queue old status messages if we're not connected;
// on connection send a current status instead.  Prevents unneccesary
// memory use and out-of-date information from confusing the system.
func (b *binding) send(msg protocol.Message) error {
	b.mutex.Lock()
	client := b.client
	authed := b.authed
	b.mutex.Unlock()
	if !authed || client == nil {
		b.logger(logging.Trace, "Host [DROP]", msg.OpCode())
		return nil
	}
	bytes, err := msg.Marshal(protocol.Current)
	if err == nil {
		b.logger(logging.Trace, "Host [SND]=>", msg)
		err = client.WriteBinary(bytes)
	}
	return err
}

func (b *binding) nextAckID() protocol.AckID {
	b.mutex.Lock()
	ackID := b.nextAck
	b.nextAck = ackID.Next()
	b.mutex.Unlock()
	return ackID
}

func (b *binding) connect() {
	b.mutex.Lock()
	defer b.mutex.Unlock()
	if !b.closed && b.connectAttempt == nil && b.client == nil {
		b.connectAttempt = time.AfterFunc(b.backoff, b.onConnect)
		if b.backoff < 7 {
			b.backoff = b.backoff*2 + time.Second
		}
	}
}

// retry any validations during connection outage
func (b *binding) finishInit() {
	b.mutex.Lock()
	msgs := b.requests.onFinishInit()
	b.mutex.Unlock()
	for _, msg := range msgs {
		b.validate(msg)
	}
}

func (b *binding) onConnect() {
	b.mutex.Lock()
	if b.connectAttempt == nil || b.client != nil {
		b.mutex.Unlock()
		return
	}
	b.connectAttempt = nil
	b.mutex.Unlock()

	client, err := b.factory(b)

	if err != nil {
		b.connect()
		return
	}

	b.mutex.Lock()
	closed := b.closed
	ackID := b.nextAck
	if !closed {
		b.client = client
		b.requests.add(ackID, &hostAuthorizationPromise{b})
		b.nextAck = ackID.Next()
	}
	b.backoff = time.Duration(0)
	b.mutex.Unlock()

	if closed {
		client.Close(nil)
		b.logger(logging.Warning, "Host - Connect after shutdown, review shutdown ordering")
		return
	}

	auth, _ := message.NewAuthHost(ackID, protocol.Current, b.hostname, b.onAuth().Encode())
	b.logger(logging.Trace, "Host [AUTH]", client.Address(), auth)
	bytes, _ := auth.Marshal(protocol.Current)
	if err = client.WriteBinary(bytes); err != nil {
		if promise, ok := b.requests.remove(ackID); ok {
			promise.Set(stream.NoPermissions(), err)
		}
	}
}

func (b *binding) initRequestCallback(err error) {
	b.mutex.Lock()
	b.authed = err == nil
	b.mutex.Unlock()
	b.onInit(err) // sends post-auth signals
	if err != nil {
		b.client.Close(err)
	}
}
