package batcher

import (
	"errors"
	"fmt"
	"log"
	"sort"
	"strings"
	"sync"
	"time"
	"unsafe"

	"go.uber.org/atomic"

	"a.yandex-team.ru/security/osquery/osquery-sender/metrics"
	"a.yandex-team.ru/security/osquery/osquery-sender/parser"
	"a.yandex-team.ru/security/osquery/osquery-sender/util"
)

// Batcher stores events in memory in columnar format.

type Batcher struct {
	removePrefix     []string
	removeSuffix     []string
	cleanNames       bool
	appendDecorators []string

	// name -> batch
	events map[string]*EventBatch

	// Locks the batcher during modifications.
	mu sync.Mutex

	// Sum of all batches memory sizes.
	totalSize atomic.Int64
}

type EventBatch struct {
	// All slices must have the same Length.
	Length int

	// Number of columns in StringValues and Float64Values.
	NumColumns    int
	StringValues  map[string][]string
	Float64Values map[string][]float64

	// Predefined columns
	Timestamps []int64
	Actions    []string
	Hosts      []string

	// Approximate memory consumed by this batch.
	SliceSize  int64
	StringSize int64
}

func (b *EventBatch) MemorySize() int64 {
	return b.SliceSize + b.StringSize
}

var (
	float64Size = int64(unsafe.Sizeof(0.0))
	stringSize  = int64(unsafe.Sizeof(""))
)

func New(removePrefix []string, removeSuffix []string, cleanNames bool, appendDecorators []string) *Batcher {
	return &Batcher{
		removePrefix:     removePrefix,
		removeSuffix:     removeSuffix,
		cleanNames:       cleanNames,
		appendDecorators: appendDecorators,
		events:           map[string]*EventBatch{},
	}
}

func (b *Batcher) Append(events []*parser.ParsedEvent) {
	b.AppendWithTimestamp(events, time.Now())
}

func (b *Batcher) AppendWithTimestamp(events []*parser.ParsedEvent, timestamp time.Time) {
	groups := b.groupEventsByName(events)
	groupBatches := map[string]*EventBatch{}
	for name, group := range groups {
		batch := b.prepareBatch(group, timestamp)
		groupBatches[name] = batch
	}

	b.mu.Lock()
	defer b.mu.Unlock()
	for name, newBatch := range groupBatches {
		b.events[name] = AppendBatch(b.events[name], newBatch)
		b.totalSize.Add(newBatch.MemorySize())
	}
}

// Returns all accumulated event batches and resets the batcher.
func (b *Batcher) FlushAll() map[string]*EventBatch {
	b.mu.Lock()
	defer b.mu.Unlock()
	result := b.events
	b.events = map[string]*EventBatch{}
	b.totalSize.Store(0)
	return result
}

// Removes largest event batches until the remaining memory size is not larger than thresholdSize.
func (b *Batcher) FlushTop(thresholdSize int64) map[string]*EventBatch {
	b.mu.Lock()
	defer b.mu.Unlock()

	topSizes := make([]string, 0, len(b.events))
	for name := range b.events {
		topSizes = append(topSizes, name)
	}
	// Sort in reverse order by totalSize of the batch.
	sort.Slice(topSizes, func(i int, j int) bool {
		return b.events[topSizes[i]].MemorySize() > b.events[topSizes[j]].MemorySize()
	})

	result := map[string]*EventBatch{}
	removedSize := int64(0)
	threshold := b.totalSize.Load() - thresholdSize
	for _, name := range topSizes {
		batch := b.events[name]
		result[name] = batch
		delete(b.events, name)
		removedSize += batch.MemorySize()
		if removedSize >= threshold {
			break
		}
	}
	b.totalSize.Sub(removedSize)
	return result
}

func (b *Batcher) MemorySize() int64 {
	return b.totalSize.Load()
}

