package StarfruitSECProducer

import (
	"context"
	"encoding/json"
	"fmt"
	"log"
	"strings"
	"time"

	"code.justin.tv/amzn/StarfruitSECTwirp"
	"github.com/aws/aws-sdk-go/aws"
	"github.com/aws/aws-sdk-go/aws/endpoints"
	"github.com/aws/aws-sdk-go/aws/session"
	"github.com/aws/aws-sdk-go/service/sqs"
	"github.com/golang/protobuf/ptypes/timestamp"
	"github.com/google/uuid"
	"github.com/pkg/errors"
	"google.golang.org/protobuf/encoding/protojson"
)

const (
	enableSendingEventsToSEC         = true
	forceSendEventsForTwitchChannels = false
	enableTranscodeHeartbeat         = true
	enableStarvationHeartbeat        = false

	stageDev        = "dev"
	stageStaging    = "staging"
	stageProduction = "production"
	stageN2T        = "n2t"

	stageBeta  = "beta"
	stageTest  = "test"
	stageGamma = "gamma"
	stageProd  = "prod"

	transcodeHeartbeatTime = 300
)

// To instantiate the SEC producer we need to know the SEC stage that maps to the Service stage.
// This map provides the mapping. This may need to be updated based on the environment setup.
var stages = map[string]string{
	stageBeta:  stageBeta,
	stageGamma: stageGamma,
	stageProd:  stageProd,

	stageN2T:        stageTest,
	stageDev:        stageGamma,
	stageStaging:    stageGamma,
	stageProduction: stageProd,
}

// secClient that interfaces with SEC Service
type secClient struct {
	originDc    string
	channelArn  string
	sessionId   string
	channel     string
	channelName string
	customerId  string
	streamId    string
	stage       string
	client      SECProducer
	clients     map[string]ProtobufSQSProducer
}

// NewSECClient will return a new SEC client. (For use with data plane services that are processing a single stream
func NewSECClient(pop string, stage string, channel string, sessionId string) (SEC, error) {
	enableSEC, reason := shouldSendEventsToSEC(channel)
	if !enableSEC {
		return nil, fmt.Errorf("SEC is disabled. Reason: %s", reason)
	}

	sfStage, ok := stages[strings.ToLower(stage)]
	if !ok {
		return nil, fmt.Errorf("stage is not configured for the stage %s", stage)
	}

	channelArn := GetChannelArn(channel)
	homeRegion, err := GetHomeRegionForChannel(channelArn)
	if err != nil {
		return nil, errors.Wrap(err, "Channel name is not a valid ARN or does not have home region. ")
	}

	if sessionId == "" {
		return nil, fmt.Errorf("SessionId is required to initialize SEC client")
	}

	awsConfig := &aws.Config{
		Region:              aws.String(homeRegion),
		MaxRetries:          aws.Int(3),
		STSRegionalEndpoint: endpoints.RegionalSTSEndpoint,
	}
	sess, e := session.NewSession(awsConfig)
	if e != nil {
		return nil, errors.Wrap(e, "Could not create the session to init SQS client")
	}
	sqsClient := sqs.New(sess)
	client, err := NewSECProducer(sfStage, homeRegion, sqsClient)

	if err != nil {
		return nil, errors.Wrap(err, "Could not create the SEC producer")
	}

	customerId, contentId := GetCustomerIdContentIdFromChannel(channel)
	return &secClient{
		originDc:    pop,
		channelArn:  channelArn,
		sessionId:   sessionId,
		client:      client,
		channel:     channel,
		stage:       sfStage,
		customerId:  customerId,
		channelName: contentId,
		clients:     make(map[string]ProtobufSQSProducer),
	}, nil
}

// NewSECBaseClient will return a new SEC client. (For use with control plane services that are processing multiple streams
func NewSECBaseClient(pop string, stage string) (SEC, error) {
	sfStage, ok := stages[strings.ToLower(stage)]
	if !ok {
		return nil, fmt.Errorf("stage is not configured for the stage %s", stage)
	}

	return &secClient{
		originDc: pop,
		stage:    sfStage,
		clients:  make(map[string]ProtobufSQSProducer),
	}, nil
}

