package kinesis

import (
	"context"
	"encoding/json"
	"log"
	"math/rand"
	"regexp"
	"runtime"
	"runtime/debug"
	"runtime/pprof"
	"strconv"
	"sync"
	"time"

	awsKinesis "github.com/aws/aws-sdk-go/service/kinesis"

	"a.yandex-team.ru/security/osquery/osquery-sender/awscommon"
	"a.yandex-team.ru/security/osquery/osquery-sender/config"
	"a.yandex-team.ru/security/osquery/osquery-sender/metrics"
	"a.yandex-team.ru/security/osquery/osquery-sender/parser"
	"a.yandex-team.ru/security/osquery/osquery-sender/util"
)

const (
	defaultMaxLogsPerBatch = 100
	defaultQueueLength     = 1000
	requestTimeout         = time.Second * 5
	defaultBackoffOnFail   = time.Millisecond * 250
	backoffJitter          = 0.25

	reportLoadPeriod = time.Second * 30
)

type RecordsBatch []*awsKinesis.PutRecordsRequestEntry

type workerJob struct {
	// Serialized events, at most maxLogsPerBatch elements.
	records RecordsBatch
}

type KinesisSender struct {
	client          *awsKinesis.Kinesis
	stream          string
	maxLogsPerBatch int

	percent         int
	percentPerName  map[string]int
	sendAllForHosts *regexp.Regexp

	backoffOnFail time.Duration
	enableDebug   bool

	queue      chan workerJob
	maxWorkers int
	workerLoad []*metrics.LoadReporter
	workerWg   sync.WaitGroup
}

func (s *KinesisSender) Start() {
	s.workerWg.Add(s.maxWorkers)
	for i := 0; i < s.maxWorkers; i++ {
		i := i
		go util.RunWithLabels(pprof.Labels("name", "kinesis-worker-"+strconv.Itoa(i)), func() {
			s.workerRun(i)
		})
	}
}

func (s *KinesisSender) Stop() {
	log.Printf("stopping Kinesis sender\n")
	close(s.queue)
	s.workerWg.Wait()
	log.Printf("stopped Kinesis sender\n")
}

func (s *KinesisSender) Send(events []*parser.ParsedEvent) {
	if len(events) == 0 {
		return
	}
	sendAllForHost := s.sendAllForHosts.MatchString(events[0].Host)
	for eventName, batches := range s.splitAndSerialize(events) {
		percentForEvent := s.getPercentForEvent(eventName)
		for _, batch := range batches {
			if !sendAllForHost && rand.Int31n(100) >= percentForEvent {
				continue
			}

			job := workerJob{records: batch}
			s.putJobToQueue(job)
		}
	}
	s.updateMetricsMaxValues()
}

func (s *KinesisSender) UpdateMetrics() {
	s.updateMetricsMaxValues()
	loads := make([]float64, len(s.workerLoad))
	for i, load := range s.workerLoad {
		loads[i] = load.GetAverage()
	}
	metrics.SetKinesisWorkerLoads(loads)
}

func (s *KinesisSender) workerRun(workerNum int) {
	workerLoad := s.workerLoad[workerNum]
	for job := range s.queue {
		startTime := time.Now()
		s.workerJob(&job)
		deltaTime := time.Since(startTime)
		workerLoad.ReportWork(deltaTime)
		metrics.KinesisRequestTimeMillis.Add(deltaTime.Milliseconds())
	}
	s.workerWg.Done()
}

func (s *KinesisSender) workerJob(job *workerJob) {
	defer func() {
		if r := recover(); r != nil {
			log.Println("ERROR: recovered in workerJob ", r, string(debug.Stack()))
			metrics.KinesisErrors.Inc()
		}
	}()

	startTime := time.Now()
	ctx, cancel := context.WithTimeout(context.Background(), requestTimeout)
	defer cancel()
	out, err := s.client.PutRecordsWithContext(ctx, &awsKinesis.PutRecordsInput{
		Records:    job.records,
		StreamName: &s.stream,
	})

	var retryRecords []*awsKinesis.PutRecordsRequestEntry
	errorMap := make(map[string]int, len(job.records))
	if err != nil {
		retryRecords = job.records
	} else {
		if out.FailedRecordCount != nil && *out.FailedRecordCount > 0 {
			for i, o := range out.Records {
				if o.ErrorCode != nil {
					errorKey := *o.ErrorCode + ": " + *o.ErrorMessage
					errorMap[errorKey] += 1
					retryRecords = append(retryRecords, job.records[i])
				}
			}
		}
	}

	if len(retryRecords) > 0 {
		metrics.KinesisErrors.Inc()
		// Most likely we are submitting too fast, sleep for a bit.
		backoff := s.getBackoffTime()
		failedRecordCount := len(job.records)
		if out != nil && out.FailedRecordCount != nil {
			failedRecordCount = int(*out.FailedRecordCount)
		}
		log.Printf("WARN: could not put %d records: %v, retry %d entries, sleep for %v\n",
			failedRecordCount, errorMap, len(retryRecords), backoff)
		time.Sleep(backoff)

		// Try putting the job back to the queue.
		s.putJobToQueue(workerJob{records: retryRecords})
		return
	}

	deltaTime := time.Since(startTime)
	metrics.KinesisRecords.Inc()
	if s.enableDebug {
		log.Printf("put %d records in %v\n", len(job.records), deltaTime)
	}
}