func (b *Batcher) groupEventsByName(events []*parser.ParsedEvent) map[string][]*parser.ParsedEvent {
	result := map[string][]*parser.ParsedEvent{}

	for _, event := range events {
		if event.Name == "" {
			continue
		}
		groupName := event.Name
		for _, prefix := range b.removePrefix {
			groupName = strings.TrimPrefix(groupName, prefix)
		}
		for _, suffix := range b.removeSuffix {
			groupName = strings.TrimSuffix(groupName, suffix)
		}
		if b.cleanNames {
			groupName = cleanName(groupName)
		}
		result[groupName] = append(result[groupName], event)
	}
	return result
}

func (b *Batcher) prepareBatch(events []*parser.ParsedEvent, timestamp time.Time) *EventBatch {
	result := &EventBatch{
		Timestamps:    make([]int64, 0, len(events)),
		Actions:       make([]string, 0, len(events)),
		Hosts:         make([]string, 0, len(events)),
		Float64Values: map[string][]float64{},
		StringValues:  map[string][]string{},
	}
	unixTimestamp := timestamp.Unix()
	for _, event := range events {
		b.processEvent(result, event, unixTimestamp, len(events))
	}
	result.UpdateSliceSize()
	return result
}

func (b *Batcher) processEvent(batch *EventBatch, event *parser.ParsedEvent, now int64, numEvents int) {
	action := getEventAction(event)
	if action == "added" || action == "removed" {
		columns, err := b.getColumnsFromEvent(event)
		if err != nil {
			log.Printf("WARNING: %v\n", err)
			metrics.IncomingParsingErrors.Inc()
			return
		}
		b.appendEvent(batch, columns, now, action, event, numEvents)
	} else if action == "snapshot" {
		columnsList, err := b.getSnapshotFromEvent(event)
		if err != nil {
			columns, err := b.getColumnsFromEvent(event)
			if err != nil {
				log.Printf("WARNING: %v\n", err)
				metrics.IncomingParsingErrors.Inc()
				return
			}
			b.appendEvent(batch, columns, now, action, event, numEvents)
		} else {
			for _, columns := range columnsList {
				b.appendEvent(batch, columns, now, action, event, numEvents)
			}
		}
	} else {
		log.Printf("WARNING: Unknown action '%s'\n", action)
		metrics.IncomingParsingErrors.Inc()
	}
	// Ignore the unknown action. NOTE: We may need to support diffResults here as well if at least
	// one installation gets configured with --log_result_events=false
}

func getEventAction(event *parser.ParsedEvent) string {
	actionValue := event.Data["action"]
	if actionValue == nil {
		return ""
	}
	result, ok := actionValue.(string)
	if !ok {
		return ""
	}
	// Do manual interning to reduce memory usage.
	switch result {
	case "added":
		return "added"
	case "removed":
		return "removed"
	case "snapshot":
		return "snapshot"
	default:
		return result
	}
}

var (
	errNoColumns = errors.New("no 'columns' in event")
)

func (b *Batcher) getColumnsFromEvent(event *parser.ParsedEvent) (map[string]interface{}, error) {
	columnsValue := event.Data["columns"]
	if columnsValue == nil {
		return nil, errNoColumns
	}
	result, ok := columnsValue.(map[string]interface{})
	if !ok {
		return nil, fmt.Errorf("'columns' has wrong type: %#v", *event)
	}
	return b.appendDecoratorsFromEvent(event, result), nil
}

func (b *Batcher) getSnapshotFromEvent(event *parser.ParsedEvent) ([]map[string]interface{}, error) {
	snapshotValue := event.Data["snapshot"]
	if snapshotValue == nil {
		return nil, fmt.Errorf("no 'snapshot' in event: %#v", *event)
	}
	result, ok := snapshotValue.([]map[string]interface{})
	if !ok {
		return nil, fmt.Errorf("'snapshot' has wrong type: %#v", *event)
	}
	ret := make([]map[string]interface{}, 0, len(result))
	for _, columns := range result {
		ret = append(ret, b.appendDecoratorsFromEvent(event, columns))
	}
	return ret, nil
}

