package wsconnection

import (
	"context"
	"net"
	"sync"
	"time"

	"github.com/gorilla/websocket"
	"go.uber.org/zap"

	"code.justin.tv/websocket-edge/server/internal/gqlsubs"
	"code.justin.tv/websocket-edge/server/internal/logs"
	"code.justin.tv/websocket-edge/server/internal/metrics"
	"code.justin.tv/websocket-edge/server/protocol"
)

const (
	// wait time for message sends to succeed.
	writeTimeout = 10 * time.Second
	// close connections where we haven't received a ping in `idleTimeout`.
	idleTimeout = 70 * time.Second
	// How often we ping clients.
	pingPeriod = 30 * time.Second
	// Max size of inbound message, in bytes.
	maxInboundMessageSize = 40 * 1024
	// Max number of messages queued in send buffer.
	sendBufferSize = 200
	// How long we wait for an HTTP request to GQL Subs to succeed.
	gqlSubsMsgSendTimeout = 1 * time.Second
)

const (
	metricErr                   = "ConnError"
	metricBinaryMessageReceived = "BinaryMessageReceived"
	metricSendBufferClosed      = "SendBufferClosed"
	metricSendBufferFull        = "SendBufferFull"
	metricSubsSendBufferClosed  = "SubsSendBufferClosed"
	metricSubsSendBufferFull    = "SubsSendBufferFull"
	metricClose                 = "ConnClose"
	metricMessageSend           = "MessageSend"
	metricMessageSize           = "MessageSize"
	metricInboundMessageSize    = "InboundMessageSize"
	metricSecondsSinceLastPong  = "SecondsSinceLastPong"
)

type Connection interface {
	GracefulClose(wg *sync.WaitGroup)
	Forward(wm protocol.ServiceToClientMessage) error
}

type connection struct {
	Conn             Conn
	sendBuffer       chan []byte
	subsSendBuffer   chan []byte
	shutdown         chan interface{}
	ID               string
	clientIP         string
	cleanupOnce      *sync.Once
	lastPongReceived time.Time
	cleanupHook      func()
	logger           logs.Logger
	statter          metrics.Statter
	subsClient       gqlsubs.Client
}

func (c *connection) cleanup() {
	c.cleanupOnce.Do(func() {
		close(c.shutdown)
		close(c.sendBuffer)

		err := c.Conn.Close()
		if err != nil {
			c.logger.Error("Error closing connection in cleanup", zap.Error(err))
			c.statter.IncrementErr(metricErr, "CloseConn")
		} else {
			c.statter.Increment(metricClose)
		}

		c.cleanupHook()
	})
}

func (c *connection) GracefulClose(wg *sync.WaitGroup) {
	defer wg.Done()
	c.writeClose()
	c.cleanup()
}

func (c *connection) writeClose() {
	err := c.Conn.WriteControl(websocket.CloseMessage, websocket.FormatCloseMessage(websocket.CloseNormalClosure, ""), time.Now().Add(writeTimeout))
	if err != nil {
		c.statter.Increment(metricErr, "SendCloseMsg")
	}
}

func (c *connection) readPump() {
	defer func() {
		c.logger.Info("Shutting down connection in readPump")
		c.cleanup()
	}()

	c.Conn.SetReadLimit(maxInboundMessageSize)
	err := c.Conn.SetReadDeadline(time.Now().Add(idleTimeout))
	if err != nil {
		c.logger.Error("could not set read deadline", zap.Error(err))
		c.statter.IncrementErr(metricErr, "SetReadDeadline")
		return
	}

	// Reset the idle timeout on pong messages.
	c.Conn.SetPongHandler(func(string) error {
		c.lastPongReceived = time.Now()
		err := c.Conn.SetReadDeadline(time.Now().Add(idleTimeout))
		if err != nil {
			c.logger.Error("could not set read deadline", zap.Error(err))
			c.statter.IncrementErr(metricErr, "SetReadDeadline")
			c.cleanup()
		}
		return nil
	})

	// Attempting to read messages will time out if we don't receive a PONG from the client.
	// When this occurs, close the connection and clean up resources.
	errChan := make(chan error, 1)

	// Kick off a reader goroutine. We forward any errors to `errChan`. When the connection dies gracefully or otherwise
	// `ReadMessage` will return an error and we will end this goroutine.
	go func() {
		for {
			msgType, msgBody, err := c.Conn.ReadMessage()
			if err != nil {
				errChan <- err
				return
			}

			c.statter.Size(metricInboundMessageSize, len(msgBody))
			if msgType == websocket.TextMessage {
				// forward the message to graphql subscriptions
				select {
				case c.subsSendBuffer <- msgBody:
				default:
					// If the buffer is full, assume that the client is disconnected, and begin cleanup
					c.statter.Increment(metricSubsSendBufferFull)
					go c.cleanup()
					return
				}
			} else {
				// The returned message from ReadMessage is either a text message or binary message if err == nil.
				c.statter.Increment(metricBinaryMessageReceived)
				c.logger.Warn("unexpected binary message")
			}
		}
	}()

	for {
		select {
		case <-c.shutdown:
			c.statter.Increment(metricClose)
			return
		case err = <-errChan:
			if err != nil {
				if _, ok := err.(*websocket.CloseError); ok {
					c.statter.Increment(metricClose)
					if websocket.IsUnexpectedCloseError(err, websocket.CloseNormalClosure, websocket.CloseGoingAway) {
						c.statter.IncrementErr(metricErr, "CloseErr")
						c.logger.Error("close error", zap.Error(err))
					}
				} else {
					c.statter.IncrementErr(metricErr, "ReadMessage")
					c.statter.Timing(metricSecondsSinceLastPong, time.Since(c.lastPongReceived))
					c.logger.Error("read message error", zap.Error(err))
				}
				return
			}
		}
	}
}

