package user

import (
	"context"
	"encoding/json"
	"net"
	"net/http"
	"net/url"
	"sync"
	"time"

	"github.com/gorilla/websocket"
	"github.com/pkg/errors"
	"go.uber.org/zap"
	"golang.org/x/sync/errgroup"

	"code.justin.tv/websocket-edge/server/internal/logs"
	"code.justin.tv/websocket-edge/server/internal/metrics"
	"code.justin.tv/websocket-edge/server/loadtest/loadtestenv"
	"code.justin.tv/websocket-edge/server/protocol"
)

const (
	metricAllMessagesReceived      = "AllMessagesReceived"
	metricConnError                = "ConnError"
	metricConnSuccess              = "ConnSuccess"
	metricConnOpenDuration         = "ConnOpenDuration"
	metricConnCloseError           = "ConnCloseError"
	metricPingErr                  = "PingError"
	metricSendMsgError             = "SendMsgError"
	metricSendMsgSuccess           = "SendMsgSuccess"
	metricReadMsgError             = "ReadMsgError"
	metricMessagesReceived         = "TotalMessagesReceived"
	metricMessageRoundTripDuration = "MessageRoundTripDuration"
)

// The two static message bodies that are sent to the websocket edge.
var (
	firstMsgBody      []byte
	subsequentMsgBody []byte
)

func init() {
	buildStaticMsgBodies()
}

type BasicUser struct {
	log     logs.Logger
	statter metrics.Statter
	wg      *sync.WaitGroup
	userID  string
	http    *http.Client
	conn    *websocket.Conn
}

func (u *BasicUser) connect() error {
	host := loadtestenv.WebsocketEdgeAddress()
	url := url.URL{Scheme: loadtestenv.WSScheme(), Host: host, Path: "/v1/connect"}
	u.log.Info("establishing websocket connection", zap.String("host", host))

	start := time.Now()

	c, _, err := websocket.DefaultDialer.Dial(url.String(), nil)
	if err != nil {
		u.log.Error("dial error", zap.Error(err))
		u.statter.IncrementErr(metricConnError, "DialError")
		return errors.Wrap(err, "Err establishing websocket connection")
	}
	u.statter.Timing(metricConnOpenDuration, time.Since(start))

	u.conn = c
	u.conn.SetPingHandler(func(message string) error {
		err := c.WriteControl(websocket.PongMessage, []byte(message), time.Now().Add(time.Second))
		if err == websocket.ErrCloseSent {
			u.log.Error("wrote pong after close was sent", zap.Error(err))
			u.statter.IncrementErr(metricPingErr, "WritePong")
			return nil
		} else if e, ok := err.(net.Error); ok && e.Temporary() {
			u.log.Error("Temporary error writing pong")
			u.statter.IncrementErr(metricPingErr, "NetworkError")
			return nil
		} else if err != nil {
			u.log.Error("Error writing pong", zap.Error(err))
			u.statter.IncrementErr(metricPingErr, "Error")
		}
		return err
	})

	return nil
}

func (u *BasicUser) readPump(ctx context.Context, numMessages int, totalTimeout time.Duration) error {
	defer func() {
		u.log.Info("readPump shutting down")
	}()

	messagesReceived := 0
	_ = u.conn.SetReadDeadline(time.Now().Add(totalTimeout))
	for {
		_, message, err := u.conn.ReadMessage()
		// context could be cancelled while waiting for this blocking operation
		if ctxErr := ctx.Err(); ctxErr != nil {
			return ctxErr
		} else if err != nil {
			u.log.Error("read message error", zap.Error(err))
			u.statter.IncrementErr(metricReadMsgError, "ReadMessage")
			return err
		}
		u.statter.Increment(metricMessagesReceived)
		messagesReceived += 1

		msg := &protocol.LoadTestMessageBody{}
		err = json.Unmarshal(message, msg)
		if err != nil {
			u.log.Error("Error unmarshalling message", zap.Error(err))
			u.statter.IncrementErr(metricReadMsgError, "Unmarshal")
			return err
		}
		sendTime, err := time.Parse(time.RFC3339, msg.Timestamp)
		if err != nil {
			u.log.Error("invalid timestamp", zap.Error(err))
			u.statter.IncrementErr(metricReadMsgError, "ParseTime")
			return err
		}
		u.statter.Timing(metricMessageRoundTripDuration, time.Since(sendTime))

		if messagesReceived == numMessages {
			u.statter.Increment(metricAllMessagesReceived)
			u.log.Info("all messages received by client")
			return nil
		}
	}
}

