package db

import (
	"context"
	"crypto/tls"
	"database/sql"
	"fmt"
	"net/url"
	"strings"
	"time"

	"github.com/ClickHouse/clickhouse-go"

	"a.yandex-team.ru/library/go/certifi"
	"a.yandex-team.ru/security/gideon/gideon/pkg/events"
	"a.yandex-team.ru/security/gideon/viewer/internal/config"
	"a.yandex-team.ru/security/gideon/viewer/internal/models"
)

var baseSelect = func() string {
	fields := []string{
		"TS",
		"Source",
		"Host",
		"Kind",
		"Proc_Pid",
		"Proc_Name",
		"Proc_Ppid",
		"Proc_ParentName",
		"Proc_Uid",
		"Proc_SessionID",
		"Proc_Container",
		"Proc_PodID",
		"Proc_PodSetID",
		"Proc_NannyID",
		"ProcExec_Exe",
		"ProcExec_Args",
		"Ptrace_RetCode",
		"Ptrace_Request",
		"Ptrace_Target",
		"Connect_Family",
		"Connect_DstAddr",
		"Connect_DstPort",
		"SSHSession_ID",
		"SSHSession_User",
		"SSHSession_TTY",
		"SSHSession_Kind",
		"OpenAt_RetCode",
		"OpenAt_FD",
		"OpenAt_Filename",
		"OpenAt_Flags",
	}
	return fmt.Sprintf("SELECT %s FROM gevents", strings.Join(fields, ","))
}()

type DB struct {
	db *sql.DB
}

type SuggestRequest struct {
	FullSearch  bool
	Key         string
	ValuePrefix string
	Where       string
	Args        []interface{}
}

type QuerySSHSessionRequest struct {
	TS           uint64
	SSHSessionID string
	SessionID    uint32
}

func NewDB(cfg config.ClickHouse) (*DB, error) {
	caCertPool, _ := certifi.NewCertPoolInternal()
	if caCertPool != nil {
		_ = clickhouse.RegisterTLSConfig("mdb", &tls.Config{
			RootCAs: caCertPool,
		})
	}

	sqlDB, err := sql.Open("clickhouse", cfg.URI(url.Values{
		"tls_config":               {"mdb"},
		"connection_open_strategy": {"time_random"},
		"pool_size":                {"1000"},
	}))
	if err != nil {
		return nil, fmt.Errorf("ch open failed: %w", err)
	}

	if err := sqlDB.Ping(); err != nil {
		return nil, fmt.Errorf("ping failed: %w", err)
	}

	return &DB{
		db: sqlDB,
	}, nil
}

func (d *DB) Ping(ctx context.Context) error {
	return d.db.PingContext(ctx)
}

func (d *DB) QueryEvents(ctx context.Context, where string, args []interface{}) ([]models.JsEvent, error) {
	if where == "" || len(args) == 0 {
		return nil, ErrEmptyQuery
	}

	rows, err := d.db.QueryContext(ctx, baseSelect+" WHERE "+where+" ORDER BY TS ASC LIMIT 1000", args...)
	if err != nil {
		return nil, err
	}

	var out []models.JsEvent
	for rows.Next() {
		var event models.JsEvent
		err := rows.Scan(
			&event.TS,
			&event.Source,
			&event.Host,
			&event.Kind,

			&event.Proc.Pid,
			&event.Proc.Name,
			&event.Proc.Ppid,
			&event.Proc.ParentName,
			&event.Proc.UID,
			&event.Proc.SessionID,
			&event.Proc.Container,
			&event.Proc.PodID,
			&event.Proc.PodSetID,
			&event.Proc.NannyServiceID,

			&event.ProcExec.Exe,
			&event.ProcExec.Args,

			&event.Ptrace.RetCode,
			&event.Ptrace.Request,
			&event.Ptrace.Target,

			&event.Connect.Family,
			&event.Connect.DstAddr,
			&event.Connect.DstPort,

			&event.SSHSession.ID,
			&event.SSHSession.User,
			&event.SSHSession.TTY,
			&event.SSHSession.Kind,

			&event.OpenAt.RetCode,
			&event.OpenAt.FD,
			&event.OpenAt.Filename,
			&event.OpenAt.Flags,
		)

		if err != nil {
			return nil, err
		}

		out = append(out, event)
	}

	return out, rows.Err()
}

