package s3

import (
	"fmt"
	"log"
	"math"
	"math/rand"
	"runtime/debug"
	"runtime/pprof"
	"sort"
	"time"

	awsS3 "github.com/aws/aws-sdk-go/service/s3"
	"go.uber.org/atomic"

	"a.yandex-team.ru/security/osquery/osquery-sender/metrics"
	"a.yandex-team.ru/security/osquery/osquery-sender/util"
)

const (
	// Merge objects older than 1 day.
	mergeDaysThreshold = 2
)

type objectInfo struct {
	key       string
	batchTime time.Time
	alg       CompressionAlg
	// Different objects even from the same day may have different number of columns and/or reordered columns (this
	// should not happen now as sort, but could happen before).
	columns []string
}

type s3Merger struct {
	enableDebug bool

	manager *S3Manager
	bucket  string

	mergedAlg CompressionAlg

	loc *time.Location

	deleteMergedAfterDays int

	closed atomic.Bool

	mergerLoad *metrics.LoadReporter
}

func (s *s3Merger) Start() {
	go util.RunWithLabels(pprof.Labels("name", "s3-merge-daily-objects"), func() {
		s.mergeDailyObjects(s.deleteMergedAfterDays)
	})
}

func (s *s3Merger) Stop() {
	s.closed.Store(true)
}

func (s *s3Merger) mergeDailyObjects(dropMergedOlderThan int) {
	const minSleepTime = 100
	const maxSleepTime = 200
	// Sleep random amount of time to reduce the probability of collistion with another instance doing a merge.
	sleepTime := minSleepTime + rand.Intn(maxSleepTime-minSleepTime)
	log.Printf("Sleeping %d minutes before starting merging", sleepTime)
	time.Sleep(time.Minute * time.Duration(sleepTime))
	ticker := time.NewTicker(time.Hour * 24)
	for {
		if s.closed.Load() {
			return
		}

		startTime := time.Now()
		s.mergeDailyObjectsImpl()
		s.dropMergedObjectsImpl(dropMergedOlderThan)
		s.mergerLoad.ReportWork(time.Since(startTime))
		// No way to force the Ticker to tick immediately on start.
		<-ticker.C
	}
}

func (s *s3Merger) mergeDailyObjectsImpl() {
	defer func() {
		if r := recover(); r != nil {
			log.Println("ERROR: recovered in mergeDailyObjectsImpl ", r, string(debug.Stack()))
			metrics.S3FailuresDuringMerge.Inc()
		}
	}()

	log.Printf("started listing objects to merge")
	startTime := time.Now()
	thresholdDate := s.timeToDate(time.Now().Add(-time.Hour * 24 * mergeDaysThreshold))
	objectsToMerge := s.getObjectsOlderThan(thresholdDate)
	log.Printf("listed objects to merge in %v\n", time.Since(startTime))

	// Group objects in days. Randomize the order of object processing to reduce the probability of collision with
	// another instance doing a merge.
	var names []string
	for name := range objectsToMerge {
		names = append(names, name)
	}
	rand.Shuffle(len(names), func(i, j int) {
		names[i], names[j] = names[j], names[i]
	})

	totalObjects := 0
	for name, objects := range objectsToMerge {
		dateMap := map[int][]*objectInfo{}
		for _, object := range objects {
			date := s.timeToDate(object.batchTime)
			dateMap[date] = append(dateMap[date], object)
		}

		for date, dateObjects := range dateMap {
			s.mergeOneDailyObject(name, date, dateObjects)
		}
		totalObjects += len(objects)
	}

	log.Printf("merged %d objects in %v\n", totalObjects, time.Since(startTime))
}

func (s *s3Merger) dropMergedObjectsImpl(dropMergedOlderThan int) {
	startTime := time.Now()
	thresholdDate := s.timeToDate(time.Now().Add(-time.Hour * 24 * time.Duration(dropMergedOlderThan)))
	keys := s.getMergedObjectsOlderThan(thresholdDate)
	log.Printf("%d merged objects are older than %d, deleting\n", len(keys), thresholdDate)

	err := s.manager.DeleteObjects(s.bucket, keys)
	if err != nil {
		log.Printf("ERROR: error deleting old merged objects: %v\n", errToString(err))
		return
	}

	metrics.S3Deletions.Add(int64(len(keys)))
	log.Printf("deleted %d merged objects older than %d in %v\n", len(keys), thresholdDate, time.Since(startTime))
}

