package clickhouse

import (
	"fmt"
	"io/ioutil"
	"log"
	"math/rand"
	"regexp"
	"runtime"
	"runtime/debug"
	"runtime/pprof"
	"strconv"
	"strings"
	"sync"
	"time"

	ch "github.com/ClickHouse/clickhouse-go"
	"github.com/c2h5oh/datasize"
	"go.uber.org/atomic"

	"a.yandex-team.ru/security/osquery/osquery-sender/batcher"
	"a.yandex-team.ru/security/osquery/osquery-sender/config"
	"a.yandex-team.ru/security/osquery/osquery-sender/metrics"
	"a.yandex-team.ru/security/osquery/osquery-sender/parser"
	"a.yandex-team.ru/security/osquery/osquery-sender/util"
)

type ClickhouseSender struct {
	enableDebug bool

	enableSharding    bool
	eventNameRe       *regexp.Regexp
	chOnlyEventNameRe *regexp.Regexp

	dropAfterDays       int
	customDropAfterDays map[string]int

	workers *batcher.BatcherWorkers
	pool    *ClickhousePool

	// Store all table schemas.
	schemaCacheMu sync.Mutex
	schemaCache   map[string]TableSchema

	// All table creation/alter must be done while holding this lock.
	modifyTableMu sync.Mutex

	closed atomic.Bool
}

const (
	// Always add a timestamp and action columns to the schema.
	TimestampColumn = "timestamp"
	ActionColumn    = "action"
	HostColumn      = "host"

	shardTableNamePrefix = "shard_"

	defaultDropAfterDays = 30

	defaultNumRetries = 5
	retryBackoff      = time.Second * 5

	waitForConnection = time.Second * 5

	// Use 1Gb as a sane default.
	defaultMaxMemory = datasize.GB
	defaultMaxDelay  = time.Second * 5

	// Refresh table sizes each minute.
	refreshSizesInterval = time.Minute
)

func NewSender(config *config.ClickhouseConfig, addDecorators []string, enableDebug bool) (*ClickhouseSender, error) {
	if config == nil {
		return nil, nil
	}
	maxMemory := config.MaxMemory
	if maxMemory == 0 {
		maxMemory = defaultMaxMemory
	}
	var maxDelay time.Duration
	var err error
	if config.MaxDelay != "" {
		maxDelay, err = time.ParseDuration(config.MaxDelay)
		if err != nil {
			return nil, fmt.Errorf("could not parse max_delay: %v", err)
		}
	} else {
		maxDelay = defaultMaxDelay
	}
	maxConnections := config.MaxConnections
	if maxConnections == 0 {
		maxConnections = runtime.NumCPU()
	}

	eventNameRe := makeRegexp(config)
	chOnlyEventNameRe := makeChOnlyRegexp(config)
	params, err := MakeClickhouseParams(config)
	if err != nil {
		return nil, err
	}

	dropAfterDays := config.DropAfterDays
	if dropAfterDays == 0 {
		dropAfterDays = defaultDropAfterDays
	}

	numRetries := config.NumRetries
	if numRetries == 0 {
		numRetries = defaultNumRetries
	}

	pool, err := NewPool(PoolParams{
		Hosts:             config.Hosts,
		Port:              config.Port,
		Params:            params,
		Size:              maxConnections,
		NumRetries:        numRetries,
		RetryBackoff:      retryBackoff,
		WaitForConnection: waitForConnection,
	})
	if err != nil {
		return nil, err
	}

	ret := &ClickhouseSender{
		enableDebug:         enableDebug || config.EnableDebug,
		enableSharding:      config.EnableSharding,
		eventNameRe:         eventNameRe,
		chOnlyEventNameRe:   chOnlyEventNameRe,
		dropAfterDays:       dropAfterDays,
		customDropAfterDays: config.CustomDropAfterDays,
		pool:                pool,
		schemaCache:         make(map[string]TableSchema),
	}

	eventBatcher := batcher.New(config.RemovePrefix, config.RemoveSuffix, true, addDecorators)
	workersConfig := batcher.WorkersConfig{
		MaxMemory:  int64(maxMemory.Bytes()),
		MaxDelay:   maxDelay,
		MaxWorkers: maxConnections,
		SplitDays:  false,
	}
	workers := batcher.NewWorkers(eventBatcher, workersConfig, ret, "clickhouse")
	ret.workers = workers

	return ret, nil
}

