package peering

import (
	"encoding/json"
	"net"
	"net/url"
	"sync"
	"sync/atomic"
	"time"

	"code.justin.tv/devhub/lib-lifecycle/src/lifecycle"
	"code.justin.tv/devhub/e2ml/libs/logging"
	"code.justin.tv/devhub/e2ml/libs/session"
)

const (
	sent      uint32 = 1 << 0
	recv      uint32 = 1 << 1
	appending uint32 = 1 << 2
	closed    uint32 = 1 << 4
)

type proxy struct {
	manager    *manager
	client     session.Client
	inner      session.Binding
	queue      queue
	name       string
	isClient   bool
	retryDelay time.Duration
	status     uint32
	mutex      sync.RWMutex
}

type proxyFactory struct {
	manager  *manager
	bFactory session.BindingFactory
}

type handshake struct {
	Name string `json:"peer_name"`
}

var _ session.Binding = (*proxy)(nil)
var _ Peer = (*proxy)(nil)

func newProxyFactory(
	manager *manager,
	bFactory session.BindingFactory,
) *proxyFactory {
	return &proxyFactory{manager, bFactory}
}

func (p *proxyFactory) newProxy(client session.Client, retryDelay time.Duration) *proxy {
	val := &proxy{manager: p.manager, client: client, retryDelay: retryDelay, isClient: client == nil}
	if val.isClient { // queue initial messages until after handshake
		val.client = val
	}
	binding := p.bFactory(val)
	val.mutex.Lock()
	val.inner = binding
	val.mutex.Unlock()
	return val
}

// connect to a remote peer + handshake
func (p *proxyFactory) create(target *url.URL, delay time.Duration) error {
	val := p.newProxy(nil, delay)
	client, err := p.manager.cFactory(target, val)
	if err != nil { // failed to connect, mark the wrapper closed
		val.inner.OnClosed(err)
		return err
	}
	val.mutex.Lock()
	val.client = client
	val.mutex.Unlock()
	return val.handshake()
}

// accept remote peer connections + handshake
func (p *proxyFactory) AsBindingFactory() session.BindingFactory {
	return func(client session.Client) session.Binding {
		val := p.newProxy(client, time.Duration(0))
		if err := val.handshake(); err != nil {
			p.manager.logger(logging.Debug, "Error during service side handshake", err)
		}
		return val
	}
}
func (p *proxy) Binding() session.Binding {
	return p.inner
}

func (p *proxy) OnTextMessage(value string) {
	status := atomic.LoadUint32(&p.status)
	if status&recv == 0 { // intercept our first text message to get name
		var hs handshake
		if err := json.Unmarshal([]byte(value), &hs); err != nil || hs.Name == "" {
			p.Close(ErrInvalidHandshake)
		}
		p.mutex.Lock()
		p.name = hs.Name
		p.mutex.Unlock()
		for !atomic.CompareAndSwapUint32(&p.status, status, status|recv) {
			status = atomic.LoadUint32(&p.status)
		}
		p.manager.onConnected(p)
		return
	}
	p.inner.OnTextMessage(value)
}

func (p *proxy) OnBinaryMessage(bytes []byte) {
	if atomic.LoadUint32(&p.status)&recv == recv { // require name receipt
		p.inner.OnBinaryMessage(bytes)
	} else {
		p.Close(ErrMissingHandshake)
	}
}

func (p *proxy) OnClosed(err error) {
	for {
		status := atomic.LoadUint32(&p.status)
		if atomic.CompareAndSwapUint32(&p.status, status, status|closed) {
			break
		}
	}
	p.inner.OnClosed(err)
	p.manager.onClose(p, err)
}

func (p *proxy) RemoteName() string {
	p.mutex.RLock()
	name := p.name
	p.mutex.RUnlock()
	return name
}

func (p *proxy) Address() net.Addr {
	p.mutex.RLock()
	addr := dummyAddr
	if p.client != nil && p.client != p {
		addr = p.client.Address()
	}
	p.mutex.RUnlock()
	return addr
}

func (p *proxy) WriteText(value string) error {
	return p.writeOrQueue([]byte(value), p.client.WriteBinaryAsText)
}

func (p *proxy) WriteBinaryAsText(bytes []byte) error {
	return p.writeOrQueue(bytes, p.client.WriteBinaryAsText)
}

func (p *proxy) WriteBinary(bytes []byte) error {
	return p.writeOrQueue(bytes, p.client.WriteBinary)
}

func (p *proxy) Close(err error) {
	p.client.Close(err)
}

func (p *proxy) writeOrQueue(bytes []byte, write func([]byte) error) error {
	for {
		status := atomic.LoadUint32(&p.status)
		if status&closed == closed {
			return ErrPeerClosed
		}
		if status&sent == sent {
			return write(bytes)
		}
		for status&appending == appending { // TODO : replace with condition variable
			continue // spin while another thread appends
		}
		if atomic.CompareAndSwapUint32(&p.status, status, status|appending) {
			p.queue.Append(bytes, write)
			for {
				status := atomic.LoadUint32(&p.status)
				if atomic.CompareAndSwapUint32(&p.status, status, status&^appending) {
					return nil
				}
			}
		}
	}
}

func (p *proxy) handshake() error {
	bytes, _ := json.Marshal(&handshake{p.manager.LocalName()})
	err := p.client.WriteBinaryAsText(bytes) // bypass sent check
	status := atomic.LoadUint32(&p.status)
	for !atomic.CompareAndSwapUint32(&p.status, status, status|sent) {
		status = atomic.LoadUint32(&p.status)
	}
	for status&appending == appending { // spin while another thread appends
		status = atomic.LoadUint32(&p.status)
	}
	for _, queued := range p.queue.Flush() {
		err = lifecycle.CombineErrors(err, queued.write(queued.msg))
	}
	return err
}
