package mock_service

import (
	"bytes"
	"context"
	"encoding/json"
	"errors"
	"io"
	"io/ioutil"
	"net/http"
	"time"

	"golang.org/x/net/context/ctxhttp"

	"code.justin.tv/websocket-edge/server/internal/logs"
	"code.justin.tv/websocket-edge/server/internal/metrics"
	"code.justin.tv/websocket-edge/server/loadtest/loadtestenv"
	"code.justin.tv/websocket-edge/server/protocol"

	"golang.org/x/sync/errgroup"

	"go.uber.org/zap"
)

var ok = []byte("OK")

const (
	metricPayloadIngestErr      = "GQLSubsPayloadIngestError"
	metricPayloadIngestSuccess  = "GQLSubsPayloadIngestSuccess"
	metricPayloadIngestIgnore   = "GQLSubsPayloadIngestIgnore"
	metricMessageSendErr        = "GQLSubsMessageSendError"
	metricMessageSendSuccess    = "GQLSubsMessageSendSuccess"
	metricAllMessageSendSuccess = "GQLSubsAllMessageSendSuccess"
	metricAllMessageSendErr     = "GQLSubsAllMessageSendErr"
	metricHealthErr             = "GQLSubsHealthErr"
)

var msgBody string

func someBytes(n int) []byte {
	b := make([]byte, n)
	for i := 0; i < n; i++ {
		b[i] = byte(int('a') + i%26)
	}
	return b
}

func init() {
	msgBody = string(someBytes(1024))
}

type handler struct {
	logger          logs.Logger
	statter         metrics.Statter
	httpClient      *http.Client
	mux             *http.ServeMux
	numMessages     int
	messageInterval time.Duration
}

func NewHandler(logger logs.Logger, statter metrics.Statter, httpClient *http.Client, numMessages int, messageInterval time.Duration) *handler {
	h := &handler{
		logger:          logger,
		statter:         statter,
		httpClient:      httpClient,
		mux:             http.NewServeMux(),
		numMessages:     numMessages,
		messageInterval: messageInterval,
	}

	h.handle("/health", h.health)
	h.handle("/message", h.payloadIngest)
	return h
}

func (h *handler) health(w http.ResponseWriter, _ *http.Request) {
	_, err := w.Write(ok)
	if err != nil {
		h.logger.Error("health error", zap.Error(err))
		h.statter.Increment(metricHealthErr)
	}
}

func (h *handler) ServeHTTP(w http.ResponseWriter, r *http.Request) {
	h.mux.ServeHTTP(w, r)
}

func (h *handler) handle(path string, f http.HandlerFunc) {
	h.mux.HandleFunc(path+"/", f)
	h.mux.HandleFunc(path, f)
}

func (h *handler) buildMsg(sessionID string) ([]byte, error) {
	var msgJSON []byte

	now, err := time.Now().MarshalText()
	if err != nil {
		h.logger.Error("timestamp generation failed", zap.Error(err))
		h.statter.IncrementErr(metricMessageSendErr, "TimestampGen")
		return msgJSON, err
	}

	msgBody, err := json.Marshal(protocol.LoadTestMessageBody{
		Timestamp: string(now),
		Msg:       msgBody,
	})
	if err != nil {
		h.logger.Error("marshal failed", zap.Error(err))
		h.statter.IncrementErr(metricMessageSendErr, "MarshalJSON")
		return msgJSON, err
	}

	m := protocol.ServiceToClientMessage{
		SessionID: sessionID,
		Body:      string(msgBody),
	}
	msgJSON, err = json.Marshal(m)
	if err != nil {
		h.logger.Error("marshal failed", zap.Error(err))
		h.statter.IncrementErr(metricMessageSendErr, "MarshalJSON")
		return msgJSON, err
	}

	return msgJSON, nil
}

func (h *handler) sendMessage(ctx context.Context, addr string, sessionID string) error {
	if err := ctx.Err(); err != nil {
		return err
	}

	msgJSON, err := h.buildMsg(sessionID)
	if err != nil {
		// Error is logged in buildMsg
		return err
	}

	rCtx, cancel := context.WithTimeout(ctx, 5*time.Second)
	defer cancel()

	resp, err := ctxhttp.Post(rCtx, h.httpClient, addr, "application/json", bytes.NewBuffer(msgJSON))
	if err != nil {
		h.logger.Error("post failed", zap.Error(err))
		h.statter.IncrementErr(metricMessageSendErr, "Network")
		return err
	}
	defer func() {
		_, err := io.Copy(ioutil.Discard, resp.Body)
		if err != nil {
			h.statter.IncrementErr(metricMessageSendErr, "DiscardResponseBody")
		}
		err = resp.Body.Close()
		if err != nil {
			h.statter.IncrementErr(metricMessageSendErr, "CloseResponseBody")
		}
	}()

	if resp.StatusCode != 200 {
		h.logger.Error("bad response from websocket edge", zap.Error(err))
		h.statter.IncrementErr(metricMessageSendErr, "BadResp")
		return errors.New("bad response from websocket edge")
	}

	h.statter.Increment(metricMessageSendSuccess)
	return nil
}

func (h *handler) sendMessages(addr string, sessionID string) {
	ticker := time.NewTicker(h.messageInterval)

	errGroup, msgCtx := errgroup.WithContext(context.Background())
	for i := 0; i < h.numMessages; i++ {
		select {
		case <-ticker.C:
			errGroup.Go(func() error {
				return h.sendMessage(msgCtx, addr, sessionID)
			})
		case <-msgCtx.Done():
			break
		}
	}

	err := errGroup.Wait()
	if err != nil {
		h.logger.Error("failed sending messages to client", zap.Error(err))
		h.statter.Increment(metricAllMessageSendErr)
	} else {
		h.logger.Info("all messages sent to client")
		h.statter.Increment(metricAllMessageSendSuccess)
	}
}

func (h *handler) payloadIngest(w http.ResponseWriter, r *http.Request) {
	_, _ = w.Write(ok)

	var receivedMsg protocol.ClientToServiceMessage
	decoder := json.NewDecoder(r.Body)
	err := decoder.Decode(&receivedMsg)
	if err != nil {
		h.statter.IncrementErr(metricPayloadIngestErr, "DecodeJSON")
		h.logger.Warn("failed to ingest payload", zap.Error(err))
		return
	}

	var subsMsg protocol.MockSubscribeMessage
	err = json.Unmarshal([]byte(receivedMsg.Body), &subsMsg)
	if err != nil {
		h.statter.IncrementErr(metricPayloadIngestErr, "DecodeSubsJSON")
		h.logger.Warn("failed to ingest payload", zap.Error(err))
		return
	}

	h.statter.Increment(metricPayloadIngestSuccess)

	if !subsMsg.First {
		h.statter.Increment(metricPayloadIngestIgnore)
		return
	}

	sendMsgAddr := loadtestenv.ForwardAddress(receivedMsg.HostAddress)

	// Send messages asynchronously, but immediately respond with 200
	go h.sendMessages(sendMsgAddr, receivedMsg.SessionID)
}
