package migration

import (
	"encoding/json"
	"fmt"
	"golang.org/x/net/context"
	"log"
	"reflect"
	"time"

	"github.com/aws/aws-sdk-go/aws"
	"github.com/aws/aws-sdk-go/aws/session"
	"github.com/aws/aws-sdk-go/service/kinesis"
	"github.com/cactus/go-statsd-client/statsd"

	"github.com/zenazn/goji/graceful"

	"code.justin.tv/chat/timing"
	"code.justin.tv/d8a/migration/comparison"
)

const saveFrequency = 25

type anyAmazonError interface {
	error
	Code() string
}

type Replayer interface {
	ReplayCall(context context.Context, methodName string, jsonData []byte) (newResults []interface{}, err error)
	InterfaceName() string
	LogLevel() comparison.LogLevel
	ComparePreprocessor() comparison.ComparePreprocessor
}

func ExecuteReplay(replayer Replayer, statsdHostPort string, environment string, repo string, awsProfile string, region string, streamName string, consumerGroupName string) error {
	stats, err := statsd.NewBufferedClient(statsdHostPort,
		fmt.Sprintf("d8a-migration.%s.%s.%s", environment, extractRepoFromSvcname(repo), replayer.InterfaceName()),
		time.Second, 0)
	if err != nil {
		log.Fatalln(err)
	}

	awsConfig := &aws.Config{
		Credentials: getCredentialsFromProfile(awsProfile),
		Region:      aws.String(region),
	}

	consumerGroup := createConsumerGroup(awsConfig, streamName, consumerGroupName)

	streamClient := kinesis.New(session.New(), awsConfig)
	stream, err := streamClient.DescribeStream(&kinesis.DescribeStreamInput{
		StreamName: aws.String(streamName),
	})
	if err != nil {
		return err
	}

	for _, shard := range stream.StreamDescription.Shards {
		go readShard(replayer, stats, shard, streamClient, streamName, consumerGroup)
	}

	graceful.HandleSignals()
	graceful.Wait()
	return nil
}

func readShard(replayer Replayer, stats statsd.Statter, shard *kinesis.Shard, stream *kinesis.Kinesis, streamName string, consumerGroup *consumerGroup) {
	log.Println("Opening shard", shard.ShardId)
	var iterator = new(string)
	var recordID = ""
	err := getBestIterator(stream, consumerGroup, streamName, *shard.ShardId, iterator)
	if err != nil {
		log.Fatalf("Could not retrieve iterator for shard %s: %v\n", *shard.ShardId, err)
	}

	nextMandatorySave := time.Now().Add(time.Second)
	defer attemptSaveProgress(consumerGroup, *shard.ShardId, *iterator)

	log.Println("Shard", shard.ShardId, "open")

	recordsSinceLastSave := 0
	lastIteration := time.Now()
	for {
		if iterator == nil {
			log.Fatalf("Shard ID %s has closed unexpectedly\n", *shard.ShardId)
		}

		request := &kinesis.GetRecordsInput{
			ShardIterator: iterator,
			Limit:         aws.Int64(1),
		}
		recordset, err := stream.GetRecords(request)
		if err != nil {
			awsErr, ok := err.(anyAmazonError)
			if ok && awsErr.Code() == "ProvisionedThroughputExceededException" {
				//We're just hitting the API too fast... somehow
				log.Println("Exceeded provisioned kinesis replay capacity- pausing briefly...")
				time.Sleep(1 * time.Second)
				continue
			} else {
				log.Fatalf("Could not retrieve records for shard ID %s: %v\n", *shard.ShardId, err)
			}
		}

		stats.Timing("replay-lag", *recordset.MillisBehindLatest, StatsdSampleRate)
		for _, record := range recordset.Records {
			err = processRecord(replayer, stats, record)
			if err != nil {
				log.Println(err)
			}

			recordsSinceLastSave++
			recordID = *record.SequenceNumber
		}

		iterator = recordset.NextShardIterator

		if recordsSinceLastSave >= saveFrequency || (recordsSinceLastSave > 0 && time.Now().After(nextMandatorySave)) {
			attemptSaveProgress(consumerGroup, *shard.ShardId, recordID)
			log.Println("Saved")
			recordsSinceLastSave = 0
			nextMandatorySave = time.Now().Add(time.Second)
		}

		//Make sure no iteration is allowed to take less than 200ms, because we only get 5 reads per second from kinesis
		newIteration := time.Now()
		iterationLength := newIteration.Sub(lastIteration)
		remainingTime := (time.Millisecond * 200) - iterationLength

		if remainingTime > 0 {
			time.Sleep(remainingTime)
		}
		lastIteration = newIteration
	}
}

