package main

import (
	"database/sql/driver"
	"encoding/csv"
	"flag"
	"fmt"
	"io"
	"io/ioutil"
	"log"
	"net/http"
	"os"
	"runtime"
	"runtime/pprof"
	"runtime/trace"
	"strconv"
	"strings"
	"time"

	ch "github.com/ClickHouse/clickhouse-go"
	"github.com/aws/aws-sdk-go/aws"
	awsClient "github.com/aws/aws-sdk-go/aws/client"
	awsCredentials "github.com/aws/aws-sdk-go/aws/credentials"
	awsSession "github.com/aws/aws-sdk-go/aws/session"
	awsS3 "github.com/aws/aws-sdk-go/service/s3"
	"golang.org/x/net/http2"

	"a.yandex-team.ru/security/osquery/osquery-sender/clickhouse"
	"a.yandex-team.ru/security/osquery/osquery-sender/config"
	"a.yandex-team.ru/security/osquery/osquery-sender/s3"
)

const (
	minPartSizeForMerge = 256 * 1024 * 1024
)

var (
	configPath = flag.String("config", "/etc/osquery-sender/conf.yaml", "osquery-sender config")
	timezone   = flag.String("tz", "Europe/Moscow", "timezone to use")

	chToS3 = flag.String("ch-to-s3", "", "copy the data from Clickhouse to S3, argument must be table_name/yyyy-mm-dd")
	s3ToCh = flag.String("s3-to-ch", "", "copy the data from S3 to Clickhouse, argument must be table_name/yyyy-mm-dd (this is the start date, the end date will be determined automatically)")

	cpuProfilePath = flag.String("cpuprofile", "", "write CPU profile to file")
	tracePath      = flag.String("trace", "", "write trace to file")
)

func main() {
	log.SetFlags(log.Lshortfile | log.Ldate | log.Ltime)
	log.SetPrefix("[osquery-sender copier] ")

	flag.Parse()

	conf, err := config.FromFile(*configPath)
	if err != nil {
		log.Fatalf("could not read config %s: %v\n", *configPath, err)
	}

	if *cpuProfilePath != "" {
		cpuProfileFile, err := os.Create(*cpuProfilePath)
		if err != nil {
			log.Fatalf("could not open CPU profile file: %v\n", err)
		}
		defer func() {
			_ = cpuProfileFile.Close()
		}()
		err = pprof.StartCPUProfile(cpuProfileFile)
		if err != nil {
			log.Fatalf("could not start CPU profile: %v\n", err)
		}
		defer pprof.StopCPUProfile()
	}

	if *tracePath != "" {
		traceFile, err := os.Create(*tracePath)
		if err != nil {
			log.Fatalf("could not open tracePath file: %v\n", err)
		}
		defer func() {
			_ = traceFile.Close()
		}()
		err = trace.Start(traceFile)
		if err != nil {
			log.Fatalf("could not start trace: %v\n", err)
		}
		defer trace.Stop()
	}

	secretAccessKeyBytes, err := ioutil.ReadFile(conf.S3.SecretAccessKeyFile)
	if err != nil {
		log.Fatalf("secret file %s error: %v\n", conf.S3.SecretAccessKeyFile, err)
	}
	secretAccessKey := strings.TrimSpace(string(secretAccessKeyBytes))
	credentials := awsCredentials.NewStaticCredentials(conf.S3.AccessKeyID, secretAccessKey, "")

	httpTransport := http.DefaultTransport.(*http.Transport).Clone()
	h2Transport, err := http2.ConfigureTransports(httpTransport)
	if err != nil {
		log.Fatal(err)
	}
	h2Transport.ReadIdleTimeout = time.Second * 30

	// awsLogger := aws.LoggerFunc(func(i ...interface{}) {
	// 	log.Println("AWS SDK:", fmt.Sprint(i...))
	// })
	session, err := awsSession.NewSessionWithOptions(awsSession.Options{
		Config: aws.Config{
			Credentials: credentials,
			Endpoint:    &conf.S3.Endpoint,
			Region:      &conf.S3.Region,
			Retryer: awsClient.DefaultRetryer{
				NumMaxRetries: 5,
				MinRetryDelay: time.Second * 5,
			},
			HTTPClient: &http.Client{
				Transport: httpTransport,
				Timeout:   time.Second * 30,
			},
			// Logger:   awsLogger,
			// LogLevel: aws.LogLevel(aws.LogDebugWithRequestErrors | aws.LogDebugWithRequestRetries),
		},
	})
	if err != nil {
		log.Fatal(err)
	}
	client := awsS3.New(session)
	manager := s3.NewS3Manager(client, &s3.S3ManagerConfig{
		EnableDebug:        true,
		NumRetries:         5,
		NumDownloadWorkers: runtime.NumCPU(),
		NumUploadWorkers:   runtime.NumCPU(),
		NumGetInfoWorkers:  runtime.NumCPU(),
		MinUploadPartSize:  minPartSizeForMerge,
	})

	chParams, err := clickhouse.MakeClickhouseParams(conf.Clickhouse)
	if err != nil {
		log.Fatalf("clickhouse params error: %v\n", err)
	}

	chPool, err := clickhouse.NewPool(clickhouse.PoolParams{
		Hosts:             conf.Clickhouse.Hosts,
		Port:              conf.Clickhouse.Port,
		Params:            chParams,
		Size:              runtime.NumCPU(),
		NumRetries:        5,
		RetryBackoff:      time.Second * 5,
		WaitForConnection: time.Second * 5,
	})
	if err != nil {
		log.Fatalf("clickhouse pool error: %v\n", err)
	}

	if *chToS3 != "" {
		tableName, year, month, day := parseTableAndDate(*chToS3)
		err = doChToS3(chPool, manager, conf, tableName, year, month, day)
		if err != nil {
			log.Fatalf("error copying the data from Clickhouse to S3: %v\n", err)
		}
	} else if *s3ToCh != "" {
		tableName, year, month, day := parseTableAndDate(*s3ToCh)
		doS3ToCh(chPool, manager, conf, tableName, year, month, day)
	}
}