func MakeClickhouseParams(config *config.ClickhouseConfig) (map[string]string, error) {
	ret := map[string]string{}
	for key, value := range config.ConnectionParams {
		ret[key] = value
	}
	if config.PasswordFile != "" {
		contents, err := ioutil.ReadFile(config.PasswordFile)
		if err != nil {
			return nil, err
		}
		ret["password"] = strings.TrimSpace(string(contents))
	}
	return ret, nil
}

func makeRegexp(config *config.ClickhouseConfig) *regexp.Regexp {
	allEventNames := append(config.EventNames, config.ClickhouseOnlyEventNames...)
	joinedRe := "^" + strings.Join(allEventNames, "|") + "$"
	return regexp.MustCompile(joinedRe)
}

func makeChOnlyRegexp(config *config.ClickhouseConfig) *regexp.Regexp {
	joinedRe := "^" + strings.Join(config.ClickhouseOnlyEventNames, "|") + "$"
	return regexp.MustCompile(joinedRe)
}

func (s *ClickhouseSender) Start() {
	s.workers.Start()
	log.Printf("started %d workers for clickhouse\n", s.workers.NumWorkers())

	go util.RunWithLabels(pprof.Labels("name", "clickhouse-drop-old-partitions"), func() {
		s.dropOldPartitions()
	})
	go util.RunWithLabels(pprof.Labels("name", "clickhouse-refresh-sizes"), func() {
		s.refreshSizes()
	})
}

func (s *ClickhouseSender) Stop() {
	log.Printf("stopping clickhouse sender\n")
	s.workers.Stop()
	s.pool.Close()
	s.closed.Store(true)
	log.Printf("stopped clickhouse sender\n")
}

func (s *ClickhouseSender) UpdateMetrics() {
	s.updateMetricsMaxValues()
	metrics.SetChWorkerLoads(s.workers.Load())
}

func (s *ClickhouseSender) updateMetricsMaxValues() {
	metrics.ChQueueLen.Report(uint64(s.workers.QueueLen()))
	metrics.ChQueueMemorySize.Report(uint64(s.workers.QueueMemory()))
	metrics.ChBatcherMemorySize.Report(uint64(s.workers.BatcherMemory()))
}

func (s *ClickhouseSender) OnDropDueToFullMemory() {
	metrics.ChDroppedDueToFullMemory.Inc()
}

func (s *ClickhouseSender) OnDropDueToFullQueue() {
	metrics.ChDroppedDueToFullQueue.Inc()
}

func (s *ClickhouseSender) OnFlushDueToMemoryPressure() {
	metrics.ChFlushedDueToPressure.Inc()
}

func (s *ClickhouseSender) OnDropAfterRetries() {
	metrics.ChDroppedAfterRetries.Inc()
}

// Filters out events which should be sent to Clickhouse.
func (s *ClickhouseSender) FilterEvents(events []*parser.ParsedEvent) (chEvents []*parser.ParsedEvent, remainingEvents []*parser.ParsedEvent) {
	if s == nil {
		return nil, events
	}

	chEvents = make([]*parser.ParsedEvent, 0, len(events))
	remainingEvents = make([]*parser.ParsedEvent, 0, len(events))
	for _, event := range events {
		if s.eventNameRe.MatchString(event.Name) {
			chEvents = append(chEvents, event)
		}

		if !s.chOnlyEventNameRe.MatchString(event.Name) {
			remainingEvents = append(remainingEvents, event)
		}
	}
	return
}

