package s3

import (
	"fmt"
	"log"
	"math"
	"runtime"
	"runtime/debug"
	"runtime/pprof"
	"sort"
	"time"

	awsS3 "github.com/aws/aws-sdk-go/service/s3"
	"github.com/c2h5oh/datasize"
	"go.uber.org/atomic"

	"a.yandex-team.ru/security/osquery/osquery-sender/awscommon"
	"a.yandex-team.ru/security/osquery/osquery-sender/batcher"
	"a.yandex-team.ru/security/osquery/osquery-sender/config"
	"a.yandex-team.ru/security/osquery/osquery-sender/metrics"
	"a.yandex-team.ru/security/osquery/osquery-sender/parser"
	"a.yandex-team.ru/security/osquery/osquery-sender/util"
)

type S3Sender struct {
	enableDebug bool

	manager *S3Manager
	bucket  string

	alg CompressionAlg

	loc *time.Location

	merger *s3Merger

	workers *batcher.BatcherWorkers

	closed atomic.Bool
}

const (
	defaultCompressionAlg = CompressionSnappy

	// Use 1Gb as a sane default.
	defaultMaxMemory = datasize.GB
	defaultMaxDelay  = time.Second * 30

	defaultNumRetries = 5

	// 16Mb for partNumber for regular uploads
	minPartSize = 16 * 1024 * 1024
	// Use bigger parts for merger, otherwise we hit the partNumber limit
	// (see https://docs.aws.amazon.com/AmazonS3/latest/userguide/qfacts.html and https://st.yandex-team.ru/CLOUD-75886)
	minPartSizeForMerge = 256 * 1024 * 1024

	// Refresh S3 bucket sizes each 60 minutes.
	refreshSizesInterval = time.Minute * 60
)

func NewSender(config *config.S3Config, addDecorators []string, enableDebug bool) (*S3Sender, error) {
	if config == nil {
		return nil, nil
	}
	enableDebug = enableDebug || config.EnableDebug

	maxMemory := config.MaxMemory
	if maxMemory == 0 {
		maxMemory = defaultMaxMemory
	}
	var maxDelay time.Duration
	var err error
	if config.MaxDelay != "" {
		maxDelay, err = time.ParseDuration(config.MaxDelay)
		if err != nil {
			return nil, fmt.Errorf("could not parse max_delay: %v", err)
		}
	} else {
		maxDelay = defaultMaxDelay
	}
	maxWorkers := config.MaxWorkers
	if maxWorkers == 0 {
		maxWorkers = runtime.NumCPU()
	}

	numRetries := config.NumRetries
	if numRetries == 0 {
		numRetries = defaultNumRetries
	}

	session, err := awscommon.NewSession(&awscommon.SessionParams{
		Endpoint:            config.Endpoint,
		Region:              config.Region,
		AccessKeyID:         config.AccessKeyID,
		SecretAccessKeyFile: config.SecretAccessKeyFile,
		NumRetries:          numRetries,
		RequestTimeout:      requestTimeout,
		EnableVerboseDebug:  config.EnableVerboseDebug,
	})
	if err != nil {
		return nil, err
	}
	s3Client := awsS3.New(session)

	alg, err := ParseCompressionAlg(config.Compression)
	if err != nil {
		return nil, err
	}

	mergeDaily := true
	if config.MergeDaily != nil {
		mergeDaily = *config.MergeDaily
	}

	loc := time.Local
	if config.Timezone != "" {
		loc, err = time.LoadLocation(config.Timezone)
		if err != nil {
			return nil, err
		}
	}

	var merger *s3Merger
	if mergeDaily {
		deleteMergedAfterDays := 730
		if config.DeleteMergedAfterDays != nil {
			deleteMergedAfterDays = *config.DeleteMergedAfterDays
		}

		mergerManager := NewS3Manager(s3Client, &S3ManagerConfig{
			EnableDebug: enableDebug,
			// We retry harder when merging because otherwise the whole merging process restarts on each error.
			NumRetries:         numRetries,
			NumDownloadWorkers: 1,
			NumUploadWorkers:   1,
			NumGetInfoWorkers:  maxWorkers,
			MinUploadPartSize:  minPartSizeForMerge,
		})
		merger = &s3Merger{
			enableDebug:           enableDebug,
			bucket:                config.Bucket,
			manager:               mergerManager,
			mergedAlg:             alg,
			loc:                   loc,
			deleteMergedAfterDays: deleteMergedAfterDays,
			mergerLoad:            metrics.NewLoadReporter(time.Hour * 24),
		}
	}

	manager := NewS3Manager(s3Client, &S3ManagerConfig{
		EnableDebug: enableDebug,
		// There are automatic retries in BatcherWorkers, do not retry at S3 level.
		NumRetries:        1,
		NumUploadWorkers:  maxWorkers,
		MinUploadPartSize: minPartSize,
	})
	ret := &S3Sender{
		enableDebug: enableDebug,
		manager:     manager,
		bucket:      config.Bucket,
		alg:         alg,
		loc:         loc,
		merger:      merger,
	}

	eventBatcher := batcher.New(config.RemovePrefix, config.RemoveSuffix, true, addDecorators)
	workersConfig := batcher.WorkersConfig{
		MaxMemory:  int64(maxMemory.Bytes()),
		MaxDelay:   maxDelay,
		MaxWorkers: maxWorkers,
		// This is important for later merging of objects: we assume that either the merged object consists of several
		// smaller objects.
		SplitDays: true,
		SplitLoc:  loc,
	}
	workers := batcher.NewWorkers(eventBatcher, workersConfig, ret, "s3")
	ret.workers = workers

	return ret, nil
}

