// +build integration

package cwlogevent

import (
	"context"
	"fmt"
	"testing"
	"time"

	"code.justin.tv/hygienic/distconf"
	"code.justin.tv/hygienic/log"
	"code.justin.tv/hygienic/messagebatch"
	"github.com/aws/aws-sdk-go/aws"
	"github.com/aws/aws-sdk-go/aws/credentials"
	"github.com/aws/aws-sdk-go/aws/credentials/stscreds"
	"github.com/aws/aws-sdk-go/aws/session"
	"github.com/aws/aws-sdk-go/service/cloudwatchlogs"
	"github.com/aws/aws-sdk-go/service/sts"
	"github.com/cep21/circuit"
	. "github.com/smartystreets/goconvey/convey"
)

var _ Circuit = &circuit.Circuit{}

// CreateAWSSession returns an aws session needed to connect to AWS services
func createAWSSession(dconf *distconf.Distconf) (*session.Session, []*aws.Config, error) {
	clientProvider, err := session.NewSession()
	if err != nil {
		return nil, nil, err
	}
	profileName := dconf.Str("aws.profile", "").Get()
	if profileName != "" {
		clientProvider.Config.Credentials = credentials.NewSharedCredentials("", profileName)
	}
	retConfig := []*aws.Config{}
	regionName := dconf.Str("aws.region", "").Get()
	if regionName != "" {
		retConfig = append(retConfig, &aws.Config{Region: &regionName})
	}

	assumedRole := dconf.Str("aws.assume_role", "").Get()
	if assumedRole != "" {
		stsSession, err := session.NewSession(retConfig...)
		if err != nil {
			return nil, nil, err
		}
		stsclient := sts.New(stsSession)
		arp := &stscreds.AssumeRoleProvider{
			ExpiryWindow: 10 * time.Second,
			RoleARN:      assumedRole,
			Client:       stsclient,
		}
		creds := credentials.NewCredentials(arp)
		retConfig = append(retConfig, &aws.Config{
			Credentials: creds,
		})
	}
	return clientProvider, retConfig, nil
}

func TestCWIntegration(t *testing.T) {
	Convey("Batcher should work", t, func(c C) {
		mem := distconf.InMemory{}
		mem.StoreConfig(map[string][]byte{
			"log_group_name": []byte("testing-loggroup"),
			"aws.region":     []byte("us-west-2"),
		})
		dconf := &distconf.Distconf{
			Readers: []distconf.Reader{&mem},
		}
		cfg := Config{
			LogGroupName: "testing-loggroup",
		}
		s, ccfg, err := createAWSSession(dconf)
		So(err, ShouldBeNil)
		client := cloudwatchlogs.New(s, ccfg...)
		b := CloudwatchLogBatcher{
			Batcher: messagebatch.Batcher{
				Log:    log.Discard,
				Events: make(chan interface{}, 100),
			},
			Config:  &cfg,
			Client:  client,
			Circuit: circuit.NewCircuitFromConfig("name", circuit.Config{}),
		}
		// Ignore any errors here.  Maybe the stream was already created
		_, _ = client.CreateLogGroup(&cloudwatchlogs.CreateLogGroupInput{
			LogGroupName: aws.String(b.Config.LogGroupName),
		})
		So(b.Setup(), ShouldBeNil)
		ended := make(chan struct{})
		go func() {
			defer close(ended)
			c.So(b.Start(), ShouldBeNil)
		}()
		eventString := fmt.Sprintf("hello-world-%d", time.Now().UnixNano())
		b.Event(eventString, time.Time{})
		So(b.Close(), ShouldBeNil)
		So(b.EmptyEvents(context.Background()), ShouldBeNil)
		var out *cloudwatchlogs.GetLogEventsOutput
		start := time.Now()
		for {
			if time.Since(start) > time.Second*20 {
				t.Fatal("took too long to see the sent event")
				break
			}
			out, err = client.GetLogEvents(&cloudwatchlogs.GetLogEventsInput{
				LogGroupName:  aws.String(b.Config.LogGroupName),
				LogStreamName: &b.LogStreamName,
			})
			So(err, ShouldBeNil)
			if len(out.Events) != 0 {
				break
			}
		}
		firstEvent := out.Events[0]
		So(eventString, ShouldEqual, *firstEvent.Message)
	})
}
