package main

import (
	"encoding/hex"
	"flag"
	"fmt"
	"log"
	"math/rand"
	"net/url"
	"os"
	"runtime"
	"runtime/pprof"
	"strconv"
	"strings"
	"sync"
	"sync/atomic"
	"time"

	"github.com/c2h5oh/datasize"

	"a.yandex-team.ru/security/osquery/osquery-sender/clickhouse"
	"a.yandex-team.ru/security/osquery/osquery-sender/config"
	"a.yandex-team.ru/security/osquery/osquery-sender/metrics"
	"a.yandex-team.ru/security/osquery/osquery-sender/parser"
	"a.yandex-team.ru/security/osquery/osquery-sender/s3"
)

const (
	fimPackName = "this_is_a_fim_event"
)

var (
	chHosts        = flag.String("ch-hosts", "", "clickhouse hosts to connect to")
	chParams       = flag.String("ch-params", "", "clickhouse params in URL form")
	chPasswordFile = flag.String("ch-password", "", "file with clickhouse password")

	s3Endpoint     = flag.String("s3-endpoint", "", "s3 endpoint, e.g. storage.yandexcloud.net")
	s3AccessKeyID  = flag.String("s3-access-key", "", "s3 access key id")
	s3SecretFile   = flag.String("s3-secret-file", "", "file with s3 secret access key")
	s3Bucket       = flag.String("s3-bucket", "", "s3 bucket")
	s3Region       = flag.String("s3-region", "ru-central1", "s3 region")
	s3Compression  = flag.String("s3-compression", "", "compression algorithm to use with S3")
	s3VerboseDebug = flag.Bool("s3-verbose-debug", false, "log all S3 requests")

	numThreads     = flag.Int("num-threads", 1, "number of threads")
	numEvents      = flag.Int("num-events", 1000, "number of events to send in one batch")
	numBatches     = flag.Int("num-batches", 100, "number of batches")
	maxMemoryMb    = flag.Int("max-memory", 1000, "maximum memory in megabytes")
	maxDelay       = flag.String("max-delay", "", "maximum delay in seconds")
	largeEvents    = flag.Bool("large-events", false, "use larger events")
	numEventNames  = flag.Int("num-event-names", 1, "number of different event names")
	cpuProfilePath = flag.String("cpuprofile", "", "path to write CPU profile to")
	memProfilePath = flag.String("memprofile", "", "path to write heap profile to")
	logVerbose     = flag.Bool("log-verbose", false, "log times for each batch")

	eventNames []string

	hostComponents = []string{"yc", "cloud", "vla", "sas", "myt", "prod", "preprod", "osquery",
		"dev", "host", "kms", "billing", "compute", "mdb", "iam", "serverless"}
	pathComponents = []string{"bin", "home", "usr", "var", "lock", "load", "hack", "test",
		"very", "much", "include", "docker", "kubelet", "daemon", "systemd", "modules"}

	// strconv.Itoa actually takes quite a lot of time, precompute the
	randomNums []string
)