func parseTableAndDate(s string) (string, int, int, int) {
	v := strings.Split(s, "/")
	if len(v) != 2 {
		log.Fatalf("error parsing argument %s\n", s)
	}
	tableName := v[0]

	d := strings.Split(v[1], "-")
	if len(d) != 3 {
		log.Fatalf("error parsing argument %s\n", s)
	}
	year, err := strconv.Atoi(d[0])
	if err != nil {
		log.Fatalf("error parsing argument %s date: %v\n", s, err)
	}
	month, err := strconv.Atoi(d[1])
	if err != nil {
		log.Fatalf("error parsing argument %s date: %v\n", s, err)
	}
	day, err := strconv.Atoi(d[2])
	if err != nil {
		log.Fatalf("error parsing argument %s date: %v\n", s, err)
	}

	return tableName, year, month, day
}

func doChToS3(chPool *clickhouse.ClickhousePool, manager *s3.S3Manager, conf *config.SenderConfig, name string, year int, month int, day int) error {
	s3Timezone := *timezone
	if conf.S3.Timezone != "" {
		s3Timezone = conf.S3.Timezone
	}
	s3Loc, err := time.LoadLocation(s3Timezone)
	if err != nil {
		log.Fatalf("could not load location %s: %v\n", conf.S3.Timezone, s3Loc)
	}

	s3Alg := s3.CompressionLz4
	if conf.S3.Compression != "" {
		var err error
		s3Alg, err = s3.ParseCompressionAlg(conf.S3.Compression)
		if err != nil {
			log.Fatalf("could not parse compression alg %s: %v", conf.S3.Compression, err)
		}
	}

	fromTimestamp := time.Date(year, time.Month(month), day, 0, 0, 0, 0, s3Loc)
	toTimestamp := fromTimestamp.Add(time.Hour * 24)

	schema, err := clickhouse.DescribeTable(chPool, name)
	if err != nil {
		return fmt.Errorf("could not describe %s: %v", name, err)
	}
	columns := orderColumns(schema.SortedColumns())

	date := year*10000 + month*100 + day
	uploadKey := s3.MakeMergedKey(name, date, s3Alg)
	uploadMetadata := s3.MakeMergedMetadata(date, s3Alg, columns)

	startTime := time.Now()
	log.Printf("Starting copying from %s to %s (date %d, columns %v)\n", name, uploadKey, date, columns)

	upload := manager.Upload(conf.S3.Bucket, uploadKey, uploadMetadata)
	writer, err := s3.NewCompressedWriter(s3Alg, upload)
	if err != nil {
		return fmt.Errorf("error creating writer: %v", err)
	}
	csvW := csv.NewWriter(writer)
	csvW.Comma = '\t'

	err = csvW.Write(columns)
	if err != nil {
		return fmt.Errorf("error writing header: %v", err)
	}

	// Estimate the number of rows.
	query := fmt.Sprintf("SELECT COUNT(*) FROM `%s` WHERE timestamp >= toDateTime64(%d, 0) AND timestamp < toDateTime64(%d, 0)",
		name, fromTimestamp.Unix(), toTimestamp.Unix())
	var count uint64
	err = clickhouse.RunQuery(chPool, query, func(rows []map[string]driver.Value) error {
		if len(rows) != 1 {
			return fmt.Errorf("got strange number of rows for count: %d", len(rows))
		}
		for _, v := range rows[0] {
			var ok bool
			if count, ok = v.(uint64); ok {
				return nil
			}
			return fmt.Errorf("strange type for count: %T", v)
		}
		return fmt.Errorf("no columns for count")
	})
	if err != nil {
		return fmt.Errorf("could not count rows in %s: %v", name, err)
	}

	// We want to use ~256Mb, while the average record is ~1Kb so we can hold ~250k records at once.
	const maxRecords = 250000
	timeStep := time.Second * time.Duration(int64(maxRecords)*int64(60*60*24)/int64(count))
	if timeStep > time.Hour*24 {
		timeStep = time.Hour * 24
	}
	log.Printf("Got %d lines in table, use %v timestep\n", count, timeStep)

	for fromBatch := fromTimestamp; fromBatch.Before(toTimestamp); fromBatch = fromBatch.Add(timeStep) {
		toBatch := fromBatch.Add(timeStep)
		if toBatch.After(toTimestamp) {
			toBatch = toTimestamp
		}

		query = fmt.Sprintf("SELECT * FROM `%s` WHERE timestamp >= toDateTime64(%d, 0) AND timestamp < toDateTime64(%d, 0) ORDER BY timestamp",
			name, fromBatch.Unix(), toBatch.Unix())
		err = clickhouse.RunTx(chPool, func(conn ch.Clickhouse) error {
			stmt, err := conn.Prepare(query)
			if err != nil {
				return err
			}
			defer func() {
				err := stmt.Close()
				if err != nil {
					log.Printf("ERROR: closing statement failed: %v\n", err)
				}
			}()
			//goland:noinspection GoDeprecation
			rows, err := stmt.Query([]driver.Value{})
			if err != nil {
				return err
			}
			defer func() {
				err := rows.Close()
				if err != nil {
					log.Printf("ERROR: closing rows failed: %v\n", err)
				}
			}()

			log.Printf("Read batch from %d (%v) to %d (%v)\n",
				fromBatch.Unix(), fromBatch.Format("2006-01-02T15-04-05"), toBatch.Unix(), toBatch.Format("2006-01-02T15-04-05"))

			// Maps destination column -> row column.
			columnMap := map[int]int{}
			for i, col := range columns {
				columnMap[i] = stringsIndexOf(rows.Columns(), col)
			}

			rowValues := make([]driver.Value, len(rows.Columns()))
			record := make([]string, len(columns))
			for {
				err = rows.Next(rowValues)
				if err != nil {
					if err == io.EOF {
						return nil
					}
					return err
				}

				for i := 0; i < len(columns); i++ {
					v, err := valueToString(rowValues[columnMap[i]])
					if err != nil {
						return err
					}
					record[i] = v
				}
				err = csvW.Write(record)
				if err != nil {
					return err
				}
			}
		})
		if err != nil {
			upload.Abort()
			return err
		}
		csvW.Flush()
	}

	err = writer.Close()
	if err != nil {
		return err
	}

	err = upload.Close()
	if err != nil {
		return err
	}
	log.Printf("Finished uploading to %s in %v\n", uploadKey, time.Since(startTime))
	return nil
}

