package backend

import (
	"github.com/aws/aws-sdk-go/aws/ec2metadata"
	"github.com/aws/aws-sdk-go/aws/session"
	"github.com/aws/aws-sdk-go/service/cloudwatch"
	"github.com/aws/aws-sdk-go/service/ec2"
	"github.com/aws/aws-sdk-go/service/ecr"
	"github.com/aws/aws-sdk-go/service/lambda"
	"github.com/aws/aws-sdk-go/service/lambda/lambdaiface"
	"github.com/aws/aws-sdk-go/service/sqs"
	"github.com/aws/aws-sdk-go/service/ssm"
)

// Client contains methods for AWS interactions to allow easier data mocking for tests.
// If this interface changes, counterfeiter must be re-run using `make mocks`
type Client interface {
	SQSReceiveMessage(input *sqs.ReceiveMessageInput) (*sqs.ReceiveMessageOutput, error)
	SSMGetParameters(input *ssm.GetParametersInput) (*ssm.GetParametersOutput, error)
	SQSCreateQueue(input *sqs.CreateQueueInput) (*sqs.CreateQueueOutput, error)
	SQSGetQueueUrl(input *sqs.GetQueueUrlInput) (*sqs.GetQueueUrlOutput, error)
	SQSDeleteQueue(input *sqs.DeleteQueueInput) (*sqs.DeleteQueueOutput, error)
	EC2MGetInstanceIdentityDocument() (ec2metadata.EC2InstanceIdentityDocument, error)
	EC2DescribeInstances(input *ec2.DescribeInstancesInput) (*ec2.DescribeInstancesOutput, error)
	ECRGetAuthorizationToken(input *ecr.GetAuthorizationTokenInput) (*ecr.GetAuthorizationTokenOutput, error)
	CWPutMetricData(input *cloudwatch.PutMetricDataInput) (*cloudwatch.PutMetricDataOutput, error)
	GetLambdaClient() lambdaiface.LambdaAPI
}

type client struct {
	sqs    *sqs.SQS
	ssm    *ssm.SSM
	ec2    *ec2.EC2
	ec2m   *ec2metadata.EC2Metadata
	ecr    *ecr.ECR
	cw     *cloudwatch.CloudWatch
	lambda *lambda.Lambda
}

// New generates all the AWS clients required for this service using the supplied AWS session
func New(sess *session.Session) (Client, error) {
	return &client{
		sqs:    sqs.New(sess),
		ssm:    ssm.New(sess),
		ec2m:   ec2metadata.New(sess),
		ec2:    ec2.New(sess),
		ecr:    ecr.New(sess),
		cw:     cloudwatch.New(sess),
		lambda: lambda.New(sess),
	}, nil
}

// GetLambdaClient returns the lambda client for use in twirp lambda transport
func (c *client) GetLambdaClient() lambdaiface.LambdaAPI {
	return c.lambda
}

// SQSReceiveMessage calls the underlying ReceiveMessage from the sqs backend
func (c *client) SQSReceiveMessage(input *sqs.ReceiveMessageInput) (*sqs.ReceiveMessageOutput, error) {
	return c.sqs.ReceiveMessage(input)
}

// SQSCreateQueue calls the underlying CreateQueue from the sqs backend
func (c *client) SQSCreateQueue(input *sqs.CreateQueueInput) (*sqs.CreateQueueOutput, error) {
	return c.sqs.CreateQueue(input)
}

// SQSDeleteQueue calls the underlying DeleteQueue from the sqs backend
func (c *client) SQSDeleteQueue(input *sqs.DeleteQueueInput) (*sqs.DeleteQueueOutput, error) {
	return c.sqs.DeleteQueue(input)
}

// SQSGetQueueUrl calls the underlying GetQueueUrl from the sqs backend
func (c *client) SQSGetQueueUrl(input *sqs.GetQueueUrlInput) (*sqs.GetQueueUrlOutput, error) {
	return c.sqs.GetQueueUrl(input)
}

// SSMGetParameters calls the underlying GetParameters from the ssm backend
func (c *client) SSMGetParameters(input *ssm.GetParametersInput) (*ssm.GetParametersOutput, error) {
	return c.ssm.GetParameters(input)
}

// EC2MGetInstanceIdentityDocument calls the underlying MGetInstanceIdentityDocument from the ec2metadata backend
func (c *client) EC2MGetInstanceIdentityDocument() (ec2metadata.EC2InstanceIdentityDocument, error) {
	return c.ec2m.GetInstanceIdentityDocument()
}

// EC2DescribeInstances calls the underlying DescribeInstance from the ec2 backend
func (c *client) EC2DescribeInstances(input *ec2.DescribeInstancesInput) (*ec2.DescribeInstancesOutput, error) {
	return c.ec2.DescribeInstances(input)
}

// GetAuthorizationToken calls the underlying GetAuthorizationToken from the ecr backend
func (c *client) ECRGetAuthorizationToken(input *ecr.GetAuthorizationTokenInput) (*ecr.GetAuthorizationTokenOutput, error) {
	return c.ecr.GetAuthorizationToken(input)
}

// CWPutMetricData calls the underlying PutMetricData from the cloudwatch backend
func (c *client) CWPutMetricData(input *cloudwatch.PutMetricDataInput) (*cloudwatch.PutMetricDataOutput, error) {
	return c.cw.PutMetricData(input)
}
