package jobrunner

import (
	"code.justin.tv/event-engineering/acm-ca-go/pkg/certgen"
	logging "code.justin.tv/event-engineering/golibs/pkg/logging"
	aws "code.justin.tv/event-engineering/moonlight-api/pkg/aws"
	"code.justin.tv/event-engineering/moonlight-api/pkg/ca"
	db "code.justin.tv/event-engineering/moonlight-api/pkg/db"
	"context"
	"crypto/tls"
	"encoding/json"
	"fmt"
	"github.com/aws/aws-lambda-go/events"
	"net/http"
	"strconv"
)

const GlobalMaxReceiveCount = 50

type Client interface {
	Handle(ctx context.Context, sqsEvent HybridEvent) error
}

type client struct {
	logger           logging.Logger
	daemonHttpClient *http.Client
	db               db.MoonlightDB
	aws              *aws.Client
	daemonASGName    string
}

func New(db db.MoonlightDB, aws *aws.Client, certGen certgen.Generator, daemonASGName string, logger logging.Logger) (Client, error) {
	// We're using acm-ca-go to generate TLS certificates signed by a private CA, so we need to grab the private CA root cert and add it to the pool
	rootCAs, err := ca.GetRootCAs(certGen)
	if err != nil {
		return nil, err
	}

	client := &client{
		logger: logger,
		daemonHttpClient: &http.Client{
			Transport: &http.Transport{
				TLSClientConfig: &tls.Config{
					RootCAs: rootCAs,
				},
			},
		},
		db:            db,
		aws:           aws,
		daemonASGName: daemonASGName,
	}

	return client, nil
}

type HybridEvent struct {
	events.SQSEvent
	events.CloudWatchEvent
}

func (c *client) Handle(ctx context.Context, event HybridEvent) error {
	// We're using a batch size of 1 so that we can use standard error handling rather than manually resending the message
	// Some things come in as SQS records for retries etc.
	if len(event.Records) > 0 {
		for _, message := range event.Records {
			receiveCount, err := strconv.ParseInt(message.Attributes["ApproximateReceiveCount"], 10, 32)
			if err != nil {
				c.logger.Warnf("Could not establish receive count: %v", err)
				// We can't return err here otherwise we would process this message forever
				return nil
			}

			if receiveCount > GlobalMaxReceiveCount {
				c.logger.Warnf("Reached global maximim number of retries (%v/%v)", receiveCount, GlobalMaxReceiveCount)
				return nil
			}

			var jobTypeAttr events.SQSMessageAttribute
			var ok bool

			if jobTypeAttr, ok = message.MessageAttributes["job_type"]; !ok {
				msgStr, _ := json.Marshal(message)
				return fmt.Errorf("Got unexpected message %v", string(msgStr))
			}

			jobType := *jobTypeAttr.StringValue

			switch jobType {
			case AllocateServer.String():
				if receiveCount > AllocateServerMaxReceiveCount {
					c.logger.Warnf("Reached maximim number of retries (%v/%v) for instance server allocation", receiveCount, AllocateServerMaxReceiveCount)
					return nil
				}

				asm := &AllocateServerMessage{}
				err := json.Unmarshal([]byte(message.Body), asm)
				if err != nil {
					c.logger.Warnf("Failed to execute job: unable to deserialise body, %v", err)
					return err
				}

				err = c.AllocateServer(ctx, *asm)
				if err != nil {
					c.logger.Warnf("Failed to execute AllocateServer, %v", err)
					return err
				}

				break
			default:
				return fmt.Errorf("Got unexpected job type %v", jobType)
			}
		}
	}

	// Other things come in as Cloudwatch events because they happen on a schedule we don't want to retry
	if event.Source == "aws.events" {
		eventPayload := &EventPayload{}
		err := json.Unmarshal([]byte(event.Detail), eventPayload)
		if err != nil || eventPayload.JobType == "" {
			eventStr, _ := json.Marshal(event)
			return fmt.Errorf("Got unexpected event %v", string(eventStr))
		}

		switch eventPayload.JobType {
		case CheckConsistency.String():
			err = c.CheckConsistency(ctx)
			if err != nil {
				c.logger.Warnf("Failed to execute CheckConsistency, %v", err)
				return err
			}
			break
		case CheckScaling.String():
			err = c.CheckScaling(ctx, c.daemonASGName)
			if err != nil {
				c.logger.Warnf("Failed to execute CheckScaling, %v", err)
				return err
			}
			break
		}
	}

	return nil
}