func (s *ClickhouseSender) Enqueue(events []*parser.ParsedEvent) {
	if s == nil {
		return
	}
	if s.enableDebug {
		log.Printf("enqueueing %d events to clickhouse\n", len(events))
	}

	s.workers.Enqueue(events)
	s.updateMetricsMaxValues()
}

func (s *ClickhouseSender) TotalMemory() int64 {
	return s.workers.BatcherMemory() + s.workers.QueueMemory()
}

func (s *ClickhouseSender) SubmitEvents(name string, events *batcher.EventBatch) error {
	if s.enableDebug {
		log.Printf("sending %d events with name %s to clickhouse\n", events.Length, name)
	}

	renameEventColumns(events)

	table := getTableSchemaFromEvents(events)
	err := s.createOrAlterTable(name, table)
	if err != nil {
		log.Printf("ERROR: %v\n", err)
		metrics.ChFailedRetries.Inc()
		return err
	}
	err = s.insertIntoTable(name, table, events)
	if err != nil {
		log.Printf("ERROR: %v\n", err)
		s.deleteTableFromSchemaCache(name)
		metrics.ChFailedRetries.Inc()
		return err
	}

	return nil
}

// Rename duplicate column names (collisions with standard names such as timestamp, action or host).
func renameEventColumns(events *batcher.EventBatch) {
	const postfix = "_column"
	for key, values := range events.Float64Values {
		if key == TimestampColumn || key == ActionColumn || key == HostColumn {
			events.Float64Values[key+postfix] = values
			delete(events.Float64Values, key)
		}
	}
	for key, values := range events.StringValues {
		if key == TimestampColumn || key == ActionColumn || key == HostColumn {
			events.StringValues[key+postfix] = values
			delete(events.StringValues, key)
		}
	}
}

func getTableSchemaFromEvents(events *batcher.EventBatch) TableSchema {
	result := TableSchema{}
	result[TimestampColumn] = ColumnDateTime64
	result[ActionColumn] = ColumnString
	result[HostColumn] = ColumnString
	for key := range events.Float64Values {
		result[key] = ColumnFloat64
	}
	for key := range events.StringValues {
		result[key] = ColumnString
	}
	return result
}

func (s *ClickhouseSender) createOrAlterTable(tableName string, table TableSchema) error {
	if s.enableSharding {
		shardTableName := shardTableNamePrefix + tableName
		err := s.createOrAlterTableImpl(shardTableName, table, "")
		if err != nil {
			return err
		}
		// The shard_ table must be altered first before altering the distributed table.
		return s.createOrAlterTableImpl(tableName, table, shardTableName)
	} else {
		return s.createOrAlterTableImpl(tableName, table, "")
	}
}

func (s *ClickhouseSender) createOrAlterTableImpl(tableName string, table TableSchema, shardTableName string) error {
	s.schemaCacheMu.Lock()
	gotTable, ok := s.schemaCache[tableName]
	s.schemaCacheMu.Unlock()

	if ok && tableIsSubset(table, gotTable) {
		return nil
	}

	s.modifyTableMu.Lock()
	defer s.modifyTableMu.Unlock()

	exists, err := TableExists(s.pool, tableName)
	if err != nil {
		return err
	}
	if exists {
		gotTable, err = DescribeTable(s.pool, tableName)
		if err != nil {
			return err
		}
		if tableIsSubset(table, gotTable) {
			log.Printf("table %s already has the required columns\n", tableName)
			s.schemaCacheMu.Lock()
			s.schemaCache[tableName] = gotTable
			s.schemaCacheMu.Unlock()
			return nil
		}
	}

	if gotTable != nil {
		err := AlterTable(s.pool, tableName, table, gotTable)
		if err != nil {
			return err
		}
	} else {
		if shardTableName != "" {
			err = CreateDistributedTable(s.pool, tableName, shardTableName)
			if err != nil {
				return err
			}
		} else {
			// Must match the pattern parsed in dropOldPartitions().
			partitionBy := fmt.Sprintf("toYYYYMMDD(%s)", TimestampColumn)
			err = CreateReplicatedTable(s.pool, tableName, table, TimestampColumn, partitionBy)
			if err != nil {
				return err
			}
		}
	}

	s.schemaCacheMu.Lock()
	s.schemaCache[tableName] = table
	s.schemaCacheMu.Unlock()
	return nil
}