func main() {
	flag.Parse()

	eventNames = append(eventNames, fimPackName)
	for i := 1; i < *numEventNames; i++ {
		eventNames = append(eventNames, fmt.Sprintf("%s_%d", fimPackName, i))
	}

	initRandom()

	if *cpuProfilePath != "" {
		f, err := os.Create(*cpuProfilePath)
		if err != nil {
			log.Fatal("could not create CPU profile: ", err)
		}
		defer func() {
			_ = f.Close()
		}()
		if err := pprof.StartCPUProfile(f); err != nil {
			log.Fatal("could not start CPU profile: ", err)
		}
		defer pprof.StopCPUProfile()
	}

	startTime := time.Now()

	maxMemory := datasize.MB * datasize.ByteSize(*maxMemoryMb)

	var chSender *clickhouse.ClickhouseSender
	var s3Sender *s3.S3Sender
	if *chHosts != "" {
		log.Printf("Benchmarking Clickhouse")
		connectionParams, err := parseChParams(*chParams)
		if err != nil {
			log.Fatalf("Error parsing params: %v\n", err)
		}
		chSender, err = clickhouse.NewSender(&config.ClickhouseConfig{
			Hosts:            strings.Split(*chHosts, ","),
			ConnectionParams: connectionParams,
			PasswordFile:     *chPasswordFile,
			EventNames:       eventNames,
			MaxConnections:   *numThreads,
			MaxMemory:        maxMemory,
			MaxDelay:         *maxDelay,
		}, nil, *logVerbose)
		if err != nil {
			log.Fatalf("Error connecting to clickhouse: %v\n", err)
		}
		chSender.Start()
	} else {
		log.Printf("Benchmarking S3")
		var err error
		s3Sender, err = s3.NewSender(&config.S3Config{
			AccessKeyID:         *s3AccessKeyID,
			SecretAccessKeyFile: *s3SecretFile,
			Endpoint:            *s3Endpoint,
			Bucket:              *s3Bucket,
			Region:              *s3Region,
			Compression:         *s3Compression,
			MaxWorkers:          *numThreads,
			MaxMemory:           maxMemory,
			MaxDelay:            *maxDelay,
			EnableVerboseDebug:  *s3VerboseDebug,
		}, nil, *logVerbose)
		if err != nil {
			log.Fatalf("Error connecting to S3: %v\n", err)
		}
		s3Sender.Start()
	}

	var wg sync.WaitGroup
	wg.Add(*numThreads)

	genTimes := make([][]time.Duration, *numThreads)
	sendTimes := make([][]time.Duration, *numThreads)
	var totalBackoff int64
	for t := 0; t < *numThreads; t++ {
		go func(t int) {
			genTimes[t] = make([]time.Duration, *numBatches)
			sendTimes[t] = make([]time.Duration, *numBatches)

			rnd := rand.New(rand.NewSource(int64(t) + time.Now().Unix()))
			events := make([]*parser.ParsedEvent, *numEvents)
			for i := 0; i < *numBatches; i++ {
				startTime := time.Now()
				for i := 0; i < *numEvents; i++ {
					events[i] = generateFimEvent(rnd)
				}
				genTimes[t][i] = time.Since(startTime)

				startTime = time.Now()
				if chSender != nil {
					for chSender.TotalMemory() > int64(maxMemory.Bytes())*3/4 {
						backoff := (t+i)%*numThreads + 1
						atomic.AddInt64(&totalBackoff, int64(backoff))
						time.Sleep(time.Millisecond * time.Duration(backoff))
					}
					chSender.Enqueue(events)
				} else {
					for s3Sender.TotalMemory() > int64(maxMemory.Bytes())*3/4 {
						backoff := (t+i)%*numThreads + 1
						atomic.AddInt64(&totalBackoff, int64(backoff))
						time.Sleep(time.Millisecond * time.Duration(backoff))
					}
					s3Sender.Enqueue(events)
				}
				sendTimes[t][i] = time.Since(startTime)
				if *logVerbose {
					log.Printf("Thread %d: got generation time %v, send time %v for batch %d\n",
						t, genTimes[t][i], sendTimes[t][i], i)
				}
			}
			wg.Done()
		}(t)
	}

	wg.Wait()
	stopStartTime := time.Now()
	if chSender != nil {
		chSender.Stop()
	}
	if s3Sender != nil {
		s3Sender.Stop()
	}
	stopTime := time.Since(stopStartTime)

	for t := 0; t < *numThreads; t++ {
		totalGenTime := time.Duration(0)
		totalSendTime := time.Duration(0)
		for i := 0; i < *numBatches; i++ {
			totalGenTime += genTimes[t][i]
			totalSendTime += sendTimes[t][i]
		}
		log.Printf("Thread %d: total generation time %v, total send time %v\n", t, totalGenTime, totalSendTime)
	}
	log.Printf("Total backoff: %v\n", time.Millisecond*time.Duration(totalBackoff))
	log.Printf("Final Stop()ping time: %v\n", stopTime)

	log.Printf("Finished in %v\n", time.Since(startTime))

	log.Printf("Metrics: %v\n", metrics.ReportUnistat())

	if *memProfilePath != "" {
		f, err := os.Create(*memProfilePath)
		if err != nil {
			log.Fatal("could not create memory profile: ", err)
		}
		defer func() {
			_ = f.Close()
		}()
		runtime.GC()
		if err := pprof.WriteHeapProfile(f); err != nil {
			log.Fatal("could not write memory profile: ", err)
		}
	}
}