func (s *S3Sender) Start() {
	s.workers.Start()

	if s.merger != nil {
		s.merger.Start()
	}
	go util.RunWithLabels(pprof.Labels("name", "s3-refresh-sizes"), func() {
		s.refreshSizes()
	})
}

func (s *S3Sender) Stop() {
	log.Printf("stopping S3 sender\n")
	s.workers.Stop()
	// Close only after the workers have stopped, otherwise they will not be able to submit the upload parts.
	s.manager.Stop()
	s.closed.Store(true)
	if s.merger != nil {
		s.merger.Stop()
	}
	log.Printf("stopped S3 sender\n")
}

func (s *S3Sender) UpdateMetrics() {
	s.updateMetricsMaxValues()
	metrics.SetS3WorkerLoads(s.workers.Load())
	if s.merger != nil {
		s.merger.UpdateMetrics()
	}
}

func (s *S3Sender) updateMetricsMaxValues() {
	metrics.S3QueueLen.Report(uint64(s.workers.QueueLen()))
	metrics.S3QueueMemorySize.Report(uint64(s.workers.QueueMemory()))
	metrics.S3BatcherMemorySize.Report(uint64(s.workers.BatcherMemory()))
}

func (s *S3Sender) OnDropDueToFullMemory() {
	metrics.S3DroppedDueToFullMemory.Inc()
}

func (s *S3Sender) OnDropDueToFullQueue() {
	metrics.S3DroppedDueToFullQueue.Inc()
}

func (s *S3Sender) OnFlushDueToMemoryPressure() {
	metrics.S3FlushedDueToPressure.Inc()
}

func (s *S3Sender) OnDropAfterRetries() {
	metrics.S3DroppedAfterRetries.Inc()
}

func (s *S3Sender) Enqueue(events []*parser.ParsedEvent) {
	if s.enableDebug {
		log.Printf("enqueueing %d events to s3\n", len(events))
	}

	s.workers.Enqueue(events)
	s.updateMetricsMaxValues()
}

func (s *S3Sender) TotalMemory() int64 {
	return s.workers.BatcherMemory() + s.workers.QueueMemory()
}