func (s *KinesisSender) putJobToQueue(job workerJob) {
	select {
	case s.queue <- job:
	default:
		log.Printf("ERROR: kinesis queue is full")
		metrics.KinesisDroppedDueToFullQueue.Inc()
	}
}

func (s *KinesisSender) getBackoffTime() time.Duration {
	millis := float64(s.backoffOnFail.Milliseconds()) * (1.0 + rand.Float64()*backoffJitter)
	return time.Millisecond * time.Duration(millis)
}

func (s *KinesisSender) updateMetricsMaxValues() {
	metrics.KinesisQueueLen.Report(uint64(len(s.queue)))
}

func (s *KinesisSender) getPercentForEvent(name string) int32 {
	if percent, ok := s.percentPerName[name]; ok {
		return int32(percent)
	}
	return int32(s.percent)
}

func NewSender(config *config.KinesisConfig, enableDebug bool) (*KinesisSender, error) {
	if config == nil {
		return nil, nil
	}
	enableDebug = enableDebug || config.EnableDebug

	session, err := awscommon.NewSession(&awscommon.SessionParams{
		Endpoint:            config.Endpoint,
		Region:              config.Region,
		AccessKeyID:         config.AccessKeyID,
		SecretAccessKeyFile: config.SecretAccessKeyFile,
		NumRetries:          1,
		RequestTimeout:      requestTimeout,
		EnableVerboseDebug:  config.EnableVerboseDebug,
	})
	if err != nil {
		return nil, err
	}
	kinesisClient := awsKinesis.New(session)

	maxLogsPerBatch := defaultMaxLogsPerBatch
	if config.MaxLogsPerBatch != 0 {
		maxLogsPerBatch = config.MaxLogsPerBatch
	}

	percent := 100
	if config.Percent != nil {
		percent = *config.Percent
	}
	sendAllForHosts := util.GlobsToRegexp(config.SendAllForHosts)

	queueLength := config.QueueLength
	if queueLength == 0 {
		queueLength = defaultQueueLength
	}
	maxWorkers := config.MaxWorkers
	if maxWorkers == 0 {
		maxWorkers = runtime.NumCPU()
	}
	backoffOnFail := time.Millisecond * time.Duration(config.BackoffOnFailMillis)
	if backoffOnFail == 0 {
		backoffOnFail = defaultBackoffOnFail
	}

	ret := &KinesisSender{
		client:          kinesisClient,
		stream:          config.Stream,
		maxLogsPerBatch: maxLogsPerBatch,

		percent:         percent,
		percentPerName:  config.PercentPerEventName,
		sendAllForHosts: sendAllForHosts,

		backoffOnFail: backoffOnFail,
		enableDebug:   enableDebug,

		queue:      make(chan workerJob, queueLength),
		maxWorkers: maxWorkers,
	}
	for i := 0; i < maxWorkers; i++ {
		ret.workerLoad = append(ret.workerLoad, metrics.NewLoadReporter(reportLoadPeriod))
	}
	return ret, nil
}

// Split event per event name and into at most maxLogsPerBatch batches.
func (s *KinesisSender) splitAndSerialize(logs []*parser.ParsedEvent) map[string][]RecordsBatch {
	ret := make(map[string][]RecordsBatch)
	curBatch := make(map[string]RecordsBatch)
	for _, event := range logs {
		batch := curBatch[event.Name]
		// Do not reference the whole input JSON for one field.
		copiedHost := util.CopyString(event.Host)
		batch = append(batch, &awsKinesis.PutRecordsRequestEntry{
			Data:         serializeData(event.Data, event.LogType),
			PartitionKey: &copiedHost,
		})
		if len(batch) >= s.maxLogsPerBatch {
			ret[event.Name] = append(ret[event.Name], batch)
			curBatch[event.Name] = nil
		} else {
			// Hello, value semantics for Go slices, we meet again.
			curBatch[event.Name] = batch
		}
	}
	for name, batch := range curBatch {
		ret[name] = append(ret[name], batch)
	}
	return ret
}

func serializeData(data map[string]interface{}, logType string) []byte {
	newData := make(map[string]interface{}, len(data)+1)
	for k, v := range data {
		newData[k] = v
	}
	newData["log_type"] = logType
	ret, err := json.Marshal(newData)
	if err != nil {
		panic(err)
	}
	return ret
}