func (d *DB) QuerySessionEvents(ctx context.Context, info models.SSHSessionInfo) ([]models.JsEvent, error) {
	rows, err := d.db.QueryContext(
		ctx,
		baseSelect+" WHERE TS >= ? AND TS <= ? AND Host = ? AND Proc_SessionID = ? AND Proc_PodID = ? ORDER BY TS ASC",
		info.TS,
		info.TS+uint64(24*time.Hour),
		info.Host,
		info.SessionID,
		info.PodID,
	)

	if err != nil {
		return nil, err
	}

	var out []models.JsEvent
	for rows.Next() {
		var event models.JsEvent
		err := rows.Scan(
			&event.TS,
			&event.Source,
			&event.Host,
			&event.Kind,

			&event.Proc.Pid,
			&event.Proc.Name,
			&event.Proc.Ppid,
			&event.Proc.ParentName,
			&event.Proc.UID,
			&event.Proc.SessionID,
			&event.Proc.Container,
			&event.Proc.PodID,
			&event.Proc.PodSetID,
			&event.Proc.NannyServiceID,

			&event.ProcExec.Exe,
			&event.ProcExec.Args,

			&event.Ptrace.RetCode,
			&event.Ptrace.Request,
			&event.Ptrace.Target,

			&event.Connect.Family,
			&event.Connect.DstAddr,
			&event.Connect.DstPort,

			&event.SSHSession.ID,
			&event.SSHSession.User,
			&event.SSHSession.TTY,
			&event.SSHSession.Kind,

			&event.OpenAt.RetCode,
			&event.OpenAt.FD,
			&event.OpenAt.Filename,
			&event.OpenAt.Flags,
		)

		if err != nil {
			return nil, err
		}

		out = append(out, event)
	}

	return out, rows.Err()
}

func (d *DB) QuerySessionsEvents(ctx context.Context, fromTS, toTS int64, infos []models.SSHSessionInfo) ([]models.JsEvent, error) {
	if len(infos) == 0 {
		return nil, nil
	}

	var query strings.Builder
	query.WriteString(baseSelect)
	query.WriteString(" WHERE TS >= ? AND TS <= ? AND (Host, Proc_SessionID, Proc_PodID) IN (")
	args := []interface{}{fromTS, toTS}
	for i, info := range infos {
		if i != 0 {
			query.WriteByte(',')
		}
		query.WriteString("(?, ?, ?)")
		args = append(args, info.Host, info.SessionID, info.PodID)
	}
	query.WriteString(") ORDER BY TS ASC")
	rows, err := d.db.QueryContext(ctx, query.String(), args...)

	if err != nil {
		return nil, err
	}

	var out []models.JsEvent
	for rows.Next() {
		var event models.JsEvent
		err := rows.Scan(
			&event.TS,
			&event.Source,
			&event.Host,
			&event.Kind,

			&event.Proc.Pid,
			&event.Proc.Name,
			&event.Proc.Ppid,
			&event.Proc.ParentName,
			&event.Proc.UID,
			&event.Proc.SessionID,
			&event.Proc.Container,
			&event.Proc.PodID,
			&event.Proc.PodSetID,
			&event.Proc.NannyServiceID,

			&event.ProcExec.Exe,
			&event.ProcExec.Args,

			&event.Ptrace.RetCode,
			&event.Ptrace.Request,
			&event.Ptrace.Target,

			&event.Connect.Family,
			&event.Connect.DstAddr,
			&event.Connect.DstPort,

			&event.SSHSession.ID,
			&event.SSHSession.User,
			&event.SSHSession.TTY,
			&event.SSHSession.Kind,

			&event.OpenAt.RetCode,
			&event.OpenAt.FD,
			&event.OpenAt.Filename,
			&event.OpenAt.Flags,
		)

		if err != nil {
			return nil, err
		}

		out = append(out, event)
	}

	return out, rows.Err()
}

