package websocket

import (
	"errors"
	"fmt"
	"io"
	"io/ioutil"
	"net"

	"code.justin.tv/devhub/e2ml/libs/logging"
	"code.justin.tv/devhub/e2ml/libs/session"
	"code.justin.tv/devhub/e2ml/libs/timeout"
	"github.com/gobwas/ws"
	"github.com/gobwas/ws/wsutil"
)

var (
	pingFrame          = ws.NewPingFrame([]byte("ping"))
	statusRawError     = ws.StatusRangePrivate.Min + 1
	statusEncodedError = ws.StatusRangePrivate.Min + 2
)

func realloc(buffer []byte, offset, length int) []byte {
	if buffer == nil {
		return make([]byte, offset+length)
	}
	if len(buffer)+offset >= length {
		return buffer
	}
	expanded := make([]byte, offset+length)
	if offset > 0 {
		copy(buffer[offset:], expanded)
	}
	return expanded
}

func execute(binding session.Binding, link *link, state ws.State, shutdown func() bool) error {
	link.logger(logging.Trace, "Opened", link.Address())
	inPing := false
	textPending := false
	src := timeout.NewReader(link.timeout, link.conn)
	var r io.Reader
	limitReader := &io.LimitedReader{R: src}
	cipherReader := wsutil.NewCipherReader(nil, [4]byte{0, 0, 0, 0})
	utf8Reader := wsutil.NewUTF8Reader(nil)
	pongBuffer := make([]byte, 8)
	payload := make([]byte, 32)
	read := 0

	for {
		header, err := ws.ReadHeader(src)
		if err != nil {
			var nerr net.Error
			if errors.As(err, &nerr) && nerr.Timeout() {
				switch link.onTimeout {
				case Ping:
					if !inPing {
						if err = link.writeFrame(pingFrame); err != nil {
							link.writeClose(closeForError(link, shutdown(), fmt.Errorf("Unable to send ping: %w", err)))
						} else {
							inPing = true
							continue
						}
					}
				case Disconnect:
					// pass through to error
				}
				link.writeClose(ws.StatusProtocolError, "Timed out")
			} else {
				link.writeClose(closeForError(link, shutdown(), fmt.Errorf("Unable to read header: %w", err)))
			}
			return err
		}

		if err = ws.CheckHeader(header, state); err != nil {
			if !isClosedError(err) {
				link.writeClose(closeForError(link, shutdown(), fmt.Errorf("Unable to check header: %w", err)))
			}
			return err
		}

		limitReader.N = header.Length
		if header.Masked {
			cipherReader.Reset(limitReader, header.Mask)
			r = cipherReader
		} else {
			r = limitReader
		}

		length := int(header.Length)
		utf8Fin := false

		if link.maxMsgLength != 0 && read+length > link.maxMsgLength {
			link.writeClose(ws.StatusMessageTooBig, "Message too large")
			return errors.New("Write length exceeded") // TODO : convert to structured error
		}

		switch header.OpCode {
		case ws.OpPing:
			pongBuffer = realloc(pongBuffer, 0, length)
			if _, err = io.ReadAtLeast(r, pongBuffer, length); err != nil {
				link.logger(logging.Debug, "Unable to read ping payload", err)
				link.writeClose(ws.StatusUnsupportedData, "Unable to read payload")
				return err
			}
			link.writeFrame(ws.NewPongFrame(pongBuffer[:length]))
			continue

		case ws.OpPong:
			inPing = false
			if _, err = io.CopyN(ioutil.Discard, r, header.Length); err != nil {
				link.logger(logging.Debug, "Unable to read pong payload", err)
				link.writeClose(ws.StatusProtocolError, "Unable to read payload")
				return err
			}
			continue

		case ws.OpClose:
			utf8Fin = true

		case ws.OpContinuation:
			if textPending {
				utf8Reader.Source = r
				r = utf8Reader
			}
			if header.Fin {
				state = state.Clear(ws.StateFragmented)
				textPending = false
				utf8Fin = true
			}

		case ws.OpText:
			utf8Reader.Reset(r)
			r = utf8Reader

			if !header.Fin {
				state = state.Set(ws.StateFragmented)
				textPending = true
			} else {
				utf8Fin = true
			}

		case ws.OpBinary:
			if !header.Fin {
				state = state.Set(ws.StateFragmented)
			}
		}

		payload = realloc(payload, read, length)

		// r is limited to at most header.Length above, making this an exact read
		_, err = io.ReadAtLeast(r, payload[read:], length)
		if err == nil && utf8Fin && !utf8Reader.Valid() {
			err = wsutil.ErrInvalidUTF8
		}
		if err != nil {
			link.logger(logging.Debug, "Unable to read payload:", err)
			if err == wsutil.ErrInvalidUTF8 {
				link.writeClose(ws.StatusInvalidFramePayloadData, "Invalid UTF8")
			} else {
				link.writeClose(ws.StatusProtocolError, "Unable to read payload")
			}
			return err
		}

		if header.OpCode == ws.OpClose {
			code, reason := ws.ParseCloseFrameData(payload[:length])
			if code.Empty() {
				code = ws.StatusNormalClosure
			}
			if err = ws.CheckCloseFrameData(code, reason); err != nil {
				link.logger(logging.Debug, "Invalid close data", err)
				link.writeClose(ws.StatusProtocolError, err.Error())
				return err
			}
			link.writeClose(code, reason) // acknowledge receipt
			// TODO : create standard strutured errors for these cases
			switch code {
			case ws.StatusNormalClosure | ws.StatusGoingAway:
				// no error
			case ws.StatusProtocolError | ws.StatusUnsupportedData | ws.StatusPolicyViolation:
				err = errors.New("Invalid message content")
			case ws.StatusInvalidFramePayloadData | ws.StatusMessageTooBig:
				err = errors.New("Invalid message payload")
			case ws.StatusMandatoryExt:
				err = errors.New("Missing websocket extension")
			case ws.StatusInternalServerError:
				err = errors.New("Internal server error")
			case statusRawError:
				err = errors.New(reason)
			case statusEncodedError:
				var parsed error
				if perr := link.errorMap.Unmarshal([]byte(reason), &parsed); perr != nil {
					err = errors.New(reason)
				} else {
					err = parsed
				}
			}
			return err
		}

		header.Masked = false
		if state.Is(ws.StateFragmented) {
			read += length
		} else {
			content := payload[:read+length]
			read = 0
			if utf8Fin {
				binding.OnTextMessage(string(content))
			} else {
				binding.OnBinaryMessage(content)
			}
		}
	}
}

func closeForError(link *link, shutdown bool, err error) (ws.StatusCode, string) {
	if err == nil || err.Error() == "EOF" || isClosedError(err) {
		if !shutdown {
			return ws.StatusNormalClosure, ""
		}
		return ws.StatusGoingAway, ""
	}
	if bytes, issue := link.errorMap.Marshal(err); issue == nil {
		// application close code hints we want to unmarshal
		return statusEncodedError, string(bytes)
	}
	return statusRawError, err.Error()
}

func isClosedError(err error) bool {
	cast := new(net.OpError)
	return errors.As(err, &cast) && cast.Err.Error() == "use of closed network connection"
}

func filterClosed(err error) error {
	if isClosedError(err) {
		return nil
	}
	return err
}