func (c *secClient) createSQSProducer(stage string, homeRegion string) (ProtobufSQSProducer, error) {
	err := ValidateProducerParams(stage, homeRegion)
	if err != nil {
		return nil, errors.Wrap(err, "stage or home region is invalid")
	}

	awsConfig := &aws.Config{
		Region:              aws.String(homeRegion),
		MaxRetries:          aws.Int(3),
		STSRegionalEndpoint: endpoints.RegionalSTSEndpoint,
	}

	sess, err := session.NewSession(awsConfig)
	if err != nil {
		return nil, errors.Wrap(err, "Could not create the session to init SQS client")
	}

	sqsClient := sqs.New(sess)

	queueUrl, err := GetQueueUrl(stage, homeRegion)
	if err != nil {
		return nil, errors.Wrap(err, "unable to determine destination queue")
	}

	protoProducer := NewProtoSQSProducer(queueUrl, sqsClient)
	return protoProducer, nil
}

func (c *secClient) SetStreamId(streamId string) {
	c.streamId = streamId
}

func (c *secClient) getSQSProducer(region string) (ProtobufSQSProducer, error) {
	var producer ProtobufSQSProducer
	var err error
	// One queue is per region, so we pick up the producer applicable to the queue of that region.
	producer, ok := c.clients[region]
	// if producer is not available then create a new producer
	if !ok {
		producer, err = c.createSQSProducer(c.stage, region)
		if err != nil {
			return nil, errors.Wrap(err, "Unable to create the protobuf SQS producer")
		}
		c.clients[region] = producer
	}

	return producer, nil
}

func (c *secClient) sendEvent(ctx context.Context, producer ProtobufSQSProducer, data *StarfruitSECTwirp.StreamEventData) error {
	batchMsg := &StarfruitSECTwirp.StreamEventDataBatch{
		Data: []*StarfruitSECTwirp.StreamEventData{data},
	}
	return producer.Send(ctx, batchMsg)
}

func (c *secClient) sendStreamEvent(ctx context.Context, eventName StarfruitSECTwirp.EventName, eventGroup StarfruitSECTwirp.EventGroup, reason string, extra string) error {
	id, _ := uuid.New().MarshalBinary()
	data := &StarfruitSECTwirp.StreamEventData{
		EventUuid:           id,
		EventTime:           &timestamp.Timestamp{Seconds: time.Now().Unix()},
		ChannelArn:          c.channelArn,
		ChannelName:         c.channelName,
		CustomerId:          c.customerId,
		BroadcastSessionId:  c.sessionId,
		StreamId:            c.streamId,
		OriginDc:            c.originDc,
		EventName:           eventName.String(),
		EventGroup:          eventGroup.String(),
		Reason:              reason,
		Extra:               extra,
		CustomerChannelName: "",
	}
	return c.client.SendEventToSEC(ctx, data)
}

func (c *secClient) sendLimitBreachEvent(ctx context.Context, channel string, customerId string, homeRegion string, eventName StarfruitSECTwirp.EventName, eventGroup StarfruitSECTwirp.EventGroup, limitThreshold int, exceededBy int) error {
	var region string
	var channelArn string
	var err error

	if channel != "" {
		channelArn = GetChannelArn(channel)
		region, err = GetHomeRegionForChannel(channelArn)
		if err != nil {
			return errors.Wrap(err, "Channel name is not a valid ARN or does not have home region. ")
		}
	} else if homeRegion != "" {
		region = homeRegion
	} else {
		return errors.New("No home region was provided so unable to get right SQS client to send the event")
	}

	producer, err := c.getSQSProducer(region)
	if err != nil {
		return errors.Wrap(err, "Unable to get the SQS producer to send the event")
	}

	limitData := &StarfruitSECTwirp.LimitExtraData{
		LimitThreshold: int64(limitThreshold),
		ExceededBy:     int64(exceededBy),
	}
	extra, err := json.Marshal(limitData)
	if err != nil {
		return errors.Wrap(err, "Unable to marshal the limit data into json")
	}

	// SEC uses CustomerID as part of the PK and GSI Sort Key for Limit breach table.
	// For bitrate events we do not pass Customer ID while sending limit breach events.
	// This is just for that special case.
	if customerId == "" {
		customerId = c.customerId
	}

	id, _ := uuid.New().MarshalBinary()
	data := &StarfruitSECTwirp.StreamEventData{
		EventUuid:           id,
		EventTime:           &timestamp.Timestamp{Seconds: time.Now().Unix()},
		ChannelArn:          channelArn,
		ChannelName:         channel,
		CustomerId:          customerId,
		BroadcastSessionId:  c.sessionId,
		StreamId:            "",
		OriginDc:            c.originDc,
		EventName:           eventName.String(),
		EventGroup:          eventGroup.String(),
		Reason:              "",
		Extra:               string(extra),
		CustomerChannelName: "",
	}
	return c.sendEvent(ctx, producer, data)
}

