package StarfruitNyxClient

import (
	"context"
	"fmt"
	"time"

	rpc "code.justin.tv/amzn/StarfruitNyxTwirp"
	telemetry "code.justin.tv/amzn/TwitchTelemetry"
	"code.justin.tv/video/invoker"

	"github.com/aws/aws-sdk-go/aws/arn"
	"github.com/aws/aws-sdk-go/aws/credentials/stscreds"
	"github.com/aws/aws-sdk-go/aws/session"
	"github.com/klauspost/compress/zstd"
)

const (
	eventsChannelSize = 10000
	eventDataSize     = 250000 // 250000 bytes = 250KB
)

var (
	ErrInvalidEventType         = fmt.Errorf("invalid event type")
	ErrInvalidSession           = fmt.Errorf("invalid session")
	ErrMismatchedProtoType      = fmt.Errorf("proto message doesn't match event type")
	ErrEventQueueFull           = fmt.Errorf("event queue is full")
	ErrMessageMarshalFailed     = fmt.Errorf("failed to marshal added protobuf message")
	ErrUnsupportedChannelRegion = fmt.Errorf("unsupported channel home region")
)

type ProtoMessage interface {
	Size() int
	MarshalTo(data []byte) (int, error)
}

type NyxProtoEvent interface {
	Size() int
	MarshalTo(data []byte) (int, error)
	GetChannelArn() string
}

type EventType struct {
	name        string
	stream      string
	role        string
	payloadType rpc.PayloadType
}

var (
	EventTypeESR    = EventType{"ESR", "nyx-esr", "esr-submitter", rpc.PayloadType_ESR}
	EventTypeEPR    = EventType{"EPR", "nyx-epr", "epr-submitter", rpc.PayloadType_EPR}
	EventTypeRSR    = EventType{"RSR", "nyx-rsr", "rsr-submitter", rpc.PayloadType_RSR}
	EventTypeBeacon = EventType{"Beacon", "nyx-beacon", "beacon-submitter", rpc.PayloadType_BEACON}
)

type Region struct {
	name       string
	amazonName string
	accountID  string
}

var (
	RegionPDX = Region{"PDX", "us-west-2", "602740783494"}
	RegionCMH = Region{"CMH", "us-east-2", "883280163697"}
	RegionIAD = Region{"IAD", "us-east-1", "640021745694"}
	RegionDUB = Region{"DUB", "eu-west-1", "685868989575"}
	RegionFRA = Region{"FRA", "eu-central-1", "950513999296"}
	RegionNRT = Region{"NRT", "ap-northeast-1", "337771229831"}
	RegionICN = Region{"ICN", "ap-northeast-2", "232718289047"}

	SupportedRegions = []Region{RegionPDX, RegionCMH, RegionIAD, RegionDUB, RegionFRA, RegionNRT, RegionICN}
)

type ClientConfig struct {
	EventType        EventType                      // required
	Session          *session.Session               // required
	SampleReporter   *telemetry.SampleReporter      // if unset won't send metrics
	SendTimeout      time.Duration                  // if unset defaults to 30 seconds
	SendRetryTimeout time.Duration                  // if unset defaults to 1 minute
	FlushInterval    time.Duration                  // if unset will only flush on max-size
	OnError          func(region string, err error) // optional
}

type Client struct {
	ClientConfig

	streams map[Region]*stream
}

// Creation
func NewClient(conf ClientConfig) (*Client, error) {
	if conf.EventType.stream == "" {
		return nil, ErrInvalidEventType
	}

	if conf.Session == nil {
		return nil, ErrInvalidSession
	}

	encoder, err := zstd.NewWriter(nil, zstd.WithEncoderLevel(zstd.SpeedDefault), zstd.WithEncoderConcurrency(1))
	if err != nil {
		return nil, err
	}

	c := &Client{
		ClientConfig: conf,
		streams:      map[Region]*stream{},
	}

	for _, region := range SupportedRegions {
		creds := stscreds.NewCredentials(c.Session, fmt.Sprintf("arn:aws:iam::%s:role/%s", region.accountID, c.EventType.role))

		s, err := newStream(&streamConfig{
			region:           region,
			eventType:        conf.EventType,
			awsSession:       conf.Session,
			awsCreds:         creds,
			sampleReporter:   conf.SampleReporter,
			encoder:          encoder,
			onError:          conf.OnError,
			flushInterval:    conf.FlushInterval,
			sendTimeout:      conf.SendTimeout,
			sendRetryTimeout: conf.SendRetryTimeout,
		})
		if err != nil {
			return nil, fmt.Errorf("failed to create stream %w", err)
		}

		c.streams[region] = s
	}

	return c, nil
}

// Event Addition
func (c *Client) AddEvent(e NyxProtoEvent) error {
	if !c.eventAllowed(e) {
		return ErrMismatchedProtoType
	}

	s, err := c.getEventStream(e)
	if err != nil {
		return fmt.Errorf("failed to get event region %w", err)
	}

	return s.addEvent(e)
}

func (c *Client) AddEventBlocking(ctx context.Context, e NyxProtoEvent) error {
	if !c.eventAllowed(e) {
		return ErrMismatchedProtoType
	}

	s, err := c.getEventStream(e)
	if err != nil {
		return fmt.Errorf("failed to get event region %w", err)
	}

	return s.addEventBlocking(ctx, e)
}

func (c *Client) eventAllowed(e NyxProtoEvent) bool {
	switch e.(type) {
	case *rpc.EdgeSegmentRequestEvent:
		return c.EventType == EventTypeESR
	case *rpc.EdgePlaylistRequestEvent:
		return c.EventType == EventTypeEPR
	case *rpc.ReplicationSegmentRequestEvent:
		return c.EventType == EventTypeRSR
	case *rpc.BeaconEvent:
		return c.EventType == EventTypeBeacon
	default:
		return false
	}
}

func (c *Client) getEventStream(e NyxProtoEvent) (*stream, error) {
	cARN := e.GetChannelArn()
	a, err := arn.Parse(cARN)
	if err != nil {
		return nil, fmt.Errorf("failed to parse channel arn %w", err)
	}

	for r, s := range c.streams {
		if r.amazonName == a.Region {
			return s, nil
		}
	}

	return nil, ErrUnsupportedChannelRegion
}

// Main run loop
func (c *Client) Run(ctx context.Context) (retErr error) {
	inv := invoker.New()
	for _, s := range c.streams {
		inv.Add(s.run)
	}

	err := inv.Run(ctx)
	return err
}
