package main

import (
	"bufio"
	"flag"
	"io/ioutil"
	"log"
	"net/http"
	"os"
	"regexp"
	"runtime"
	"runtime/pprof"
	"runtime/trace"
	"strings"
	"time"

	"github.com/aws/aws-sdk-go/aws"
	awsClient "github.com/aws/aws-sdk-go/aws/client"
	awsCredentials "github.com/aws/aws-sdk-go/aws/credentials"
	awsSession "github.com/aws/aws-sdk-go/aws/session"
	awsS3 "github.com/aws/aws-sdk-go/service/s3"
	"golang.org/x/net/http2"

	"a.yandex-team.ru/security/osquery/osquery-sender/config"
	"a.yandex-team.ru/security/osquery/osquery-sender/s3"
)

const (
	minPartSizeForMerge = 256 * 1024 * 1024
)

var (
	configPath = flag.String("config", "/etc/osquery-sender/conf.yaml", "osquery-sender config")

	replaceRegexStr = flag.String("replace-regex", "", "regex to replace")
	replaceWith     = flag.String("replace-with", "", "text to replace regex with")

	s3Path      = flag.String("s3", "", "path to directory with objects to clean")
	fromDateStr = flag.String("from-date", "", "date to replace data from")
	toDateStr   = flag.String("to-date", "", "date to replace data to (including)")

	s3NoReplace = flag.Bool("s3-no-replace", false, "do not actually replace the objects, create the new ones with '-tmp' postfixes")

	cpuProfilePath = flag.String("cpuprofile", "", "write CPU profile to file")
	tracePath      = flag.String("trace", "", "write trace to file")
)

func main() {
	log.SetFlags(log.Lshortfile | log.Ldate | log.Ltime)
	log.SetPrefix("[osquery-sender cleaner] ")

	flag.Parse()

	conf, err := config.FromFile(*configPath)
	if err != nil {
		log.Fatalf("could not read config %s: %v\n", *configPath, err)
	}

	fromDate, err := time.Parse("2006-01-02", *fromDateStr)
	if err != nil {
		log.Fatalf("could not parse date %v: %v\n", *fromDateStr, err)
	}
	toDate, err := time.Parse("2006-01-02", *toDateStr)
	if err != nil {
		log.Fatalf("could not parse date %v: %v\n", *fromDateStr, err)
	}

	if *replaceRegexStr == "" {
		log.Fatalf("replace-regex is mandatory\n")
	}
	replaceRegex, err := regexp.Compile(*replaceRegexStr)
	if err != nil {
		log.Fatalf("could not parse replace-regex %v: %v", *replaceRegexStr, err)
	}
	if *replaceWith == "" {
		log.Printf("WARNING: replacing %v with empty string\n", *replaceRegexStr)
	}

	if *cpuProfilePath != "" {
		cpuProfileFile, err := os.Create(*cpuProfilePath)
		if err != nil {
			log.Fatalf("could not open CPU profile file: %v\n", err)
		}
		defer func() {
			_ = cpuProfileFile.Close()
		}()
		err = pprof.StartCPUProfile(cpuProfileFile)
		if err != nil {
			log.Fatalf("could not start CPU profile: %v\n", err)
		}
		defer pprof.StopCPUProfile()
	}

	if *tracePath != "" {
		traceFile, err := os.Create(*tracePath)
		if err != nil {
			log.Fatalf("could not open tracePath file: %v\n", err)
		}
		defer func() {
			_ = traceFile.Close()
		}()
		err = trace.Start(traceFile)
		if err != nil {
			log.Fatalf("could not start trace: %v\n", err)
		}
		defer trace.Stop()
	}

	secretAccessKeyBytes, err := ioutil.ReadFile(conf.S3.SecretAccessKeyFile)
	if err != nil {
		log.Fatalf("secret file %s error: %v\n", conf.S3.SecretAccessKeyFile, err)
	}
	secretAccessKey := strings.TrimSpace(string(secretAccessKeyBytes))
	credentials := awsCredentials.NewStaticCredentials(conf.S3.AccessKeyID, secretAccessKey, "")

	httpTransport := http.DefaultTransport.(*http.Transport).Clone()
	h2Transport, err := http2.ConfigureTransports(httpTransport)
	if err != nil {
		log.Fatal(err)
	}
	h2Transport.ReadIdleTimeout = time.Second * 30

	// awsLogger := aws.LoggerFunc(func(i ...interface{}) {
	// 	log.Println("AWS SDK:", fmt.Sprint(i...))
	// })
	session, err := awsSession.NewSessionWithOptions(awsSession.Options{
		Config: aws.Config{
			Credentials: credentials,
			Endpoint:    &conf.S3.Endpoint,
			Region:      &conf.S3.Region,
			Retryer: awsClient.DefaultRetryer{
				NumMaxRetries: 5,
				MinRetryDelay: time.Second * 5,
			},
			HTTPClient: &http.Client{
				Transport: httpTransport,
				Timeout:   time.Second * 30,
			},
			// Logger:   awsLogger,
			// LogLevel: aws.LogLevel(aws.LogDebugWithRequestErrors | aws.LogDebugWithRequestRetries),
		},
	})
	if err != nil {
		log.Fatal(err)
	}
	s3Client := awsS3.New(session)
	s3Manager := s3.NewS3Manager(s3Client, &s3.S3ManagerConfig{
		EnableDebug:        true,
		NumRetries:         5,
		NumDownloadWorkers: runtime.NumCPU(),
		NumUploadWorkers:   runtime.NumCPU(),
		NumGetInfoWorkers:  runtime.NumCPU(),
		MinUploadPartSize:  minPartSizeForMerge,
	})

	if *s3Path != "" {
		matchingObjects := listS3(s3Manager, conf.S3.Bucket, *s3Path, fromDate, toDate)
		for _, key := range matchingObjects {
			tmpKey := key + "-tmp"
			log.Printf("Replacing '%v' with '%v' in %v\n", *replaceRegexStr, *replaceWith, key)
			replaceS3(s3Manager, conf.S3.Bucket, key, tmpKey, replaceRegex, *replaceWith)
			if !*s3NoReplace {
				log.Printf("Moving %v -> %v\n", tmpKey, key)
				moveS3(s3Manager, conf.S3.Bucket, tmpKey, key)
			}
		}
	}
	log.Printf("All done!")
}