func (c *secClient) sendStreamEventBase(ctx context.Context, channel string, sessionId string, eventName StarfruitSECTwirp.EventName,
	eventGroup StarfruitSECTwirp.EventGroup, reason string, extra string) error {
	channelArn := GetChannelArn(channel)
	region, err := GetHomeRegionForChannel(channelArn)
	if err != nil {
		return errors.Wrap(err, "Channel name is not a valid ARN or does not have home region. ")
	}
	customerId, contentId := GetCustomerIdContentIdFromChannel(channel)

	producer, err := c.getSQSProducer(region)
	if err != nil {
		return errors.Wrap(err, "Unable to get the SQS producer to send the event")
	}

	id, _ := uuid.New().MarshalBinary()
	data := &StarfruitSECTwirp.StreamEventData{
		EventUuid:           id,
		EventTime:           &timestamp.Timestamp{Seconds: time.Now().Unix()},
		ChannelArn:          channelArn,
		ChannelName:         contentId,
		CustomerId:          customerId,
		BroadcastSessionId:  sessionId,
		StreamId:            "",
		OriginDc:            c.originDc,
		EventName:           eventName.String(),
		EventGroup:          eventGroup.String(),
		Reason:              reason,
		Extra:               extra,
		CustomerChannelName: "",
	}

	return c.sendEvent(ctx, producer, data)
}

func (c *secClient) sendStreamEventBaseV2(ctx context.Context, channel string, sessionId string, eventName StarfruitSECTwirp.EventName,
	eventGroup StarfruitSECTwirp.EventGroup, reason string, extra string, customerChannelName string) error {
	channelArn := GetChannelArn(channel)
	region, err := GetHomeRegionForChannel(channelArn)
	if err != nil {
		return errors.Wrap(err, "Channel name is not a valid ARN or does not have home region. ")
	}
	customerId, contentId := GetCustomerIdContentIdFromChannel(channel)

	producer, err := c.getSQSProducer(region)
	if err != nil {
		return errors.Wrap(err, "Unable to get the SQS producer to send the event")
	}

	id, _ := uuid.New().MarshalBinary()
	data := &StarfruitSECTwirp.StreamEventData{
		EventUuid:           id,
		EventTime:           &timestamp.Timestamp{Seconds: time.Now().Unix()},
		ChannelArn:          channelArn,
		ChannelName:         contentId,
		CustomerId:          customerId,
		BroadcastSessionId:  sessionId,
		StreamId:            "",
		OriginDc:            c.originDc,
		EventName:           eventName.String(),
		EventGroup:          eventGroup.String(),
		Reason:              reason,
		Extra:               extra,
		CustomerChannelName: customerChannelName,
	}

	return c.sendEvent(ctx, producer, data)
}

func (c *secClient) SendStarvationStartEvent(ctx context.Context, reason string, extra string) error {
	return c.sendStreamEvent(ctx, StarfruitSECTwirp.EventName_STARVATION_STARTED, StarfruitSECTwirp.EventGroup_STARVATION_EVENT_GROUP, reason, extra)
}