func doS3ToCh(chPool *clickhouse.ClickhousePool, manager *s3.S3Manager, conf *config.SenderConfig, name string, year int, month int, day int) {
	timestamp := time.Date(year, time.Month(month), day, 0, 0, 0, 0, time.UTC)

	// Check the first timestamp present.
	query := fmt.Sprintf("SELECT MIN(timestamp) FROM `%s` WHERE timestamp >= toDateTime64(%d, 0)",
		name, timestamp.Unix())
	var toTimestamp time.Time
	err := clickhouse.RunQuery(chPool, query, func(rows []map[string]driver.Value) error {
		if len(rows) != 1 {
			log.Fatalf("got strange number of rows for timestamp: %d", len(rows))
		}
		for _, v := range rows[0] {
			var ok bool
			if toTimestamp, ok = v.(time.Time); ok {
				return nil
			}
			return fmt.Errorf("strange type for count: %T", v)
		}
		log.Fatalf("no columns for count")
		return nil
	})
	if err != nil {
		log.Fatalf("could not get timestamp in %s: %v", name, err)
	}

	log.Printf("Copying until %v\n", toTimestamp)

	s3Alg := s3.CompressionLz4
	if conf.S3.Compression != "" {
		var err error
		s3Alg, err = s3.ParseCompressionAlg(conf.S3.Compression)
		if err != nil {
			log.Fatalf("could not parse compression alg %s: %v", conf.S3.Compression, err)
		}
	}

	startTime := time.Now()
	var lineCount int64
	for {
		dateInt := int(timestamp.Year())*10000 + int(timestamp.Month())*100 + timestamp.Day()
		downloadKey := s3.MakeMergedKey(name, dateInt, s3Alg)

		infos, err := manager.GetInfos(conf.S3.Bucket, []string{downloadKey})
		if err != nil {
			log.Fatalf("error getting info %s: %v", downloadKey, err)
		}

		columns, err := s3.ParseMetadataColumns(infos[0].Metadata)
		if err != nil {
			log.Fatalf("error getting columns in %s: %v", downloadKey, err)
		}
		timestampColumnIdx := stringsIndexOf(columns, clickhouse.TimestampColumn)
		if timestampColumnIdx == -1 {
			log.Fatalf("%s does not contain timestamp", name)
		}

		schema, err := clickhouse.DescribeTable(chPool, name)
		if err != nil {
			log.Fatalf("could not describe %s: %v", name, err)
		}

		log.Printf("Starting inserting from %s to %s (date %d, columns %v)\n", downloadKey, name, dateInt, columns)
		dayStartTime := time.Now()

		download := manager.Download(conf.S3.Bucket, downloadKey)
		reader, err := s3.NewCompressedReader(s3Alg, download)
		if err != nil {
			log.Fatalf("error creating writer: %v", err)
		}
		csvR := csv.NewReader(reader)
		csvR.Comma = '\t'

		_, err = csvR.Read()
		if err == io.EOF {
			log.Fatalf("error reading the header: %v", err)
		}

		const maxBatchSize = 1000000
		var batch [][]string
		var dayLineCount int64
		for {
			line, err := csvR.Read()
			if err == io.EOF {
				break
			}

			lineTimestamp, err := strconv.ParseInt(line[timestampColumnIdx], 10, 64)
			if err != nil {
				log.Fatalf("error parsing timestamp %s: %v", line[timestampColumnIdx], err)
			}
			if !time.Unix(lineTimestamp, 0).Before(toTimestamp) {
				break
			}

			batch = append(batch, line)
			if len(batch) > maxBatchSize {
				err = insertIntoClickhouse(chPool, name, batch, columns, schema)
				if err != nil {
					log.Fatalf("error inserting into clickhouse: %v", err)
				}
				dayLineCount += int64(len(batch))
				batch = nil
			}
		}
		if len(batch) > 0 {
			err = insertIntoClickhouse(chPool, name, batch, columns, schema)
			if err != nil {
				log.Fatalf("error inserting into clickhouse: %v", err)
			}
			dayLineCount += int64(len(batch))
		}

		lineCount += dayLineCount
		log.Printf("Finished inserting from %s to %s (date %d), %d lines in %v\n", downloadKey, name, dateInt, dayLineCount, time.Since(dayStartTime))

		timestamp = timestamp.AddDate(0, 0, 1)
		if timestamp.After(toTimestamp) {
			break
		}
	}

	log.Printf("Finished inserting into %s %d lines in %v\n", name, lineCount, time.Since(startTime))
}