func (d *DB) QuerySSHSession(ctx context.Context, req QuerySSHSessionRequest) (models.JsEvent, error) {
	if req.SSHSessionID == "" || req.TS == 0 {
		return models.JsEvent{}, ErrEmptyQuery
	}

	query := "SELECT " +
		"TS, " +
		"Source, " +
		"Host, " +
		"Kind, " +
		"Proc_Pid, " +
		"Proc_Name, " +
		"Proc_Ppid, " +
		"Proc_ParentName, " +
		"Proc_Uid, " +
		"Proc_SessionID, " +
		"Proc_Container, " +
		"Proc_PodID, " +
		"Proc_PodSetID, " +
		"Proc_NannyID, " +
		"SSHSession_ID, " +
		"SSHSession_User, " +
		"SSHSession_TTY, " +
		"SSHSession_Kind " +
		"FROM gevents WHERE TS >= ? AND TS <= ? AND Kind = ? AND SSHSession_ID = ?"

	args := []interface{}{
		req.TS - 60*uint64(time.Minute),
		req.TS + 60*uint64(time.Minute),
		events.EventKind_EK_SSH_SESSION,
		req.SSHSessionID,
	}

	if req.SessionID > 0 {
		query += " AND Proc_SessionID = ?"
		args = append(args, req.SessionID)
	}

	rows, err := d.db.QueryContext(ctx, query+" ORDER BY TS ASC LIMIT 1", args...)
	if err != nil {
		return models.JsEvent{}, err
	}

	if err := rows.Err(); err != nil {
		return models.JsEvent{}, err
	}

	if !rows.Next() {
		return models.JsEvent{}, ErrSessionNotFound
	}

	var event models.JsEvent
	err = rows.Scan(
		&event.TS,
		&event.Source,
		&event.Host,
		&event.Kind,

		&event.Proc.Pid,
		&event.Proc.Name,
		&event.Proc.Ppid,
		&event.Proc.ParentName,
		&event.Proc.UID,
		&event.Proc.SessionID,
		&event.Proc.Container,
		&event.Proc.PodID,
		&event.Proc.PodSetID,
		&event.Proc.NannyServiceID,

		&event.SSHSession.ID,
		&event.SSHSession.User,
		&event.SSHSession.TTY,
		&event.SSHSession.Kind,
	)

	return event, err
}

func (d *DB) SuggestKey(ctx context.Context, req SuggestRequest) ([]string, error) {
	if len(req.Args) == 0 {
		return nil, ErrEmptyQuery
	}

	var where strings.Builder
	where.WriteString(req.Where)
	if req.ValuePrefix != "" {
		if where.Len() > 0 {
			where.WriteString(" AND ")
		}

		if req.FullSearch {
			where.WriteString("like(")
			where.WriteString(req.Key)
			where.WriteString(", ? )")
			req.Args = append(req.Args, "%"+req.ValuePrefix+"%")
		} else {
			where.WriteString("startsWith(")
			where.WriteString(req.Key)
			where.WriteString(", ? )")
			req.Args = append(req.Args, req.ValuePrefix)
		}
	}

	query := fmt.Sprintf("SELECT DISTINCT %s FROM gevents WHERE %s ORDER BY TS ASC LIMIT 100 ", req.Key, where.String())
	rows, err := d.db.QueryContext(ctx, query, req.Args...)
	if err != nil {
		return nil, err
	}

	var out []string
	for rows.Next() {
		var value string
		err := rows.Scan(&value)
		if err != nil {
			return nil, err
		}

		out = append(out, value)
	}

	return out, rows.Err()
}

func (d *DB) Close(_ context.Context) error {
	return d.db.Close()
}