func (c *secClient) SendStarvationEndEvent(ctx context.Context, reason string, extra string) error {
	return c.sendStreamEvent(ctx, StarfruitSECTwirp.EventName_STARVATION_ENDED, StarfruitSECTwirp.EventGroup_STARVATION_EVENT_GROUP, reason, extra)
}

func (c *secClient) SendTranscodeStartedEvent(ctx context.Context, reason string, extra string) error {
	return c.sendStreamEvent(ctx, StarfruitSECTwirp.EventName_TRANSCODE_STARTED, StarfruitSECTwirp.EventGroup_TRANSCODE_EVENT_GROUP, reason, extra)
}

func (c *secClient) SendTranscodeEndedEvent(ctx context.Context, reason string, extra string) error {
	return c.sendStreamEvent(ctx, StarfruitSECTwirp.EventName_TRANSCODE_ENDED, StarfruitSECTwirp.EventGroup_TRANSCODE_EVENT_GROUP, reason, extra)
}

func (c *secClient) SendVODStartedEvent(ctx context.Context, reason string, extra string) error {
	return c.sendStreamEvent(ctx, StarfruitSECTwirp.EventName_VOD_STARTED, StarfruitSECTwirp.EventGroup_VOD_EVENT_GROUP, reason, extra)
}

func (c *secClient) SendVODEndedEvent(ctx context.Context, reason string, extra string) error {
	return c.sendStreamEvent(ctx, StarfruitSECTwirp.EventName_VOD_ENDED, StarfruitSECTwirp.EventGroup_VOD_EVENT_GROUP, reason, extra)
}

func (c *secClient) SendRecordingStartedEvent(ctx context.Context, reason string, extra string) error {
	return c.sendStreamEvent(ctx, StarfruitSECTwirp.EventName_RECORDING_STARTED, StarfruitSECTwirp.EventGroup_RECORDING_EVENT_GROUP, reason, extra)
}

func (c *secClient) SendRecordingStartFailedEvent(ctx context.Context, reason string, extra string) error {
	return c.sendStreamEvent(ctx, StarfruitSECTwirp.EventName_RECORDING_START_FAILED, StarfruitSECTwirp.EventGroup_RECORDING_EVENT_GROUP, reason, extra)
}

func (c *secClient) SendRecordingEndedEvent(ctx context.Context, reason string, extra string) error {
	return c.sendStreamEvent(ctx, StarfruitSECTwirp.EventName_RECORDING_ENDED, StarfruitSECTwirp.EventGroup_RECORDING_EVENT_GROUP, reason, extra)
}

func (c *secClient) SendRecordingEndedWithFailureEvent(ctx context.Context, reason string, extra string) error {
	return c.sendStreamEvent(ctx, StarfruitSECTwirp.EventName_RECORDING_ENDED_WITH_FAILURE, StarfruitSECTwirp.EventGroup_RECORDING_EVENT_GROUP, reason, extra)
}

func (c *secClient) SendBitrateLimitExceededEvent(ctx context.Context, limitThreshold int, exceededBy int) error {
	return c.sendLimitBreachEvent(ctx, c.channel, "", "", StarfruitSECTwirp.EventName_BITRATE_LIMIT_EXCEEDED, StarfruitSECTwirp.EventGroup_LIMIT_EXCEEDED_EVENT_GROUP, limitThreshold, exceededBy)
}

func (c *secClient) SendDimensionLimitExceededEvent(ctx context.Context, limitThreshold int, exceededBy int) error {
	return c.sendLimitBreachEvent(ctx, c.channel, "", "", StarfruitSECTwirp.EventName_DIMENSION_LIMIT_EXCEEDED, StarfruitSECTwirp.EventGroup_LIMIT_EXCEEDED_EVENT_GROUP, limitThreshold, exceededBy)
}

func (c *secClient) SendPixelLimitExceededEvent(ctx context.Context, limitThreshold int, exceededBy int) error {
	return c.sendLimitBreachEvent(ctx, c.channel, "", "", StarfruitSECTwirp.EventName_PIXEL_LIMIT_EXCEEDED, StarfruitSECTwirp.EventGroup_LIMIT_EXCEEDED_EVENT_GROUP, limitThreshold, exceededBy)
}

