package db

import (
	"context"
	"fmt"
	"time"

	"github.com/ClickHouse/clickhouse-go"
	"github.com/ClickHouse/clickhouse-go/lib/data"

	"a.yandex-team.ru/library/go/core/log"
	"a.yandex-team.ru/library/go/core/log/nop"
	"a.yandex-team.ru/security/gideon/speedy-beaver/internal/config"
	"a.yandex-team.ru/security/gideon/speedy-beaver/internal/events"
	"a.yandex-team.ru/security/gideon/speedy-beaver/internal/unistat"
)

type EventSaver struct {
	ctx        context.Context
	cancelCtx  context.CancelFunc
	cfg        config.ClickHouse
	eventsKind events.EventKind
	db         *chDB
	unistat    *unistat.Sensor
	log        log.Logger
	queue      chan events.Event
	closed     chan struct{}
}

func NewEventSaver(cfg config.ClickHouse, eventsKind events.EventKind, opts ...Option) (*EventSaver, error) {
	db, err := newChDB(cfg)
	if err != nil {
		return nil, err
	}

	ctx, cancelCtx := context.WithCancel(context.Background())
	saver := &EventSaver{
		ctx:        ctx,
		cancelCtx:  cancelCtx,
		cfg:        cfg,
		eventsKind: eventsKind,
		db:         db,
		log:        &nop.Logger{},
		queue:      make(chan events.Event, cfg.BatchSize*3),
		closed:     make(chan struct{}),
	}

	for _, opt := range opts {
		switch o := opt.(type) {
		case LoggerOption:
			saver.log = log.With(o.Logger, log.String("kind", eventsKind.String()))
			saver.db.log = saver.log
		case SensorOption:
			saver.unistat = o.Unistat
		}
	}

	go saver.loop()
	return saver, nil
}

func (d *EventSaver) CreateTables() error {
	return d.db.Run(context.Background(), func(conn clickhouse.Clickhouse) error {
		_, err := conn.Begin()
		if err != nil {
			return fmt.Errorf("can't begin tx: %w", err)
		}

		stmt, err := conn.Prepare(createTablesQ)
		if err != nil {
			return fmt.Errorf("can't prepare create table: %w", err)
		}

		if _, err := stmt.Exec(nil); err != nil {
			return err
		}

		if err := conn.Commit(); err != nil {
			return fmt.Errorf("commit failed: %w", err)
		}

		return nil
	})
}

func (d *EventSaver) Ping(ctx context.Context) error {
	return d.db.Ping(ctx)
}

func (d *EventSaver) LogEvent(e events.Event) error {
	if e.Kind != d.eventsKind {
		return fmt.Errorf("unexpcted event kind: %s (actual) != %s (expected)", e.Kind, d.eventsKind)
	}

	select {
	case d.queue <- e:
		return nil
	case <-d.ctx.Done():
		return d.ctx.Err()
	}
}

func (d *EventSaver) loop() {
	defer close(d.closed)

	var saver func(evts []events.Event) error
	switch d.eventsKind {
	case events.EventKindProcExec:
		saver = d.saveProcExecEvents
	case events.EventKindPTrace:
		saver = d.savePTraceEvents
	case events.EventKindConnect:
		saver = d.saveConnectEvents
	case events.EventKindExecveAt:
		saver = d.saveExecveAtEvents
	case events.EventKindSSHSession:
		saver = d.saveSSHSessionEvents
	case events.EventKindOpenAt:
		saver = d.saveOpenAtEvents
	}

	var evts []events.Event
	save := func(force bool) error {
		if len(evts) == 0 {
			return nil
		}

		if len(evts) < d.cfg.BatchSize && !force {
			return nil
		}

		if d.unistat != nil {
			d.unistat.ChBatch(1)
			d.unistat.Write(int64(len(evts)))
		}

		err := saver(evts)
		evts = evts[:0]
		return err
	}

	t := time.NewTicker(d.cfg.BatchTick)
	defer t.Stop()

	for {
		forceSave := false
		forceExit := false
		select {
		case <-d.ctx.Done():
			forceExit = true
			forceSave = true
		case <-t.C:
			forceSave = true
		case e := <-d.queue:
			evts = append(evts, e)
		}

		if err := save(forceSave); err != nil {
			d.log.Error("failed to save events", log.Error(err))
			d.unistat.ChError(1)
		}

		if forceExit {
			return
		}
	}
}

