package batcher

import (
	"log"
	"runtime/debug"
	"runtime/pprof"
	"strconv"
	"sync"
	"time"

	"go.uber.org/atomic"

	"a.yandex-team.ru/security/osquery/osquery-sender/metrics"
	"a.yandex-team.ru/security/osquery/osquery-sender/parser"
	"a.yandex-team.ru/security/osquery/osquery-sender/util"
)

// BatcherWorkers stores events in memory in columnar format and eventually flushes them either on timeout
// or when the memory becomes full.

type BatcherWorkers struct {
	maxMemory int64
	maxDelay  time.Duration

	eventBatcher *Batcher

	workerQueue           chan workerJob
	workerQueueMemorySize atomic.Int64
	workerWg              sync.WaitGroup
	maxWorkers            int

	splitDays bool
	splitLoc  *time.Location

	submitter Submitter

	workerLoad []*metrics.LoadReporter

	name string
}

type WorkersConfig struct {
	MaxMemory  int64
	MaxDelay   time.Duration
	MaxWorkers int
	// If set to true, splits events from different days into different batches. If SplitDays is true, SplitLoc must be
	// not nil, as it will be used to determine the start of the day.
	SplitDays bool
	SplitLoc  *time.Location
}

// Submitter is the interface which the code using BatcherWorkers must implement

type Submitter interface {
	// If SubmitErrors returns an error, we check if we can re-add the batch to the work queue (i.e. we have enough memory).
	SubmitEvents(name string, events *EventBatch) error

	OnDropDueToFullMemory()
	OnDropDueToFullQueue()
	OnDropAfterRetries()
	OnFlushDueToMemoryPressure()
}

type workerJob struct {
	name       string
	events     *EventBatch
	submitTime time.Time
}

const (
	// Some arbitrarily large value.
	queueSize = 100000

	reportLoadPeriod = time.Minute * 5
)

func NewWorkers(eventBatcher *Batcher, config WorkersConfig, submitter Submitter, name string) *BatcherWorkers {
	ret := &BatcherWorkers{
		maxMemory:    config.MaxMemory,
		maxDelay:     config.MaxDelay,
		eventBatcher: eventBatcher,
		workerQueue:  make(chan workerJob, queueSize),
		maxWorkers:   config.MaxWorkers,
		splitDays:    config.SplitDays,
		splitLoc:     config.SplitLoc,
		submitter:    submitter,
		name:         name,
	}
	for i := 0; i < config.MaxWorkers; i++ {
		ret.workerLoad = append(ret.workerLoad, metrics.NewLoadReporter(reportLoadPeriod))
	}
	return ret
}

func (w *BatcherWorkers) Start() {
	go util.RunWithLabels(pprof.Labels("name", "workers-"+w.name+"-periodic-flush"), func() {
		w.periodicFlush()
	})
	for i := 0; i < w.maxWorkers; i++ {
		go w.workerRun(i)
		w.workerWg.Add(1)
	}
}

func (w *BatcherWorkers) Stop() {
	w.pushBatches(w.eventBatcher.FlushAll())

	close(w.workerQueue)
	w.workerWg.Wait()
}

func (w *BatcherWorkers) Enqueue(events []*parser.ParsedEvent) {
	// Drop events if we overflow the memory limit. NOTE: There is a small window of inconsistency when the batcher size
	// got decremented but the workerQueueMemorySize has not been incremented.
	totalMemorySize := w.QueueMemory() + w.eventBatcher.MemorySize()
	if totalMemorySize > w.maxMemory {
		const Mb = 1024 * 1024
		log.Printf("ERROR: dropping %d events: %dMb in batcher, %dMb in queue, %dMb limit\n",
			len(events), w.eventBatcher.MemorySize()/Mb, w.workerQueueMemorySize.Load()/Mb, w.maxMemory/Mb)
		w.submitter.OnDropDueToFullMemory()
		return
	}

	w.eventBatcher.Append(events)

	// Try to keep memory utilization from 0.5 to 0.75 of maxMemory so that there is always some room before we start
	// dropping events due to memory limit.
	if w.eventBatcher.MemorySize() > w.maxMemory*3/4 {
		w.submitter.OnFlushDueToMemoryPressure()
		w.pushBatches(w.eventBatcher.FlushTop(w.maxMemory / 4))
	}
}

func (w *BatcherWorkers) BatcherMemory() int64 {
	return w.eventBatcher.MemorySize()
}

func (w *BatcherWorkers) QueueMemory() int64 {
	return w.workerQueueMemorySize.Load()
}

func (w *BatcherWorkers) QueueLen() int {
	return len(w.workerQueue)
}

func (w *BatcherWorkers) NumWorkers() int {
	return w.maxWorkers
}

func (w *BatcherWorkers) Load() []float64 {
	ret := make([]float64, len(w.workerLoad))
	for i, load := range w.workerLoad {
		ret[i] = load.GetAverage()
	}
	return ret
}

func (w *BatcherWorkers) ForceFlush() {
	w.pushBatches(w.eventBatcher.FlushAll())
}