func processRecord(replayer Replayer, stats statsd.Statter, record *kinesis.Record) error {
	globalTimer := timing.Xact{
		Stats:            stats,
		StatsdSampleRate: StatsdSampleRate,
	}
	globalTimer.AddName("replay-record")
	globalTimer.Start()
	defer globalTimer.End("err")

	data := make(map[string]interface{})
	err := json.Unmarshal(record.Data, &data)
	if err != nil {
		return err
	}

	val, ok := data["Exp__Method"]
	if !ok {
		return missingFieldError("Exp__Method")
	}
	methodName, ok := val.(string)
	if !ok {
		return wrongTypeError("Exp__Method", "string", val)
	}

	val, ok = data["Exp__Duration"]
	if !ok {
		return missingFieldError("Exp__Duration")
	}
	durationNum, ok := val.(float64)
	if !ok {
		return wrongTypeError("Exp__Duration", "float64", val)
	}

	val, ok = data["ReturnVals"]
	if !ok {
		return missingFieldError("ReturnVals")
	}
	oldResults, ok := val.([]interface{})
	if !ok {
		return wrongTypeError("ReturnVals", "[]interface{}", val)
	}

	duration := time.Duration(durationNum)
	stats.TimingDuration(fmt.Sprintf("replay-record-original.%s.success", methodName), duration, StatsdSampleRate)

	replayTimer := timing.Xact{
		Stats:            stats,
		StatsdSampleRate: StatsdSampleRate,
	}
	replayTimer.AddName(fmt.Sprintf("replay-record.%s", methodName))
	replayTimer.Start()
	defer replayTimer.End("err")

	newResults, err := replayer.ReplayCall(context.Background(), methodName, record.Data)
	if err != nil {
		return err
	}

	oldComparable := comparison.ComparablesFromJSONObj(oldResults)
	newComparable, err := comparison.ComparablesFromStruct(newResults)
	if err != nil {
		globalTimer.End("success")
		return err
	}
	resultsSame, err := comparison.DeepCompare(methodName, replayer.ComparePreprocessor(), oldComparable, newComparable)
	if err != nil {
		log.Println(err)
		replayTimer.End("err")
	} else {
		replayTimer.End("success")

		if resultsSame {
			stats.Inc(fmt.Sprintf("replay-record-compare.%s.match", methodName), 1, StatsdSampleRate)
		} else {
			comparison.LogMismatch(methodName, record.Data, oldResults, newResults, replayer.LogLevel())
			stats.Inc(fmt.Sprintf("replay-record-compare.%s.mismatch", methodName), 1, StatsdSampleRate)
		}
	}

	globalTimer.End("success")
	return nil
}

func attemptSaveProgress(consumerGroup *consumerGroup, shardID string, progress string) {
	err := consumerGroup.SaveProgress(shardID, progress)
	if err != nil {
		fmt.Println(fmt.Errorf("Error saving progress for shard %s: %v", shardID, err))
	}
}

func getBestIterator(stream *kinesis.Kinesis, consumerGroup *consumerGroup, streamName string, shardID string, iterator *string) error {
	progress, err := consumerGroup.LoadProgress(shardID)
	if err != nil {
		return err
	}

	request := kinesis.GetShardIteratorInput{
		ShardId:    &shardID,
		StreamName: aws.String(streamName),
	}

	if progress == "" {
		request.ShardIteratorType = aws.String(kinesis.ShardIteratorTypeTrimHorizon)
	} else {
		request.ShardIteratorType = aws.String(kinesis.ShardIteratorTypeAfterSequenceNumber)
		request.StartingSequenceNumber = aws.String(progress)
	}

	iteratorResp, err := stream.GetShardIterator(&request)
	if err != nil {
		return err
	}

	*iterator = *iteratorResp.ShardIterator
	return nil
}

func missingFieldError(fieldName string) error {
	return fmt.Errorf("Kinesis record missing field '%s' - is something wrong with the stream?", fieldName)
}

func wrongTypeError(fieldName string, typeName string, val interface{}) error {
	realTypeName := reflect.TypeOf(val).Name()
	return fmt.Errorf("Kinesis record's '%s' field was not of type %s and instead was type %s - is somethign wrong with the stream?", fieldName, typeName, realTypeName)
}