func initRandom() {
	// The lengths must be power-of-two, otherwise rand.Intn() will work much slower.
	if len(hostComponents)&(len(hostComponents)-1) != 0 {
		log.Fatalf("Length of hostComponents must be pow-of-2: %d\n", len(hostComponents))
	}
	if len(pathComponents)&(len(pathComponents)-1) != 0 {
		log.Fatalf("Length of pathComponents must be pow-of-2: %d\n", len(hostComponents))
	}

	const numRandomNums = 100000
	const randomBase = 10
	const randomSpread = 10000
	randomNums = make([]string, numRandomNums)
	for i := 0; i < numRandomNums; i++ {
		randomNums[i] = strconv.Itoa(randomBase + rand.Intn(randomSpread))
	}
}

func parseChParams(p string) (map[string]string, error) {
	urlValues, err := url.ParseQuery(p)
	if err != nil {
		return nil, err
	}
	ret := map[string]string{}
	for key, value := range urlValues {
		ret[key] = value[0]
	}
	return ret, nil
}

func generateFimEvent(rnd *rand.Rand) *parser.ParsedEvent {
	hostname := generateHost(rnd)
	eventName := eventNames[rnd.Intn(len(eventNames))]
	return &parser.ParsedEvent{
		Host:    hostname,
		LogType: "some_type",
		NodeKey: hostname,
		Name:    eventName,
		Data: map[string]interface{}{
			"action": "added",
			"columns": map[string]interface{}{
				"category":   "root",
				"path":       generatePath(rnd),
				"real_path":  generatePath(rnd),
				"name":       fimPackName,
				"hostname":   hostname,
				"sha256":     generateSha256(rnd),
				"some_value": rand.Float64(),
			},
		},
	}
}

func generateHost(rnd *rand.Rand) string {
	// maxC - minC must be power-of-two
	const minC = 3
	const maxC = 5
	const largeMaxC = 19
	var num int
	if *largeEvents {
		num = largeMaxC
	} else {
		num = minC + rnd.Intn(maxC-minC)
	}
	var buf [largeMaxC * 2]string
	for i := 0; i < num; i++ {
		buf[i*2] = hostComponents[rnd.Intn(len(hostComponents))]
		buf[i*2+1] = randomNums[rnd.Intn(len(randomNums))]
	}
	return strings.Join(buf[:num*2], ".")
}

func generatePath(rnd *rand.Rand) string {
	// maxC - minC must be power-of-two
	const minC = 4
	const maxC = 8
	const largeMaxC = 68
	var num int
	if *largeEvents {
		num = largeMaxC
	} else {
		num = minC + rnd.Intn(maxC-minC)
	}
	var buf [largeMaxC * 2]string
	for i := 0; i < num; i++ {
		buf[i*2] = pathComponents[rnd.Intn(len(pathComponents))]
		buf[i*2+1] = randomNums[rnd.Intn(len(randomNums))]
	}
	return strings.Join(buf[:num*2], "/")
}

func generateSha256(rnd *rand.Rand) string {
	var b [32]byte
	rnd.Read(b[:])
	return hex.EncodeToString(b[:])
}