func tableIsSubset(t1 TableSchema, t2 TableSchema) bool {
	if len(t1) > len(t2) {
		return false
	}
	for column, columnType := range t1 {
		t, ok := t2[column]
		if !ok {
			return false
		}
		if t != columnType {
			return false
		}
	}
	return true
}

func (s *ClickhouseSender) insertIntoTable(tableName string, table TableSchema, events *batcher.EventBatch) error {
	startTime := time.Now()
	err := RunTx(s.pool, func(conn ch.Clickhouse) error {
		columns := make([]string, 0, len(table))
		quotedColumns := make([]string, 0, len(table))
		columnTypes := make([]ColumnType, 0, len(table))
		for column, columnType := range table {
			columns = append(columns, column)
			quotedColumns = append(quotedColumns, "`"+column+"`")
			columnTypes = append(columnTypes, columnType)
		}
		values := make([]string, len(columns))
		for i := 0; i < len(columns); i++ {
			values[i] = "?"
		}
		sql := fmt.Sprintf("INSERT INTO `%s` (%s) VALUES (%s)", tableName, strings.Join(quotedColumns, ", "),
			strings.Join(values, ", "))
		_, err := conn.Prepare(sql)
		if err != nil {
			return err
		}

		block, err := conn.Block()
		if err != nil {
			return err
		}

		block.Reserve()
		block.NumRows += uint64(events.Length)

		for i := 0; i < len(columns); i++ {
			switch columns[i] {
			case TimestampColumn:
				for j := 0; j < events.Length; j++ {
					// DateTime64 accepts Int64 arguments
					err = block.WriteInt64(i, events.Timestamps[j])
					if err != nil {
						return err
					}
				}
			case ActionColumn:
				for j := 0; j < events.Length; j++ {
					err = block.WriteString(i, events.Actions[j])
					if err != nil {
						return err
					}
				}
			case HostColumn:
				for j := 0; j < events.Length; j++ {
					err = block.WriteString(i, events.Hosts[j])
					if err != nil {
						return err
					}
				}
			default:
				switch columnTypes[i] {
				case ColumnFloat64:
					values := events.Float64Values[columns[i]]
					for j := 0; j < events.Length; j++ {
						err = block.WriteFloat64(i, values[j])
						if err != nil {
							return err
						}
					}
				case ColumnString:
					values := events.StringValues[columns[i]]
					for j := 0; j < events.Length; j++ {
						err = block.WriteString(i, values[j])
						if err != nil {
							return err
						}
					}
				default:
					return fmt.Errorf("unsupported column type %d", columnTypes[i])
				}
			}
		}

		return conn.WriteBlock(block)
	})
	if err != nil {
		return fmt.Errorf("error inserting into %s: %v", tableName, err)
	}

	if s.enableDebug {
		log.Printf("committed %d events with %d columns in %v\n", events.Length, events.NumColumns,
			time.Since(startTime))
	}
	metrics.ChCommits.Inc()
	metrics.ChInsertedRows.Add(int64(events.Length))

	return nil
}

func (s *ClickhouseSender) dropOldPartitions() {
	const minSleepTime = 10
	const maxSleepTime = 20
	for {
		// Sleep random amount of time to reduce the probability of collistion with another instance doing a cleanup.
		sleepTime := minSleepTime + rand.Intn(maxSleepTime-minSleepTime)
		time.Sleep(time.Minute * time.Duration(sleepTime))
		if s.closed.Load() {
			return
		}

		s.dropOldPartitionsImpl()
	}
}