func (s *s3Merger) timeToDate(t time.Time) int {
	year, month, day := t.In(s.loc).Date()
	return year*10000 + int(month)*100 + day
}

func (s *s3Merger) getObjectsOlderThan(thresholdDate int) map[string][]*objectInfo {
	result := map[string][]*objectInfo{}
	err := s.manager.ListObjects(s.bucket, nil, func(objects []*awsS3.Object) error {
		var keys []string
		for _, o := range objects {
			key := *o.Key
			if isMergedKey(key) {
				continue
			}
			keys = append(keys, key)
		}

		infos, err := s.manager.GetInfos(s.bucket, keys)
		if err != nil {
			return fmt.Errorf("error getting infos of objects older than: %v", errToString(err))
		}
		for i, info := range infos {
			if info == nil {
				continue
			}

			name, err := parseNameFromKey(keys[i])
			if err != nil {
				return fmt.Errorf("error parsing key name in %s/%s: %v", s.bucket, keys[i], err)
			}

			batchTime, err := parseMetadataBatchTime(info.Metadata)
			if err != nil {
				return fmt.Errorf("error parsing batch time in %s/%s: %v", s.bucket, keys[i], err)
			}

			alg, err := ParseMetadataAlg(info.Metadata)
			if err != nil {
				return fmt.Errorf("error getting algorithm in %s/%s: %v", s.bucket, keys[i], err)
			}

			columns, err := ParseMetadataColumns(info.Metadata)
			if err != nil {
				return fmt.Errorf("error getting columns in %s/%s: %v", s.bucket, keys[i], err)
			}
			columns = addPredefinedColumns(columns)

			if s.timeToDate(batchTime) <= thresholdDate {
				result[name] = append(result[name], &objectInfo{
					key:       keys[i],
					batchTime: batchTime,
					alg:       alg,
					columns:   columns,
				})
			}
		}

		return nil
	})
	if err != nil {
		log.Printf("ERROR: getting objects to merge failed: %v\n", errToString(err))
		metrics.S3FailuresDuringMerge.Inc()
		return nil
	}
	return result
}

func (s *s3Merger) getMergedObjectsOlderThan(thresholdDate int) []string {
	var result []string
	err := s.manager.ListObjects(s.bucket, nil, func(objects []*awsS3.Object) error {
		for _, object := range objects {
			key := *object.Key
			if !isMergedKey(key) {
				continue
			}

			date := s.getMergedObjectDate(key)
			if date < thresholdDate {
				result = append(result, key)
			}
		}
		return nil
	})
	if err != nil {
		log.Printf("ERROR: getting merged objects to delete failed: %v\n", errToString(err))
		metrics.S3FailuresDuringMerge.Inc()
		return nil
	}
	return result
}

