package websocket

import (
	"io"
	"io/ioutil"
	"net"

	"code.justin.tv/karlpatr/message-prototype/libs/logging"
	"code.justin.tv/karlpatr/message-prototype/libs/session"
	"code.justin.tv/karlpatr/message-prototype/libs/timeout"
	"github.com/gobwas/ws"
	"github.com/gobwas/ws/wsutil"
)

var pingFrame = ws.NewPingFrame([]byte("ping"))

func realloc(buffer []byte, offset, length int) []byte {
	if buffer == nil {
		return make([]byte, offset+length)
	}
	if len(buffer)+offset >= length {
		return buffer
	}
	expanded := make([]byte, offset+length)
	if offset > 0 {
		copy(buffer[offset:], expanded)
	}
	return expanded
}

func execute(binding session.Binding, link *link, state ws.State, settings *Settings, shutdown func() bool) error {
	settings.Logger(logging.Trace, "Opened", link.Address())
	inPing := false
	textPending := false
	src := timeout.NewReader(settings.Timeout, link.conn)
	var r io.Reader
	limitReader := &io.LimitedReader{R: src}
	cipherReader := wsutil.NewCipherReader(nil, [4]byte{0, 0, 0, 0})
	utf8Reader := wsutil.NewUTF8Reader(nil)
	pongBuffer := make([]byte, 8)
	payload := make([]byte, 32)
	read := 0
	for {
		header, err := ws.ReadHeader(src)
		if err != nil {
			if nerr, ok := err.(net.Error); ok && nerr.Timeout() {
				switch settings.OnTimeout {
				case Ping:
					if !inPing {
						if err = link.writeFrame(pingFrame); err != nil {
							link.writeClose(closeForError(shutdown(), err, "Unable to send ping", settings.Logger), "")
						} else {
							inPing = true
							continue
						}
					}
				case Disconnect:
					// pass through to error
				}
				link.writeClose(ws.StatusProtocolError, "Timed out")
			} else {
				link.writeClose(closeForError(shutdown(), err, "Unable to read header:", settings.Logger), "")
			}
			return err
		}

		if err = ws.CheckHeader(header, state); err != nil {
			if !isClosedError(err) {
				link.writeClose(closeForError(shutdown(), err, "Unable to check header:", settings.Logger), "")
			}
			return err
		}

		limitReader.N = header.Length
		if header.Masked {
			cipherReader.Reset(limitReader, header.Mask)
			r = cipherReader
		} else {
			r = limitReader
		}

		length := int(header.Length)
		utf8Fin := false

		switch header.OpCode {
		case ws.OpPing:
			pongBuffer = realloc(pongBuffer, 0, length)
			if _, err = io.ReadAtLeast(r, pongBuffer, length); err != nil {
				settings.Logger(logging.Debug, "Unable to read ping payload", err)
				link.writeClose(ws.StatusUnsupportedData, "Unable to read payload")
				return err
			}
			link.writeFrame(ws.NewPongFrame(pongBuffer[:length]))
			continue

		case ws.OpPong:
			inPing = false
			if _, err = io.CopyN(ioutil.Discard, r, header.Length); err != nil {
				settings.Logger(logging.Debug, "Unable to read pong payload", err)
				link.writeClose(ws.StatusProtocolError, "Unable to read payload")
				return err
			}
			continue

		case ws.OpClose:
			utf8Fin = true

		case ws.OpContinuation:
			if textPending {
				utf8Reader.Source = r
				r = utf8Reader
			}
			if header.Fin {
				state = state.Clear(ws.StateFragmented)
				textPending = false
				utf8Fin = true
			}

		case ws.OpText:
			utf8Reader.Reset(r)
			r = utf8Reader

			if !header.Fin {
				state = state.Set(ws.StateFragmented)
				textPending = true
			} else {
				utf8Fin = true
			}

		case ws.OpBinary:
			if !header.Fin {
				state = state.Set(ws.StateFragmented)
			}
		}

		payload = realloc(payload, read, length)

		// r is limited to at most header.Length above, making this an exact read
		_, err = io.ReadAtLeast(r, payload[read:], length)
		if err == nil && utf8Fin && !utf8Reader.Valid() {
			err = wsutil.ErrInvalidUTF8
		}
		if err != nil {
			settings.Logger(logging.Debug, "Unable to read payload:", err)
			if err == wsutil.ErrInvalidUTF8 {
				link.writeClose(ws.StatusInvalidFramePayloadData, "Invalid UTF8")
			} else {
				link.writeClose(ws.StatusProtocolError, "Unable to read payload")
			}
			return err
		}

		if header.OpCode == ws.OpClose {
			code, reason := ws.ParseCloseFrameData(payload[:length])
			if code.Empty() {
				code = ws.StatusNormalClosure
			}
			if err = ws.CheckCloseFrameData(code, reason); err != nil {
				settings.Logger(logging.Debug, "Invalid close data", err)
				link.writeClose(ws.StatusProtocolError, err.Error())
			} else {
				link.writeClose(code, reason)
			}
			return err
		}

		header.Masked = false
		if state.Is(ws.StateFragmented) {
			read += length
		} else {
			content := payload[:read+length]
			read = 0
			if utf8Fin {
				binding.OnTextMessage(string(content))
			} else {
				binding.OnBinaryMessage(content)
			}
		}
	}
}