func listS3(manager *s3.S3Manager, s3Bucket string, path string, fromDate time.Time, toDate time.Time) []string {
	if !strings.HasSuffix(path, "/") {
		path = path + "/"
	}

	var keys []string
	err := manager.ListObjects(s3Bucket, &path, func(objects []*awsS3.Object) error {
		for _, object := range objects {
			key := *object.Key
			if !strings.HasPrefix(key, path) {
				continue
			}
			filename := key[len(path):]
			if strings.HasPrefix(filename, "merged/") {
				continue
			}

			var dateStr string
			idx := strings.Index(filename, "T")
			if idx != -1 {
				// The string is a timestamp
				dateStr = filename[:idx]
			} else {
				idx := strings.Index(filename, ".")
				if idx != -1 {
					dateStr = filename[:idx]
				} else {
					log.Fatalf("key %s is not a date or a timestamp, bailing out\n", key)
				}
			}
			date, err := time.Parse("2006-01-02", dateStr)
			if err != nil {
				log.Fatalf("key %s is not a date or a timestamp, bailing out: %v\n", key, err)
			}
			if date.Before(fromDate) || date.After(toDate) {
				continue
			}
			keys = append(keys, key)
		}
		return nil
	})
	if err != nil {
		log.Fatalf("failed to list objects: %v\n", err)
	}
	return keys
}

func replaceS3(manager *s3.S3Manager, bucket string, fromKey string, toKey string, from *regexp.Regexp, to string) {
	info, err := manager.GetInfos(bucket, []string{fromKey})
	if err != nil {
		log.Fatalf("failed to get info for %v: %v\n", fromKey, err)
	}
	metadata := info[0].Metadata
	alg, err := s3.ParseMetadataAlg(metadata)
	if err != nil {
		log.Fatalf("error getting algorithm in %s/%s: %v\n", bucket, fromKey, err)
	}

	downloader := manager.Download(bucket, fromKey)
	compressedReader, err := s3.NewCompressedReader(alg, downloader)
	if err != nil {
		log.Fatalf("error opening downloader for %s/%s: %v\n", bucket, fromKey, err)
	}
	uploader := manager.Upload(bucket, toKey, metadata)
	compressedWriter, err := s3.NewCompressedWriter(alg, uploader)
	if err != nil {
		log.Fatalf("error opening uploader for %s/%s: %v\n", bucket, toKey, err)
	}

	scanner := bufio.NewScanner(compressedReader)
	toBytes := []byte(to)
	totalLines := 0
	replaced := 0
	for scanner.Scan() {
		fromLine := scanner.Bytes()
		toLine := from.ReplaceAllFunc(fromLine, func(s []byte) []byte {
			replaced++
			return toBytes
		})
		_, err = compressedWriter.Write(toLine)
		if err != nil {
			log.Fatalf("error writing to %s/%s: %v\n", bucket, toKey, err)
		}
		_, err = compressedWriter.Write([]byte{'\n'})
		if err != nil {
			log.Fatalf("error writing to %s/%s: %v\n", bucket, toKey, err)
		}
		totalLines++
	}
	err = scanner.Err()
	if err != nil {
		log.Fatalf("error reading from %s/%s: %v\n", bucket, fromKey, err)
	}

	err = compressedWriter.Close()
	if err != nil {
		log.Fatalf("error when closing compressed writer for %s/%s: %v\n", bucket, toKey, err)
	}
	err = uploader.Close()
	if err != nil {
		log.Fatalf("error when closing compressed writer for %s/%s: %v\n", bucket, toKey, err)
	}

	log.Printf("Read %d lines, replaced %d, compression alg %s\n", totalLines, replaced, alg)
}

func moveS3(manager *s3.S3Manager, bucket string, fromKey string, toKey string) {
	err := manager.CopyObject(bucket, fromKey, toKey)
	if err != nil {
		log.Fatalf("failed to copy %s to %s\n", fromKey, toKey)
	}
	err = manager.DeleteObjects(bucket, []string{fromKey})
	if err != nil {
		log.Fatalf("failed to delete %s\n", fromKey)
	}
}