func (c *secClient) SendStarvationStatusEvent(ctx context.Context, reason string, extra string) error {
	return c.sendStreamEvent(ctx, StarfruitSECTwirp.EventName_STARVATION_STATUS, StarfruitSECTwirp.EventGroup_STARVATION_EVENT_GROUP, reason, extra)
}

func (c *secClient) SendTranscodeHealthEvent(ctx context.Context, reason string, extra string) error {
	return c.sendStreamEvent(ctx, StarfruitSECTwirp.EventName_TRANSCODE_HEALTH, StarfruitSECTwirp.EventGroup_TRANSCODE_EVENT_GROUP, reason, extra)
}

func (c *secClient) SendTranscodeRestartedEvent(ctx context.Context, channel string, sessionId string, reason string, extra string) error {
	return c.sendStreamEventBase(ctx, channel, sessionId, StarfruitSECTwirp.EventName_TRANSCODE_RESTARTED, StarfruitSECTwirp.EventGroup_TRANSCODE_EVENT_GROUP, reason, extra)
}

func (c *secClient) SendCCBLimitExceededEvent(ctx context.Context, customerId string, homeRegion string, limitThreshold int, exceededBy int) error {
	return c.sendLimitBreachEvent(ctx, "", customerId, homeRegion, StarfruitSECTwirp.EventName_CCB_LIMIT_EXCEEDED, StarfruitSECTwirp.EventGroup_LIMIT_EXCEEDED_EVENT_GROUP, limitThreshold, exceededBy)
}

func (c *secClient) SendCCVLimitExceededEvent(ctx context.Context, customerId string, homeRegion string, limitThreshold int, exceededBy int) error {
	return c.sendLimitBreachEvent(ctx, "", customerId, homeRegion, StarfruitSECTwirp.EventName_CCV_LIMIT_EXCEEDED, StarfruitSECTwirp.EventGroup_LIMIT_EXCEEDED_EVENT_GROUP, limitThreshold, exceededBy)
}

func (c *secClient) SendStreamCreatedEvent(ctx context.Context, channel string, sessionId string, reason string, extra string) error {
	return c.sendStreamEventBase(ctx, channel, sessionId, StarfruitSECTwirp.EventName_STREAM_CREATED, StarfruitSECTwirp.EventGroup_TRANSCODE_EVENT_GROUP, reason, extra)
}

func (c *secClient) SendStreamAuthenticationEvent(ctx context.Context, channel string, sessionId string, reason string, extra string) error {
	return c.sendStreamEventBase(ctx, channel, sessionId, StarfruitSECTwirp.EventName_STREAM_AUTHENTICATION, StarfruitSECTwirp.EventGroup_SESSION_EVENT_GROUP, reason, extra)
}

func (c *secClient) SendTranscodeCreatedEvent(ctx context.Context, channel string, sessionId string, reason string, extra string) error {
	return c.sendStreamEventBase(ctx, channel, sessionId, StarfruitSECTwirp.EventName_TRANSCODE_CREATED, StarfruitSECTwirp.EventGroup_TRANSCODE_EVENT_GROUP, reason, extra)
}

func (c *secClient) SendStreamDisconnectedEvent(ctx context.Context, channel string, sessionId string, reason string, extra string) error {
	return c.sendStreamEventBase(ctx, channel, sessionId, StarfruitSECTwirp.EventName_STREAM_DISCONNECTED, StarfruitSECTwirp.EventGroup_SESSION_EVENT_GROUP, reason, extra)
}

func (c *secClient) SendStreamDeletedEvent(ctx context.Context, channel string, sessionId string, reason string, extra string) error {
	return c.sendStreamEventBase(ctx, channel, sessionId, StarfruitSECTwirp.EventName_STREAM_DELETED, StarfruitSECTwirp.EventGroup_SESSION_EVENT_GROUP, reason, extra)
}

func (c *secClient) SendTranscodeHeartbeatEvent(ctx context.Context, reason string, extra string) error {
	if enableTranscodeHeartbeat {
		return c.sendStreamEvent(ctx, StarfruitSECTwirp.EventName_TRANSCODE_HEARTBEAT, StarfruitSECTwirp.EventGroup_HEARTBEAT_EVENT_GROUP, reason, extra)
	}
	return nil
}

