package db

import (
	"context"
	"database/sql/driver"
	"fmt"
	"sort"
	"strconv"

	"github.com/ClickHouse/clickhouse-go"

	"a.yandex-team.ru/security/gideon/speedy-beaver/internal/config"
)

type SystemDB struct {
	db        *chDB
	listParts string
}

func NewSystemDB(cfg config.ClickHouse) (*SystemDB, error) {
	db, err := newChDB(cfg)
	if err != nil {
		return nil, err
	}

	systemDB := &SystemDB{
		db:        db,
		listParts: fmt.Sprintf(listPartitionsQF, cfg.Database),
	}

	return systemDB, nil
}

func (d *SystemDB) Ping(ctx context.Context) error {
	return d.db.Ping(ctx)
}

func (d *SystemDB) ListPartitions(ctx context.Context) ([]int, error) {
	var parts []int
	err := d.db.Run(ctx, func(conn clickhouse.Clickhouse) error {
		tx, err := conn.Begin()
		if err != nil {
			return fmt.Errorf("failed to start tx: %w", err)
		}
		defer func() {
			_ = tx.Rollback()
		}()

		stmt, err := conn.Prepare(d.listParts)
		if err != nil {
			return fmt.Errorf("failed to prepare stmt: %w", err)
		}

		rows, err := stmt.(driver.StmtQueryContext).QueryContext(ctx, nil)
		if err != nil {
			return fmt.Errorf("failed to exec stmt: %w", err)
		}

		columns := rows.Columns()
		if len(columns) != 1 {
			return fmt.Errorf("invalid columns count: %d != %d", len(columns), 1)
		}

		row := make([]driver.Value, 1)
		for rows.Next(row) == nil {
			part, ok := row[0].(string)
			if !ok {
				return fmt.Errorf("invalid partition column type: %T", row[0])
			}

			partNum, err := strconv.Atoi(part)
			if err != nil {
				return fmt.Errorf("invalid partition name %q: %w", part, err)
			}

			parts = append(parts, partNum)
		}

		if err := tx.Commit(); err != nil {
			return fmt.Errorf("failed to commit tx: %w", err)
		}
		return nil
	})

	sort.Ints(parts)
	return parts, err
}

func (d *SystemDB) DropPartition(ctx context.Context, partition int) error {
	return d.db.Run(ctx, func(conn clickhouse.Clickhouse) error {
		stmt, err := conn.Prepare(fmt.Sprintf(dropPartitionsQF, partition))
		if err != nil {
			return fmt.Errorf("failed to prepare stmt: %w", err)
		}
		defer func() { _ = stmt.Close() }()

		_, err = stmt.(driver.StmtExecContext).ExecContext(ctx, nil)
		if err != nil {
			return fmt.Errorf("exec fail: %w", err)
		}
		return nil
	})
}

func (d *SystemDB) Close(_ context.Context) error {
	return d.db.Close()
}
