package websocket

import (
	"errors"
	"net"
	"sync"
	"sync/atomic"
	"time"
	"unsafe"

	"code.justin.tv/devhub/lib-lifecycle/src/lifecycle"
	"code.justin.tv/devhub/e2ml/libs/logging"
	"code.justin.tv/devhub/e2ml/libs/session"
	"github.com/gobwas/ws"
)

type maskFunc func(ws.Frame) ws.Frame

type link struct {
	*linkSettings
	mgr      lifecycle.Manager
	conn     net.Conn
	causePtr unsafe.Pointer
	exit     sync.Once
}

var (
	_           session.Client = (*link)(nil)
	sentinelErr                = errors.New("Sentinel Error")
	sentinel                   = unsafe.Pointer(&sentinelErr)
)

func newLink(settings *linkSettings, conn net.Conn, mgr lifecycle.Manager) *link {
	l := &link{
		linkSettings: settings,
		mgr:          mgr,
		conn:         conn,
		causePtr:     sentinel,
	}
	mgr.RegisterHook(l, func() error {
		return filterClosed(l.writeClose(closeForError(l, true, l.cause())))
	})
	return l
}

func (l *link) Address() net.Addr                      { return l.conn.RemoteAddr() }
func (l *link) WriteText(content string) error         { return l.writeFrame(ws.NewTextFrame([]byte(content))) }
func (l *link) WriteBinaryAsText(content []byte) error { return l.writeFrame(ws.NewTextFrame(content)) }
func (l *link) WriteBinary(content []byte) error       { return l.writeFrame(ws.NewBinaryFrame(content)) }
func (l *link) Close(cause error) {
	actual := filterClosed(cause)
	atomic.CompareAndSwapPointer(&l.causePtr, sentinel, unsafe.Pointer(&actual))
	l.mgr.ExecuteHook(l)
}

func (l *link) writeClose(code ws.StatusCode, reason string) error {
	var err error
	l.exit.Do(func() {
		err = l.writeFrame(ws.NewCloseFrame(ws.NewCloseFrameBody(code, reason)))
	})
	return err
}

func (l *link) writeFrame(frame ws.Frame) error {
	body, err := ws.CompileFrame(l.mask(frame))
	if err != nil {
		return err
	}
	if err = l.conn.SetWriteDeadline(time.Now().Add(l.timeout.Next())); err == nil {
		_, err = l.conn.Write(body) // is thread safe (https://golang.org/pkg/net/#Conn)
	}
	if err != nil && (frame.Header.OpCode != ws.OpClose || filterClosed(err) != nil) {
		l.logger(logging.Debug, "Write error", l.Address(), frame.Header.OpCode, err)
	}
	return err
}

func (l *link) cause() error {
	var cause error
	if temp := (atomic.LoadPointer(&l.causePtr)); temp != nil && temp != sentinel {
		cause = *(*error)(temp)
	}
	return cause
}

func (l *link) gracefulClose() {
	err := filterClosed(l.conn.Close())
	atomic.CompareAndSwapPointer(&l.causePtr, sentinel, unsafe.Pointer(&err))
}