func (w *BatcherWorkers) periodicFlush() {
	for range time.Tick(w.maxDelay) {
		func() {
			defer func() {
				if r := recover(); r != nil {
					log.Println("ERROR: recovered in periodicFlush ", r, string(debug.Stack()))
				}
			}()
			w.pushBatches(w.eventBatcher.FlushAll())
		}()
	}
}

func (w *BatcherWorkers) pushBatches(eventBatches map[string]*EventBatch) {
	for name, events := range eventBatches {
		splitPerDay := w.splitEventsPerDay(events)
		for _, splitEvents := range splitPerDay {
			select {
			case w.workerQueue <- workerJob{name: name, events: splitEvents}:
				w.workerQueueMemorySize.Add(splitEvents.MemorySize())
			default:
				// If we hit the queue size, there is probably some misconfiguration in the osquery: too many different
				// event names.
				log.Printf("ERROR: dropping %d events: worker queue hit %d elements\n",
					splitEvents.Length, len(w.workerQueue))
				w.submitter.OnDropDueToFullQueue()
			}
		}
	}
}

func (w *BatcherWorkers) splitEventsPerDay(events *EventBatch) []*EventBatch {
	if !w.splitDays {
		return []*EventBatch{events}
	}
	// Check the fast path: all events are from one day.
	if w.eventsInOneDay(events) {
		return []*EventBatch{events}
	}

	m := map[int]*EventBatch{}
	for i := 0; i < events.Length; i++ {
		day := timeToDate(events.Timestamps[i], w.splitLoc)
		dst, ok := m[day]
		if !ok {
			dst = copyBatchStructure(events)
			m[day] = dst
		}
		copyEvent(events, i, dst)
	}

	var result []*EventBatch
	for _, e := range m {
		e.UpdateSliceSize()
		result = append(result, e)
	}
	return result
}

func copyBatchStructure(events *EventBatch) *EventBatch {
	result := &EventBatch{
		Length:        0,
		NumColumns:    events.NumColumns,
		StringValues:  map[string][]string{},
		Float64Values: map[string][]float64{},
		Timestamps:    nil,
		Actions:       nil,
		Hosts:         nil,
		SliceSize:     0,
		StringSize:    0,
	}
	for key := range events.StringValues {
		result.StringValues[key] = nil
	}
	for key := range events.Float64Values {
		result.Float64Values[key] = nil
	}
	return result
}

func copyEvent(src *EventBatch, idx int, dst *EventBatch) {
	dst.Length++
	for key := range src.StringValues {
		value := src.StringValues[key][idx]
		dst.StringValues[key] = append(dst.StringValues[key], value)
		// See comment in batcher.go regarding the x1.5 factor.
		dst.StringSize += int64(len(value) + len(value)/2)
	}
	for key := range src.Float64Values {
		dst.Float64Values[key] = append(dst.Float64Values[key], src.Float64Values[key][idx])
	}
	dst.Timestamps = append(dst.Timestamps, src.Timestamps[idx])
	dst.Actions = append(dst.Actions, src.Actions[idx])
	dst.Hosts = append(dst.Hosts, src.Hosts[idx])
}

func (w *BatcherWorkers) eventsInOneDay(events *EventBatch) bool {
	if events.Length == 0 {
		return true
	}
	first := timeToDate(events.Timestamps[0], w.splitLoc)
	for i := 1; i < events.Length; i++ {
		if timeToDate(events.Timestamps[i], w.splitLoc) != first {
			return false
		}
	}
	return true
}

func timeToDate(timestamp int64, loc *time.Location) int {
	year, month, day := time.Unix(timestamp, 0).In(loc).Date()
	return year*10000 + int(month)*100 + day
}

func (w *BatcherWorkers) workerRun(num int) {
	for job := range w.workerQueue {
		labels := pprof.Labels("name", w.name+"-batcher-worker-"+strconv.Itoa(num), "events_name", job.name, "events_len", strconv.Itoa(job.events.Length))
		util.RunWithLabels(labels, func() {
			startTime := time.Now()
			err := w.submitEvents(job.name, job.events)
			w.workerLoad[num].ReportWork(time.Since(startTime))
			if err != nil {
				log.Printf("Retrying sending %d events", job.events.Length)
				// Try to retry the batch if there is enough memory (and there is space in queue). We drop events earlier
				// so that we still have space for new events.
				totalMemorySize := w.QueueMemory() + w.eventBatcher.MemorySize() + job.events.MemorySize()
				if totalMemorySize <= w.maxMemory*3/4 {
					w.pushBatches(map[string]*EventBatch{job.name: job.events})
				} else {
					log.Printf("ERROR: sending events failed, not enough memory to retry: %v\n", err)
					w.submitter.OnDropAfterRetries()
				}
			}
		})
	}
	w.workerWg.Done()
}

func (w *BatcherWorkers) submitEvents(name string, events *EventBatch) error {
	defer func() {
		if r := recover(); r != nil {
			log.Println("ERROR: recovered in submitEvents ", r, string(debug.Stack()))
		}
		w.workerQueueMemorySize.Sub(events.MemorySize())
	}()

	return w.submitter.SubmitEvents(name, events)
}