func (s *s3Merger) mergeOneDailyObject(name string, date int, objects []*objectInfo) {
	startTime := time.Now()
	log.Printf("merging %d objects for %s, date %d\n", len(objects), name, date)

	// NOTE: The process of merging old objects consists of 3 steps:
	//  1. List all old objects to merge.
	//  2. Create a new object from old objects.
	//  3. Delete old objects that have been merged.
	//
	// The process can get interrupted at any step or two instances may step on each other while running the process.
	// The consistency is guaranteed by two reasons:
	//  1. If the merging process is run on the same set of objects twice, the end result will be identical. This allows
	//  concurrently running merges to overwrite the merged object. If one of the merges will start deleting the
	//  objects, the other one will silently abort the merge.
	//  2. If the merging process fails after completing the step 2, the next rerun will check that the merged object
	//  already exists and proceed with step 3. Note that this relies on strong consistency in S3 object metabase,
	//  otherwise e.g. deletion of source objects may get reordered with creation of merged object and the merging
	//  process will try to re-merge the remaining source objects, overwriting the merged object.
	mergedKey := MakeMergedKey(name, date, s.mergedAlg)
	if s.manager.ObjectExists(s.bucket, mergedKey) {
		log.Printf("merged object %s already exists, removing the stale %d objects\n", mergedKey, len(objects))
		s.deleteObjectInfos(objects)
		return
	}

	sort.Slice(objects, func(i, j int) bool {
		return objects[i].batchTime.Before(objects[j].batchTime)
	})

	mergedColumns := mergeColumns(objects)
	metadata := MakeMergedMetadata(date, s.mergedAlg, mergedColumns)
	if s.enableDebug {
		log.Printf("creating upload to %s", mergedKey)
	}

	uploader := s.manager.Upload(s.bucket, mergedKey, metadata)
	merger, err := NewTsvMerger(mergedColumns, s.mergedAlg, uploader)
	if err != nil {
		log.Printf("ERROR: %v\n", errToString(err))
		metrics.S3FailuresDuringMerge.Inc()
		uploader.Abort()
		return
	}

	for _, object := range objects {
		objectReader := s.manager.Download(s.bucket, object.key)
		err = merger.Append(object.columns, object.alg, objectReader)
		if err != nil {
			// The object may have been deleted by another instance.
			if s.manager.ObjectExists(s.bucket, object.key) {
				log.Printf("ERROR: merging %s into %s failed: %v\n", object.key, mergedKey, errToString(err))
				metrics.S3FailuresDuringMerge.Inc()
			} else {
				log.Printf("WARNING: aborting merge into %s: %v\n", mergedKey, errToString(err))
			}
			uploader.Abort()
			return
		}
	}

	err = merger.Close()
	if err != nil {
		log.Printf("ERROR: closing TSV merger for %s failed: %v\n", mergedKey, errToString(err))
		metrics.S3FailuresDuringMerge.Inc()
		uploader.Abort()
		return
	}

	err = uploader.Close()
	if err != nil {
		log.Printf("ERROR: closing %s failed: %v\n", mergedKey, errToString(err))
		uploader.Abort()
		metrics.S3FailedRetries.Inc()
		return
	}

	log.Printf("merged %d objects into %s (%dkb original, %dkb compressed) in %v\n", len(objects), mergedKey,
		merger.TotalBytes()/1024, merger.CompressedBytes()/1024, time.Since(startTime))
	metrics.S3Merges.Inc()
	metrics.S3WrittenBytes.Add(merger.CompressedBytes())

	s.deleteObjectInfos(objects)
}

func mergeColumns(objects []*objectInfo) []string {
	columnSet := map[string]bool{}
	var result []string

	for _, object := range objects {
		for _, column := range object.columns {
			_, ok := columnSet[column]
			if !ok {
				columnSet[column] = true
				result = append(result, column)
			}
		}
	}

	return result
}

func (s *s3Merger) deleteObjectInfos(objects []*objectInfo) {
	var keys []string
	for _, object := range objects {
		keys = append(keys, object.key)
	}
	err := s.manager.DeleteObjects(s.bucket, keys)
	if err != nil {
		log.Printf("ERROR: failed deleting objects: %v\n", errToString(err))
		return
	}

	metrics.S3Deletions.Add(int64(len(objects)))
}

func (s *s3Merger) getMergedObjectDate(key string) int {
	info, err := s.manager.GetInfos(s.bucket, []string{key})
	if err != nil {
		log.Printf("ERROR: get info for %s/%s failed: %v", s.bucket, key, errToString(err))
		return math.MaxInt32
	}
	if info[0] == nil {
		return math.MaxInt32
	}

	date, err := parseMetadataDate(info[0].Metadata)
	if err != nil {
		log.Printf("ERROR: merged object %s/%s metadata is incorrect: %v", s.bucket, key, err)
		metrics.S3FailuresDuringMerge.Inc()
		return math.MaxInt32
	}
	return date
}

func (s *s3Merger) UpdateMetrics() {
	metrics.S3MergerLoad.Store(s.mergerLoad.GetAverage())
}