func (c *connection) subsSender() {
	ctx := context.Background()
	for {
		select {
		case msgBody, ok := <-c.subsSendBuffer:
			if !ok {
				c.statter.Increment(metricSubsSendBufferClosed)
				return
			}

			ctx, cancel := context.WithTimeout(ctx, gqlSubsMsgSendTimeout)
			err := c.subsClient.SendMessage(ctx, c.ID, c.clientIP, msgBody)
			if err != nil {
				// We just log errors right now, and don't take any other action(e.g. kill the connection).
				c.logger.Error("error sending subs message", zap.Error(err))
			}
			cancel()
		case <-c.shutdown:
			return
		}
	}
}

func (c *connection) writePump() {
	ticker := time.NewTicker(pingPeriod)
	defer func() {
		c.logger.Info("Shutting down connection in writePump")
		ticker.Stop()
		c.cleanup()
	}()

	logExit := func() {
		c.logger.Info("write pump closed.")
		c.statter.Increment(metricSendBufferClosed)
	}

	for {
		select {
		case <-c.shutdown:
			logExit()
			return

		case message, ok := <-c.sendBuffer:
			if !ok {
				logExit()
				return
			}

			err := c.Conn.SetWriteDeadline(time.Now().Add(writeTimeout))
			if err != nil {
				c.logger.Error("could not set write deadline", zap.Error(err))
				c.statter.IncrementErr(metricErr, "SetWriteDeadline")
				return
			}

			// Forward the message on to the client here. If sending fails or times out, assume disconnect. Stop sending and close connection.
			err = c.Conn.WriteMessage(websocket.TextMessage, message)
			if err != nil && err != websocket.ErrCloseSent {
				c.logger.Error("message write error", zap.Error(err))
				c.statter.IncrementErr(metricErr, "WriteMessage")
				return
			}
			c.statter.Increment(metricMessageSend)
			c.statter.Size(metricMessageSize, len(message))
		}
	}
}

func (c connection) heartbeat() {
	ticker := time.NewTicker(pingPeriod)
	defer func() {
		c.logger.Info("stopping heartbeats")
		ticker.Stop()

		c.cleanup()
	}()

	for {
		select {
		case <-c.shutdown:
			return
		case <-ticker.C:
			err := c.Conn.SetWriteDeadline(time.Now().Add(writeTimeout))
			if err != nil {
				c.logger.Error("could not set write deadline", zap.Error(err))
				c.statter.IncrementErr(metricErr, "SetWriteDeadline")
				return
			}

			err = c.Conn.WriteControl(websocket.PingMessage, []byte{}, time.Now().Add(writeTimeout))
			if err != nil && err != websocket.ErrCloseSent {
				c.logger.Error("could not send ping", zap.Error(err))
				c.statter.IncrementErr(metricErr, "SendPing")
				return
			}
		}
	}

}

func New(
	conn Conn,
	sessID string,
	clientIP net.IP,
	cleanupHook func(),
	logger logs.Logger,
	statter metrics.Statter,
	subsClient gqlsubs.Client) *connection {
	c := &connection{
		Conn:           conn,
		sendBuffer:     make(chan []byte, sendBufferSize),
		shutdown:       make(chan interface{}),
		ID:             sessID,
		clientIP:       clientIP.String(),
		cleanupOnce:    &sync.Once{},
		cleanupHook:    cleanupHook, // The cleanup hook is responsible for removing references to this connection such that no additional calls to sendMessage are made.
		logger:         logger,
		statter:        statter,
		subsSendBuffer: make(chan []byte, sendBufferSize),
		subsClient:     subsClient,
	}

	// There's a number of operations happening in parallel per websocket connection
	// - writing messages
	// - sending websocket PINGs
	// - receiving websocket PONGs
	// When writing or reading over the websocket fails, all of these goroutines should be stopped via connection.cleanup()
	go c.writePump()
	go c.heartbeat()
	go c.readPump()
	go c.subsSender()

	return c
}

func (c *connection) sendMessage(message []byte) {
	select {
	case c.sendBuffer <- message:
	default:
		// If the buffer is full, assume that the client is disconnected, and begin cleanup
		c.statter.Increment(metricSendBufferFull)
		// TODO: think about this re: scaling. What if writes are coming in quickly and the buffer is full, and we spin up many cleanup goroutines?
		go c.cleanup()
	}
}

func (c *connection) Forward(wm protocol.ServiceToClientMessage) error {
	c.sendMessage([]byte(wm.Body))
	return nil
}