func (s *ClickhouseSender) dropOldPartitionsImpl() {
	defer func() {
		if r := recover(); r != nil {
			log.Println("ERROR: recovered in dropOldPartitionsImpl ", r, string(debug.Stack()))
			metrics.ChFailuresDuringCleanup.Inc()
		}
	}()

	startTime := time.Now()
	partitions, err := GetTablePartitions(s.pool)
	if err != nil {
		log.Printf("ERROR: getting table partitions failed: %v\n", err)
		metrics.ChFailuresDuringCleanup.Inc()
		return
	}

	timezone, err := GetServerTimezone(s.pool)
	if err != nil {
		log.Printf("ERROR: getting server timezone failed: %v\n", err)
		metrics.ChFailuresDuringCleanup.Inc()
		return
	}
	log.Printf("server timezone: %s\n", timezone)
	loc, err := time.LoadLocation(timezone)
	if err != nil {
		log.Printf("ERROR: timezone %s load failed: %v\n", timezone, err)
		metrics.ChFailuresDuringCleanup.Inc()
		return
	}

	totalDropped := 0
	for table, parts := range partitions {
		dropAfterDays := s.getDropAfterDaysFor(table)
		thresholdTime := time.Now().AddDate(0, 0, -(dropAfterDays + 1))
		year, month, day := thresholdTime.In(loc).Date()
		// See https://clickhouse.tech/docs/en/sql-reference/functions/date-time-functions/#toyyyymmdd
		thresholdDate := year*10000 + int(month)*100 + day
		log.Printf("threshold date for %s is %d (dropping all partitions earlier than %d days)\n",
			table, thresholdDate, dropAfterDays)

		for _, part := range parts {
			partDate, err := strconv.Atoi(part)
			if err != nil {
				log.Printf("ERROR: strange partition name, not an integer: %s\n", part)
				metrics.ChFailuresDuringCleanup.Inc()
				continue
			}

			if partDate < thresholdDate {
				log.Printf("partition %s of table %s is too old, dropping\n", part, table)
				err := TableDropPartition(s.pool, table, part)
				if err != nil {
					log.Printf("WARNING: dropping partition failed: %v\n", err)
					// Do not increment the failures metric as the parallel process may have dropped it already.
				}
				totalDropped++
				metrics.ChDroppedPartitions.Inc()
			}
		}
	}

	log.Printf("finished old partition cleanup in %v, dropped %v partitions\n", time.Since(startTime), totalDropped)
}

func (s *ClickhouseSender) refreshSizes() {
	s.updateTableSizes()

	ticker := time.NewTicker(refreshSizesInterval)
	defer ticker.Stop()
	for range ticker.C {
		if s.closed.Load() {
			return
		}
		s.updateTableSizes()
	}
}

func (s *ClickhouseSender) updateTableSizes() {
	defer func() {
		if r := recover(); r != nil {
			log.Println("ERROR: recovered in updateTableSizes ", r, string(debug.Stack()))
		}
	}()

	tableInfos, err := GetTableInfos(s.pool)
	if err != nil {
		log.Printf("ERROR: getting table sizes failed: %v\n", err)
		return
	}
	tableSizes := make(map[string]int64, len(tableInfos))
	tableRows := make(map[string]int64, len(tableInfos))
	for table, info := range tableInfos {
		tableSizes[table] = info.Bytes
		tableRows[table] = info.Rows
	}
	metrics.SetChTableSizes(tableSizes)
	metrics.SetChTableRows(tableRows)
}

func (s *ClickhouseSender) getDropAfterDaysFor(name string) int {
	if ret, ok := s.customDropAfterDays[name]; ok {
		return ret
	}
	return s.dropAfterDays
}

func (s *ClickhouseSender) deleteTableFromSchemaCache(tableName string) {
	s.schemaCacheMu.Lock()
	delete(s.schemaCache, tableName)
	s.schemaCacheMu.Unlock()
}
