package sns

import (
	"errors"
	"math/rand"
	"time"

	"code.justin.tv/feeds/distconf"
	"code.justin.tv/feeds/graphdb/cmd/graphdb/internal/graphdbmodel"
	"code.justin.tv/feeds/graphdb/proto/graphdb"
	"code.justin.tv/feeds/log"
	"github.com/aws/aws-sdk-go/aws"
	"github.com/aws/aws-sdk-go/service/sns"
)

var (
	newEdgeListRequestEvent  = "edgelistrequest"
	newEdgeGetRequestEvent   = "edgegetrequest"
	newEdgeCountRequestEvent = "edgecountrequest"
	newInvalidateCacheEvent  = "invalidatecache"
)

type SNSConfig struct {
	EdgeListRequestARN  *distconf.Str
	EdgeGetRequestARN   *distconf.Str
	EdgeCountRequestARN *distconf.Str
	InvalidateCacheARN  *distconf.Str
}

func (c *SNSConfig) Load(d *distconf.Distconf) error {
	c.EdgeListRequestARN = d.Str("graphdb.sns.edge_list_request_arn", "")
	if c.EdgeListRequestARN.Get() == "" {
		return errors.New("Unable to find EdgeListRequestARN from graphdb.sns.edge_list_request_arn")
	}
	c.EdgeGetRequestARN = d.Str("graphdb.sns.edge_get_request_arn", "")
	if c.EdgeGetRequestARN.Get() == "" {
		return errors.New("Unable to find EdgeGetRequestARN from graphdb.sns.edge_get_request_arn")
	}
	c.EdgeCountRequestARN = d.Str("graphdb.sns.edge_count_request_arn", "")
	if c.EdgeCountRequestARN.Get() == "" {
		return errors.New("Unable to find EdgeCountRequestARN from graphdb.sns.edge_count_request_arn")
	}
	c.InvalidateCacheARN = d.Str("graphdb.sns.invalidate_cache_arn", "")
	if c.InvalidateCacheARN.Get() == "" {
		return errors.New("Unable to find InvalidateCacheARN from graphdb.sns.invalidate_cache_arn")
	}
	return nil
}

type SNS interface {
	Publish(*sns.PublishInput) (*sns.PublishOutput, error)
}

type SNSClient struct {
	BaseClient  SNS
	SnsConfig   SNSConfig
	MarshalJSON func(interface{}) ([]byte, error)
	Log         *log.ElevatedLog
}

type EdgeListRequestMessage struct {
	Request  *graphdb.EdgeListRequest  `json:"request"`
	Response *graphdb.EdgeListResponse `json:"response"`
}

type EdgeGetRequestMessage struct {
	Request  *graphdb.EdgeGetRequest  `json:"request"`
	Response *graphdb.EdgeGetResponse `json:"response"`
}

type EdgeCountRequestMessage struct {
	Request  *graphdb.EdgeCountRequest  `json:"request"`
	Response *graphdb.EdgeCountResponse `json:"response"`
}

func (p *SNSClient) SendEdgeListRequestMessage(req *graphdb.EdgeListRequest, resp *graphdb.EdgeListResponse) error {
	message, err := p.MarshalJSON(EdgeListRequestMessage{
		Request:  req,
		Response: resp,
	})
	if err != nil {
		return err
	}
	snsMessage := string(message)
	topicARN := aws.String(p.SnsConfig.EdgeListRequestARN.Get())

	params := &sns.PublishInput{
		Message: &snsMessage,
		MessageAttributes: map[string]*sns.MessageAttributeValue{
			"event": {
				DataType:    aws.String("String"),
				StringValue: &newEdgeListRequestEvent,
			},
		},
		TopicArn: topicARN,
	}

	_, err = p.BaseClient.Publish(params)
	if err != nil {
		p.Log.Log("error sending EdgeList request sns message", err)
		return err
	}

	return nil
}

func (p *SNSClient) SendEdgeGetRequestMessage(req *graphdb.EdgeGetRequest, resp *graphdb.EdgeGetResponse) error {
	message, err := p.MarshalJSON(EdgeGetRequestMessage{
		Request:  req,
		Response: resp,
	})
	if err != nil {
		return err
	}
	snsMessage := string(message)
	topicARN := aws.String(p.SnsConfig.EdgeGetRequestARN.Get())

	params := &sns.PublishInput{
		Message: &snsMessage,
		MessageAttributes: map[string]*sns.MessageAttributeValue{
			"event": {
				DataType:    aws.String("String"),
				StringValue: &newEdgeGetRequestEvent,
			},
		},
		TopicArn: topicARN,
	}

	_, err = p.BaseClient.Publish(params)
	if err != nil {
		p.Log.Log("error sending EdgeGet request sns message", err)
		return err
	}

	return nil
}
func (p *SNSClient) SendEdgeCountRequestMessage(req *graphdb.EdgeCountRequest, resp *graphdb.EdgeCountResponse) error {
	message, err := p.MarshalJSON(EdgeCountRequestMessage{
		Request:  req,
		Response: resp,
	})
	if err != nil {
		return err
	}
	snsMessage := string(message)
	topicARN := aws.String(p.SnsConfig.EdgeCountRequestARN.Get())

	params := &sns.PublishInput{
		Message: &snsMessage,
		MessageAttributes: map[string]*sns.MessageAttributeValue{
			"event": {
				DataType:    aws.String("String"),
				StringValue: &newEdgeCountRequestEvent,
			},
		},
		TopicArn: topicARN,
	}

	_, err = p.BaseClient.Publish(params)
	if err != nil {
		p.Log.Log("error sending EdgeCount request sns message", err)
		return err
	}

	return nil
}

func (p *SNSClient) SendInvalidateCacheMessage(edge graphdbmodel.Edge) error {
	message, err := p.MarshalJSON(&edge) // use pointer type because only *Node has marshal method defined
	if err != nil {
		return err
	}
	snsMessage := string(message)
	topicARN := aws.String(p.SnsConfig.InvalidateCacheARN.Get())

	params := &sns.PublishInput{
		Message: &snsMessage,
		MessageAttributes: map[string]*sns.MessageAttributeValue{
			"event": {
				DataType:    aws.String("String"),
				StringValue: &newInvalidateCacheEvent,
			},
		},
		TopicArn: topicARN,
	}

	_, err = p.BaseClient.Publish(params)
	if err != nil {
		p.Log.Log("error sending InvalidateCache sns message", err)
		return err
	}

	return nil
}

// SendReadRequests has a 1% chance of returning true
func SendReadRequest() bool {
	rand.Seed(time.Now().UnixNano())
	return rand.Intn(100) == 0 // rand.Intn(100) returns a random number from [0,100)
}