func (s *S3Sender) SubmitEvents(name string, events *batcher.EventBatch) error {
	if s.enableDebug {
		log.Printf("sending %d events with name %s to s3\n", events.Length, name)
	}
	if events.Length == 0 {
		return nil
	}

	startTime := time.Now()
	batchTime := getBatchTime(events)
	columns := getColumns(events)
	metadata := MakeMetadata(batchTime, s.alg, columns)
	key := MakeKey(name, batchTime, s.loc, s.alg)
	if s.enableDebug {
		log.Printf("uploading to %s with metadata %v\n", key, metadata)
	}

	uploader := s.manager.Upload(s.bucket, key, metadata)
	writer, err := NewTsvWriter(columns, s.alg, uploader)
	if err != nil {
		log.Printf("ERROR: %v\n", errToString(err))
		metrics.S3FailedRetries.Inc()
		return err
	}

	err = writer.Write(events)
	if err != nil {
		log.Printf("ERROR: writing to %s/%s failed %v\n", s.bucket, key, errToString(err))
		uploader.Abort()
		metrics.S3FailedRetries.Inc()
		return err
	}

	err = writer.Close()
	if err != nil {
		log.Printf("ERROR: closing TSV writer for %s/%s failed %v\n", s.bucket, key, errToString(err))
		uploader.Abort()
		metrics.S3FailedRetries.Inc()
		return err
	}

	err = uploader.Close()
	if err != nil {
		log.Printf("ERROR: closing %s/%s failed %v\n", s.bucket, key, errToString(err))
		uploader.Abort()
		metrics.S3FailedRetries.Inc()
		return err
	}

	if s.enableDebug {
		log.Printf("uploaded %d events with %d columns to %s (%dkb original, %dkb compressed) in %v\n",
			events.Length, events.NumColumns, key, writer.TotalBytes()/1024,
			writer.CompressedBytes()/1024, time.Since(startTime))
	}
	metrics.S3Writes.Inc()
	metrics.S3WrittenBytes.Add(writer.CompressedBytes())
	return nil
}

func getBatchTime(events *batcher.EventBatch) time.Time {
	minTime := int64(math.MaxInt64)
	for i := 0; i < events.Length; i++ {
		if events.Timestamps[i] < minTime {
			minTime = events.Timestamps[i]
		}
	}
	return time.Unix(minTime, 0)
}

func getColumns(events *batcher.EventBatch) []string {
	columns := make([]string, 0, events.NumColumns)
	for column := range events.Float64Values {
		columns = append(columns, column)
	}
	for column := range events.StringValues {
		columns = append(columns, column)
	}
	sort.Strings(columns)
	return columns
}

func (s *S3Sender) refreshSizes() {
	s.updateFolderSizes()

	ticker := time.NewTicker(refreshSizesInterval)
	defer ticker.Stop()
	for range ticker.C {
		if s.closed.Load() {
			return
		}
		s.updateFolderSizes()
	}
}

func (s *S3Sender) updateFolderSizes() {
	defer func() {
		if r := recover(); r != nil {
			log.Println("ERROR: recovered in updateFolderSizes ", r, string(debug.Stack()))
		}
	}()

	log.Printf("started listing folder sizes")
	startTime := time.Now()
	sizes := map[string]int64{}
	counts := map[string]int64{}
	totalObjects := int64(0)
	mergedObjects := int64(0)

	err := s.manager.ListObjects(s.bucket, nil, func(objects []*awsS3.Object) error {
		for _, object := range objects {
			name, err := parseNameFromKey(*object.Key)
			if err != nil {
				log.Printf("ERROR: parsing key when listing objects: %v\n", errToString(err))
				continue
			}
			sizes[name] += *object.Size
			counts[name]++
			totalObjects++
			if isMergedKey(*object.Key) {
				mergedObjects++
			}
		}
		return nil
	})
	if err != nil {
		log.Printf("ERROR: getting folder sizes failed: %v", errToString(err))
	}

	log.Printf("got %d folder sizes in %v (%d total objects, %d merged objects)\n",
		len(sizes), time.Since(startTime), totalObjects, mergedObjects)
	metrics.SetS3FolderSizes(sizes)
	metrics.SetS3FolderObjects(counts)
	metrics.S3MergedObjects.Store(mergedObjects)
}