func insertIntoClickhouse(pool *clickhouse.ClickhousePool, name string, batch [][]string, columns []string, schema clickhouse.TableSchema) error {
	return clickhouse.RunTx(pool, func(conn ch.Clickhouse) error {
		quotedColumns := make([]string, 0, len(columns))
		columnTypes := make([]clickhouse.ColumnType, 0, len(columns))
		for _, column := range columns {
			quotedColumns = append(quotedColumns, "`"+column+"`")
			if columnType, ok := schema[column]; ok {
				columnTypes = append(columnTypes, columnType)
			} else {
				return fmt.Errorf("column not in schema: %s", column)
			}
		}
		values := make([]string, len(columns))
		for i := 0; i < len(columns); i++ {
			values[i] = "?"
		}
		sql := fmt.Sprintf("INSERT INTO `%s` (%s) VALUES (%s)", name, strings.Join(quotedColumns, ", "),
			strings.Join(values, ", "))
		_, err := conn.Prepare(sql)
		if err != nil {
			return err
		}

		block, err := conn.Block()
		if err != nil {
			return err
		}

		block.Reserve()
		block.NumRows += uint64(len(batch))

		for _, line := range batch {
			for j := 0; j < len(columns); j++ {
				switch columnTypes[j] {
				case clickhouse.ColumnFloat64:
					v, err := strconv.ParseFloat(line[j], 64)
					if err != nil {
						return fmt.Errorf("could not parse column %s value %s: %v", columns[j], line[j], err)
					}
					err = block.WriteFloat64(j, v)
					if err != nil {
						return fmt.Errorf("could not write column %s value %s: %v", columns[j], line[j], err)
					}
				case clickhouse.ColumnInt64, clickhouse.ColumnDateTime64:
					v, err := strconv.ParseInt(line[j], 10, 64)
					if err != nil {
						return fmt.Errorf("could not parse column %s value %s: %v", columns[j], line[j], err)
					}
					err = block.WriteInt64(j, v)
					if err != nil {
						return fmt.Errorf("could not write column %s value %s: %v", columns[j], line[j], err)
					}
				case clickhouse.ColumnString:
					err = block.WriteString(j, line[j])
					if err != nil {
						return fmt.Errorf("could not write column %s value %s: %v", columns[j], line[j], err)
					}
				}
			}
		}

		return conn.WriteBlock(block)
	})
}

func stringsIndexOf(haystack []string, needle string) int {
	for i, v := range haystack {
		if v == needle {
			return i
		}
	}
	return -1
}

func orderColumns(columns []string) []string {
	ret := make([]string, 0, len(columns))
	ret = append(ret, clickhouse.TimestampColumn)
	ret = append(ret, clickhouse.ActionColumn)
	ret = append(ret, clickhouse.HostColumn)
	for _, c := range columns {
		if c != clickhouse.TimestampColumn && c != clickhouse.ActionColumn && c != clickhouse.HostColumn {
			ret = append(ret, c)
		}
	}
	return ret
}

func valueToString(v driver.Value) (string, error) {
	switch v := v.(type) {
	case string:
		return v, nil
	case int64:
		return strconv.FormatInt(v, 10), nil
	case float64:
		return strconv.FormatFloat(v, 'g', -1, 64), nil
	case time.Time:
		return strconv.FormatInt(v.Unix(), 10), nil
	default:
		return "", fmt.Errorf("unsupport column type %T: %v", v, v)
	}
}