func (b *Batcher) appendEvent(batch *EventBatch, columns map[string]interface{}, now int64, action string, event *parser.ParsedEvent, numEvents int) {
	eventColumns := 0
	for key, value := range columns {
		columnName := key
		if b.cleanNames {
			columnName = cleanName(columnName)
		}
		switch value := value.(type) {
		case float64:
			column, ok := batch.Float64Values[columnName]
			if !ok {
				column = make([]float64, batch.Length, numEvents)
				batch.NumColumns++
			}
			column = append(column, value)
			batch.Float64Values[columnName] = column
			eventColumns++
		case string:
			// The value references original JSON, copy it to reduce memory usage.
			valueCopy := util.CopyString(value)
			column, ok := batch.StringValues[columnName]
			if !ok {
				column = make([]string, batch.Length, numEvents)
				batch.NumColumns++
			}
			column = append(column, valueCopy)
			batch.StringValues[columnName] = column
			// Go memory allocator does not allocate exactly the asked number of bytes, multiply the size by 1.5 to
			// approximate the actually allocated memory.
			batch.StringSize += int64(len(valueCopy) + len(valueCopy)/2)
			eventColumns++
		default:
			log.Printf("WARNING: field '%s' has unsupported type %T: %#v", key, value, *event)
		}
	}

	if eventColumns != batch.NumColumns {
		// The event contained fewer keys than the number of columns, fill the remaining ones.
		for columnName, column := range batch.Float64Values {
			_, ok := columns[columnName]
			if !ok {
				batch.Float64Values[columnName] = append(column, 0.0)
			}
		}
		for columnName, column := range batch.StringValues {
			_, ok := columns[columnName]
			if !ok {
				batch.StringValues[columnName] = append(column, "")
			}
		}
	}

	batch.Timestamps = append(batch.Timestamps, now)
	batch.Actions = append(batch.Actions, action)
	batch.Hosts = append(batch.Hosts, event.Host)
	batch.Length++
}

func (b *Batcher) appendDecoratorsFromEvent(event *parser.ParsedEvent, columns map[string]interface{}) map[string]interface{} {
	ret := make(map[string]interface{}, len(columns))
	for key, value := range columns {
		ret[key] = value
	}
	for _, key := range b.appendDecorators {
		if value, ok := event.Data[key]; ok {
			ret[key] = value
		}
	}
	return ret
}

func AppendBatch(batch *EventBatch, newBatch *EventBatch) *EventBatch {
	if batch == nil {
		batch = &EventBatch{
			Float64Values: map[string][]float64{},
			StringValues:  map[string][]string{},
		}
	}

	batch.Timestamps = append(batch.Timestamps, newBatch.Timestamps...)
	batch.Actions = append(batch.Actions, newBatch.Actions...)
	batch.Hosts = append(batch.Hosts, newBatch.Hosts...)

	for key, newColumn := range newBatch.Float64Values {
		column, ok := batch.Float64Values[key]
		if !ok {
			column = make([]float64, batch.Length)
		}
		batch.Float64Values[key] = append(column, newColumn...)
	}
	for key, newColumn := range newBatch.StringValues {
		column, ok := batch.StringValues[key]
		if !ok {
			column = make([]string, batch.Length)
		}
		batch.StringValues[key] = append(column, newColumn...)
	}

	for key := range batch.Float64Values {
		_, ok := newBatch.Float64Values[key]
		if !ok {
			batch.Float64Values[key] = append(batch.Float64Values[key], make([]float64, newBatch.Length)...)
		}
	}
	for key := range batch.StringValues {
		_, ok := newBatch.StringValues[key]
		if !ok {
			batch.StringValues[key] = append(batch.StringValues[key], make([]string, newBatch.Length)...)
		}
	}
	batch.NumColumns = len(batch.Float64Values) + len(batch.StringValues)

	batch.Length += newBatch.Length
	batch.StringSize += newBatch.StringSize
	batch.UpdateSliceSize()

	return batch
}

func (b *EventBatch) UpdateSliceSize() {
	size := int64(0)
	for _, column := range b.Float64Values {
		size += int64(cap(column)) * float64Size
	}
	for _, column := range b.StringValues {
		size += int64(cap(column)) * stringSize
	}
	b.SliceSize = size
}