func (c *secClient) SendStarvationHeartbeatEvent(ctx context.Context, reason string, extra string) error {
	if enableStarvationHeartbeat {
		return c.sendStreamEvent(ctx, StarfruitSECTwirp.EventName_STARVATION_HEARTBEAT, StarfruitSECTwirp.EventGroup_HEARTBEAT_EVENT_GROUP, reason, extra)
	}
	return nil
}

func (c *secClient) SendStreamAuthenticationEventV2(ctx context.Context, channel string, sessionId string, reason string, extra string, customerChannelName string) error {
	return c.sendStreamEventBaseV2(ctx, channel, sessionId, StarfruitSECTwirp.EventName_STREAM_AUTHENTICATION, StarfruitSECTwirp.EventGroup_SESSION_EVENT_GROUP, reason, extra, customerChannelName)
}

// SendStreamDetails Stream detail / health details
func (c *secClient) SendStreamDetails(ctx context.Context, sds *StarfruitSECTwirp.StreamDetails, reason string) error {
	if sds != nil {
		sds.ChannelArn = c.channelArn
		sds.StreamId = c.sessionId
		b, err := protojson.Marshal(sds)
		if err != nil {
			return err
		}
		// we make the json of the entire structure and put it in the extra field
		return c.sendStreamEvent(ctx, StarfruitSECTwirp.EventName_STREAM_DETAILS, StarfruitSECTwirp.EventGroup_STREAM_DETAIL_GROUP, reason, string(b))
	}
	return fmt.Errorf("parameter struct StreamDetails is nil, cannot send message to SEC")
}

// SendStreamDetailsV2 specifically for digestion
func (c *secClient) SendStreamDetailsV2(ctx context.Context, channel, sessionId string, sds *StarfruitSECTwirp.StreamDetails, reason string) error {
	if sds != nil {
		sds.ChannelArn = c.channelArn
		sds.StreamId = c.sessionId
		b, err := protojson.Marshal(sds)
		if err != nil {
			return err
		}
		// we make the json of the entire structure and put it in the extra field
		return c.sendStreamEventBase(ctx, channel, sessionId, StarfruitSECTwirp.EventName_STREAM_DETAILS, StarfruitSECTwirp.EventGroup_STREAM_DETAIL_GROUP, reason, string(b))
	}
	return fmt.Errorf("parameter struct StreamDetails is nil, cannot send message to SEC")
}

// shouldSendEventsToSEC checks if SEC should be enabled for a given channel
func shouldSendEventsToSEC(channel string) (bool, string) {
	if channel == "" {
		return false, "SEC event processing is disabled. Channel name is empty"
	}

	// Is sending events to SEC enabled?
	if !enableSendingEventsToSEC {
		return false, "SEC event processing is disabled because of the feature flag"
	}

	if IsStarfruitChannel(channel) {
		return true, fmt.Sprintf("SEC event processing enabled for Starfruit Channel:%s", channel)
	}

	if IsLvsChannel(channel) {
		return false, fmt.Sprintf("SEC event processing disabled for LVS Channel:%s", channel)
	}

	if !strings.Contains(channel, ".") {
		if forceSendEventsForTwitchChannels {
			return true, "SEC event processing enabled for all channels including Twitch channels"
		} else {
			return false, "SEC event processing is disabled for Twitch channels"
		}
	} else {
		return false, fmt.Sprintf("SEC event processing disabled. Unrecognized Channel:%s", channel)
	}
}

func (c *secClient) SendPerodicTranscodeHeartbeat(ctx context.Context, reason string, extra string) {
	ticker := time.NewTicker(transcodeHeartbeatTime * time.Second)

	for {
		select {
		case <-ctx.Done():
			ticker.Stop()
			return
		case <-ticker.C:
			err := c.SendTranscodeHeartbeatEvent(ctx, reason, extra)
			if err != nil {
				log.Printf("Unable to send transcode heartbeat event to SEC :%v\n", err)
			}
		}
	}
}
