package websocket

import (
	"io"
	"io/ioutil"
	"net"
	"time"

	"code.justin.tv/devhub/gdaas-ingest/libs/logging"
	"code.justin.tv/devhub/gdaas-ingest/libs/session"
	"github.com/gobwas/ws"
	"github.com/gobwas/ws/wsutil"
)

func realloc(buffer []byte, offset, length int) []byte {
	if buffer == nil {
		return make([]byte, offset+length)
	}
	if len(buffer)+offset >= length {
		return buffer
	}
	expanded := make([]byte, offset+length)
	if offset > 0 {
		copy(buffer[offset:], expanded)
	}
	return expanded
}

func tryPing(link *link) bool {
	err := link.writeFrame(ws.NewPingFrame([]byte("ping")))
	if err == nil {
		return true
	}
	nerr, ok := err.(net.Error)
	return ok && nerr.Temporary()
}

func execute(binding session.Binding, link *link, state ws.State, settings *Settings, shutdown func() bool) error {
	settings.Logger(logging.Trace, "Opened", link.Address())
	inPing := false
	textPending := false
	utf8Reader := wsutil.NewUTF8Reader(nil)
	cipherReader := wsutil.NewCipherReader(nil, [4]byte{0, 0, 0, 0})
	var pongBuffer []byte
	payload := make([]byte, 32)
	read := 0

	for {
		link.conn.SetReadDeadline(time.Now().Add(settings.Timeout.Next()))
		header, err := ws.ReadHeader(link.conn)
		if err != nil {
			if nerr, ok := err.(net.Error); ok && nerr.Timeout() {
				if !inPing && tryPing(link) {
					inPing = true
					continue
				}
				link.writeClose(ws.StatusProtocolError, "Timed out")
			} else {
				link.writeClose(closeForError(shutdown(), err, "Unable to read header", settings.Logger), "")
			}
			return err
		}

		if err = ws.CheckHeader(header, state); err != nil {
			if !isClosedError(err) {
				link.writeClose(closeForError(shutdown(), err, "Unable to check header", settings.Logger), "")
			}
			return err
		}

		cipherReader.Reset(
			io.LimitReader(link.conn, header.Length),
			header.Mask,
		)

		var utf8Fin bool
		var r io.Reader = cipherReader

		length := int(header.Length)

		switch header.OpCode {
		case ws.OpPing:
			header.OpCode = ws.OpPong
			if state.ClientSide() {
				// mask and send pong
				pongBuffer = realloc(pongBuffer, 0, length)
				cipherReader.Read(pongBuffer[:length])
				link.writeFrame(ws.NewPongFrame(pongBuffer[:length]))
			} else {
				// optimized : send pong directly
				header.Masked = false
				ws.WriteHeader(link.conn, header)
				io.CopyN(link.conn, cipherReader, header.Length)
			}
			continue

		case ws.OpPong:
			inPing = false
			_, _ = io.CopyN(ioutil.Discard, link.conn, header.Length)
			continue

		case ws.OpClose:
			utf8Fin = true

		case ws.OpContinuation:
			if textPending {
				utf8Reader.Source = cipherReader
				r = utf8Reader
			}
			if header.Fin {
				state = state.Clear(ws.StateFragmented)
				textPending = false
				utf8Fin = true
			}

		case ws.OpText:
			utf8Reader.Reset(cipherReader)
			r = utf8Reader

			if !header.Fin {
				state = state.Set(ws.StateFragmented)
				textPending = true
			} else {
				utf8Fin = true
			}

		case ws.OpBinary:
			if !header.Fin {
				state = state.Set(ws.StateFragmented)
			}
		}

		payload = realloc(payload, read, length)

		// r is limited to at most header.Length above, making this an exact read
		_, err = io.ReadAtLeast(r, payload[read:], length)
		if err == nil && utf8Fin && !utf8Reader.Valid() {
			err = wsutil.ErrInvalidUTF8
		}
		if err != nil {
			settings.Logger(logging.Debug, "Unable to read payload", err)
			if err == wsutil.ErrInvalidUTF8 {
				link.writeClose(ws.StatusInvalidFramePayloadData, "")
			} else {
				link.writeClose(ws.StatusProtocolError, "")
			}
			return err
		}

		if header.OpCode == ws.OpClose {
			code, reason := ws.ParseCloseFrameData(payload[:length])
			if code.Empty() {
				code = ws.StatusNoStatusRcvd
			}
			if err = ws.CheckCloseFrameData(code, reason); err != nil {
				settings.Logger(logging.Debug, "Invalid close data", err)
				link.writeClose(ws.StatusProtocolError, err.Error())
			} else {
				link.writeClose(code, reason)
			}
			return err
		}

		header.Masked = false
		if state.Is(ws.StateFragmented) {
			read += length
		} else {
			// If the binding implemented rate limit checker
			// first check if the connection reaches its allowed rate
			if rateLimiter, ok := binding.(session.RateLimiter); ok {
				rateLimiter.RateLimitChecker()
			}
			content := payload[:read+length]
			read = 0
			if utf8Fin {
				binding.OnTextMessage(string(content))
			} else {
				binding.OnBinaryMessage(content)
			}
		}
	}
}