func (d *EventSaver) saveEvent(query string, fillEvent func(block *data.Block, c int, e events.Event) error, evts []events.Event) error {
	return d.db.Run(context.Background(), func(conn clickhouse.Clickhouse) error {
		tx, err := conn.Begin()
		if err != nil {
			return fmt.Errorf("failed to start tx: %w", err)
		}
		defer func() {
			_ = tx.Rollback()
		}()

		_, err = conn.Prepare(query)
		if err != nil {
			return fmt.Errorf("failed to prepare smt: %w", err)
		}

		block, err := conn.Block()
		if err != nil {
			return fmt.Errorf("can't allocate CH block: %w", err)
		}

		block.Reserve()
		block.NumRows = uint64(len(evts))

		fillBaseEvent := func(c int, e events.Event) (int, error) {
			if err := block.WriteUInt64(c, e.Ts); err != nil {
				return c, err
			}

			c++
			if err := block.WriteString(c, e.Source); err != nil {
				return c, err
			}

			c++
			if err := block.WriteString(c, e.Host); err != nil {
				return c, err
			}

			c++
			if err := block.WriteInt32(c, int32(e.Kind)); err != nil {
				return c, err
			}

			c++
			if err := block.WriteUInt32(c, e.Proc.Pid); err != nil {
				return c, err
			}

			c++
			if err := block.WriteString(c, e.Proc.Name); err != nil {
				return c, err
			}

			c++
			if err := block.WriteUInt32(c, e.Proc.Ppid); err != nil {
				return c, err
			}

			c++
			if err := block.WriteString(c, e.Proc.ParentName); err != nil {
				return c, err
			}

			c++
			if err := block.WriteUInt32(c, e.Proc.Uid); err != nil {
				return c, err
			}

			c++
			if err := block.WriteUInt32(c, e.Proc.SessionId); err != nil {
				return c, err
			}

			c++
			if err := block.WriteString(c, e.Proc.Container); err != nil {
				return c, err
			}

			c++
			if err := block.WriteString(c, e.Proc.PodId); err != nil {
				return c, err
			}

			c++
			if err := block.WriteString(c, e.Proc.PodSetId); err != nil {
				return c, err
			}

			c++
			if err := block.WriteString(c, e.Proc.NannyId); err != nil {
				return c, err
			}

			c++
			return c, nil
		}

		for _, e := range evts {
			c, err := fillBaseEvent(0, e)
			if err != nil {
				return fmt.Errorf("failed to fill base event info: %w", err)
			}

			err = fillEvent(block, c, e)
			if err != nil {
				return fmt.Errorf("failed to fill concrete event info: %w", err)
			}
		}

		if err := conn.WriteBlock(block); err != nil {
			return fmt.Errorf("failed to write block: %w", err)
		}

		if err := tx.Commit(); err != nil {
			return fmt.Errorf("failed to commit tx: %w", err)
		}
		return nil
	})
}

func (d *EventSaver) saveProcExecEvents(evts []events.Event) error {
	return d.saveEvent(
		insertProcExecQ,
		func(block *data.Block, c int, e events.Event) error {
			/*
				ProcExec_Exe String,
				ProcExec_Args Array(String),
			*/
			details := e.GetProcExec()
			if err := block.WriteString(c, details.Exe); err != nil {
				return err
			}

			c++
			if err := block.WriteArray(c, details.Args); err != nil {
				return err
			}

			return nil
		},
		evts,
	)
}

