package message

import (
	"fmt"
	"net"

	"code.justin.tv/devhub/e2ml/libs/discovery/protocol"
)

const (
	forwardFlagsOffset   = headerLength
	forwardNetworkOffset = forwardFlagsOffset + 1
)

// Forward wraps a message from a host for peer-to-peer forwarding
type Forward interface {
	protocol.Message
	IsInit() bool
	Source() net.Addr
	Message() protocol.Message
}

type forwardMessage struct {
	init bool
	src  net.Addr
	msg  protocol.Message
}

func NewForward(init bool, src net.Addr, msg protocol.Message) (Forward, error) {
	if src == nil {
		return nil, protocol.ErrInvalidAddress
	}
	if msg == nil {
		return nil, protocol.ErrInvalidHeader
	}
	return &forwardMessage{init, src, msg}, nil
}

func (*forwardMessage) OpCode() protocol.OpCode     { return protocol.Forward }
func (f *forwardMessage) IsInit() bool              { return f.init }
func (f *forwardMessage) Source() net.Addr          { return f.src }
func (f *forwardMessage) Message() protocol.Message { return f.msg }

func (f *forwardMessage) Marshal(ver protocol.Version) ([]byte, error) {
	if !ver.IsValid() {
		return nil, protocol.ErrInvalidVersion
	}
	msg, err := f.msg.Marshal(ver)
	if err != nil {
		return nil, err
	}
	var flags byte
	if f.init {
		flags |= 1
	}
	network := f.src.Network()
	addr := f.src.String()
	addrOffset := forwardNetworkOffset + len(network)
	msgOffset := addrOffset + 1 + len(addr)
	bytes := make([]byte, msgOffset+1+len(msg))
	injectHeader(bytes, ver, f.OpCode())
	bytes[forwardFlagsOffset] = flags
	copy(bytes[forwardNetworkOffset:], network)
	bytes[addrOffset] = '\000'
	copy(bytes[addrOffset+1:], addr)
	bytes[msgOffset] = '\000'
	copy(bytes[msgOffset+1:], msg)
	return bytes, nil
}

func (f *forwardMessage) Unmarshal(ver protocol.Version, bytes []byte) error {
	if !ver.IsValid() {
		return protocol.ErrInvalidVersion
	}
	total := len(bytes)
	if total < forwardNetworkOffset {
		return protocol.ErrInvalidLength(forwardNetworkOffset, total)
	}

	addrOffset := forwardNetworkOffset
	for addrOffset < total {
		if bytes[addrOffset] == '\000' {
			break
		}
		addrOffset++
	}

	if addrOffset == total {
		return protocol.ErrEOF
	}

	msgOffset := addrOffset + 1
	for msgOffset < total {
		if bytes[msgOffset] == '\000' {
			break
		}
		msgOffset++
	}

	if msgOffset == total {
		return protocol.ErrEOF
	}

	flags := bytes[forwardFlagsOffset]
	network := string(bytes[forwardNetworkOffset:addrOffset])
	value := string(bytes[addrOffset+1 : msgOffset])
	msg, err := Unmarshal(bytes[msgOffset+1:])
	if err != nil {
		return err
	}

	f.init = flags&1 == 1
	f.src = NewForwardedAddr(network, value)
	f.msg = msg
	return nil
}

func (f *forwardMessage) String() string {
	return fmt.Sprintf("<%v> source=%v, msg=(%v)", f.OpCode(), f.Source(), f.Message())
}