func NewBasicUser(log logs.Logger, statter metrics.Statter, wg *sync.WaitGroup, httpClient *http.Client, userID string) *BasicUser {
	return &BasicUser{
		log:     log.With(zap.String("userID", userID)),
		statter: statter,
		http:    httpClient,
		wg:      wg,
		userID:  userID,
	}
}

func (u *BasicUser) sendMessage(ctx context.Context, first bool) {
	var msgBody []byte
	if first {
		msgBody = firstMsgBody
	} else {
		msgBody = subsequentMsgBody
	}

	writeTimeout := 5 * time.Second
	err := u.conn.SetWriteDeadline(time.Now().Add(writeTimeout))
	if err != nil {
		u.log.Error("could not set write deadline", zap.Error(err))
		u.statter.IncrementErr(metricSendMsgError, "SetWriteDeadline")
		return
	}

	err = u.conn.WriteMessage(websocket.TextMessage, msgBody)
	if err != nil {
		u.log.Error("message write error", zap.Error(err))
		u.statter.IncrementErr(metricSendMsgError, "WriteMessage")
		return
	}
	u.statter.Increment(metricSendMsgSuccess)
}

func (u *BasicUser) issueSubscriptionRequests(ctx context.Context, subscriptionInterval time.Duration) {
	t := time.NewTicker(subscriptionInterval)

	if err := ctx.Err(); err != nil {
		return
	}
	u.sendMessage(ctx, true)
	for {
		select {
		case <-t.C:
			if err := ctx.Err(); err != nil {
				return
			}
			u.sendMessage(ctx, false)
		case <-ctx.Done():
			return
		}
	}
}

func someBytes(n int) []byte {
	b := make([]byte, n)
	for i := 0; i < n; i++ {
		b[i] = byte(int('a') + i%26)
	}
	return b
}

func buildStaticMsgBodies() {
	paddingBytes := someBytes(1024)
	firstMsg := &protocol.MockSubscribeMessage{
		First: true,
		Msg:   string(paddingBytes),
	}
	subsequentMsg := &protocol.MockSubscribeMessage{
		First: false,
		Msg:   string(paddingBytes),
	}
	firstMsgBody, _ = json.Marshal(firstMsg)
	subsequentMsgBody, _ = json.Marshal(subsequentMsg)
}

func (u *BasicUser) Simulate(ctx context.Context, numMessages int, totalTimeout time.Duration, subscriptionInterval time.Duration) {
	defer u.wg.Done()

	// Establish websocket connection
	err := u.connect()
	if err != nil {
		u.log.Error("could not connect", zap.Error(err))
		return
	}
	defer func() {
		err := u.conn.Close()
		if err != nil {
			u.statter.IncrementErr(metricConnCloseError, "ConnClose")
		}
	}()

	u.statter.Increment(metricConnSuccess)
	u.log.Info("connection established")

	// Start reading any inbound messages.
	errGroup, errCtx := errgroup.WithContext(ctx)
	errGroup.Go(func() error {
		return u.readPump(errCtx, numMessages, totalTimeout)
	})

	go u.issueSubscriptionRequests(errCtx, subscriptionInterval)

	// Wait for all expected messages to be received
	if err := errGroup.Wait(); err != nil {
		u.log.Error("Received error while waiting for sendMsg goroutines", zap.Error(err))
	}

	// Close websocket connection
	err = u.conn.WriteControl(websocket.CloseMessage, websocket.FormatCloseMessage(websocket.CloseNormalClosure, ""), time.Now().Add(2*time.Second))
	if err != nil {
		u.log.Error("failed to emit close message", zap.Error(err))
		u.statter.IncrementErr(metricConnCloseError, "WriteCloseMsg")
	}
	u.log.Info("All done for this user.")
}