func (d *EventSaver) saveSSHSessionEvents(evts []events.Event) error {
	return d.saveEvent(
		insertSSHSessionQ,
		func(block *data.Block, c int, e events.Event) error {
			/*
				SSHSession_ID,
				SSHSession_User,
				SSHSession_TTY,
				SSHSession_Kind
			*/
			details := e.GetSshSession()
			if err := block.WriteString(c, details.Id); err != nil {
				return err
			}

			c++
			if err := block.WriteString(c, details.User); err != nil {
				return err
			}

			c++
			if err := block.WriteString(c, details.Tty); err != nil {
				return err
			}

			c++
			if err := block.WriteInt32(c, int32(details.Kind)); err != nil {
				return err
			}
			return nil
		},
		evts,
	)
}

func (d *EventSaver) savePTraceEvents(evts []events.Event) error {
	return d.saveEvent(
		insertPtraceQ,
		func(block *data.Block, c int, e events.Event) error {
			/*
				Ptrace_RetCode Int64,
				Ptrace_Request Int64,
				Ptrace_Target Int32,
			*/
			details := e.GetPtrace()
			if err := block.WriteInt64(c, details.RetCode); err != nil {
				return err
			}

			c++
			if err := block.WriteInt64(c, details.Request); err != nil {
				return err
			}

			c++
			if err := block.WriteInt32(c, details.Target); err != nil {
				return err
			}

			return nil
		},
		evts,
	)
}

func (d *EventSaver) saveConnectEvents(evts []events.Event) error {
	return d.saveEvent(
		insertConnectQ,
		func(block *data.Block, c int, e events.Event) error {
			/*
				Connect_Family Int32,
				Connect_SrcAddr String,
				Connect_SrcPort UInt32,
				Connect_DstAddr String,
				Connect_DstPort UInt32,
			*/
			details := e.GetConnect()
			if err := block.WriteInt32(c, int32(details.Family)); err != nil {
				return err
			}

			c++
			if err := block.WriteString(c, details.SrcAddr); err != nil {
				return err
			}

			c++
			if err := block.WriteUInt32(c, details.SrcPort); err != nil {
				return err
			}

			c++
			if err := block.WriteString(c, details.DstAddr); err != nil {
				return err
			}

			c++
			if err := block.WriteUInt32(c, details.DstPort); err != nil {
				return err
			}
			return nil
		},
		evts,
	)
}

func (d *EventSaver) saveExecveAtEvents(evts []events.Event) error {
	return d.saveEvent(
		insertExecveAtQ,
		func(block *data.Block, c int, e events.Event) error {
			/*
				ExecveAt_RetCode Int64,
				ExecveAt_FD Int32,
				ExecveAt_Filename String,
				ExecveAt_Args Array(String)
			*/
			details := e.GetExecveAt()
			if err := block.WriteInt64(c, details.RetCode); err != nil {
				return err
			}

			c++
			if err := block.WriteInt64(c, details.Fd); err != nil {
				return err
			}

			c++
			if err := block.WriteString(c, details.Filename); err != nil {
				return err
			}

			c++
			if err := block.WriteArray(c, details.Args); err != nil {
				return err
			}

			return nil
		},
		evts,
	)
}
func (d *EventSaver) saveOpenAtEvents(evts []events.Event) error {
	return d.saveEvent(
		insertOpenAtQ,
		func(block *data.Block, c int, e events.Event) error {
			/*
				OpenAt_RetCode Int64,
				OpenAt_FD Int64,
				OpenAt_Filename String,
				OpenAt_Flags Int32
			*/
			details := e.GetOpenAt()
			if err := block.WriteInt64(c, details.RetCode); err != nil {
				return err
			}

			c++
			if err := block.WriteInt64(c, details.Fd); err != nil {
				return err
			}

			c++
			if err := block.WriteString(c, details.Filename); err != nil {
				return err
			}

			c++
			if err := block.WriteInt32(c, details.Flags); err != nil {
				return err
			}

			return nil
		},
		evts,
	)
}

func (d *EventSaver) Close(ctx context.Context) error {
	d.cancelCtx()

	select {
	case <-d.closed:
	case <-ctx.Done():
		return ctx.Err()
	}

	return d.db.Close()
}
