package exportdb

import (
	"context"
	"fmt"
	"time"

	"a.yandex-team.ru/kikimr/public/sdk/go/ydb"
	"a.yandex-team.ru/kikimr/public/sdk/go/ydb/table"
	"a.yandex-team.ru/security/ssh-exporter/internal/config"
	"a.yandex-team.ru/security/ssh-exporter/internal/models"
)

type DB struct {
	sp   *table.SessionPool
	path string
}

func NewDB(ctx context.Context, cfg config.YDB) (*DB, error) {
	driverConfig := &ydb.DriverConfig{
		Database:    cfg.Database,
		Credentials: &ydb.AuthTokenCredentials{AuthToken: cfg.Token},
	}

	driver, err := (&ydb.Dialer{
		DriverConfig: driverConfig,
	}).Dial(ctx, cfg.Endpoint)

	if err != nil {
		return nil, fmt.Errorf("dial error: %v", err)
	}

	tableClient := table.Client{
		Driver: driver,
	}

	sp := table.SessionPool{
		IdleThreshold: 10 * time.Second,
		Builder:       &tableClient,
	}

	if err := createTables(ctx, &sp, cfg.Path); err != nil {
		return nil, err
	}

	return &DB{
		sp:   &sp,
		path: cfg.Path,
	}, nil
}

func (d *DB) Close(ctx context.Context) error {
	return d.sp.Close(ctx)
}

func (d *DB) DropTables(ctx context.Context) error {
	reqs := []string{
		dropRowACLQuery(d.path),
		dropSSHAccessQuery(d.path),
		dropProjectsQuery(d.path),
	}

	for _, query := range reqs {
		_ = table.Retry(ctx, d.sp,
			table.OperationFunc(func(ctx context.Context, s *table.Session) (err error) {
				return s.ExecuteSchemeQuery(ctx, query)
			}),
		)
	}

	// TODO(buglloc): check err
	return nil
}

func (d *DB) CreateCommonTables(ctx context.Context) error {
	query := createCommonTablesQuery(d.path)
	return table.Retry(ctx, d.sp,
		table.OperationFunc(func(ctx context.Context, s *table.Session) (err error) {
			return s.ExecuteSchemeQuery(ctx, query)
		}),
	)
}

func (d *DB) InsertSSHUsage(ctx context.Context, usages []models.SSHUsage) error {
	if len(usages) == 0 {
		return nil
	}

	writeTx := table.TxControl(
		table.BeginTx(
			table.WithSerializableReadWrite(),
		),
		table.CommitTx(),
	)

	ydbUsages := make([]ydb.Value, len(usages))
	for i, usage := range usages {
		ydbUsages[i] = ydb.StructValue(
			ydb.StructFieldValue("system_id", ydb.UTF8Value(usage.SystemID)),
			ydb.StructFieldValue("sync_time", ydb.DatetimeValueFromTime(usage.SyncTime)),
			ydb.StructFieldValue("staff_user", ydb.UTF8Value(usage.StaffUser)),
			ydb.StructFieldValue("target_user", ydb.UTF8Value(usage.TargetUser)),
			ydb.StructFieldValue("count", ydb.Uint32Value(usage.Count)),
		)
	}

	query := upsertSSHUsageQuery(d.path)
	return table.Retry(ctx, d.sp,
		table.OperationFunc(func(ctx context.Context, s *table.Session) (err error) {
			stmt, err := s.Prepare(ctx, query)
			if err != nil {
				return err
			}

			_, _, err = stmt.Execute(ctx, writeTx, table.NewQueryParameters(
				table.ValueParam("$ssh_usage", ydb.ListValue(ydbUsages...)),
			))
			return err
		}),
	)
}

func (d *DB) InsertProjects(ctx context.Context, projects []models.Project) error {
	if len(projects) == 0 {
		return nil
	}

	writeTx := table.TxControl(
		table.BeginTx(
			table.WithSerializableReadWrite(),
		),
		table.CommitTx(),
	)

	ydbProjects := make([]ydb.Value, len(projects))
	for i, prj := range projects {
		ydbProjects[i] = ydb.StructValue(
			ydb.StructFieldValue("system_id", ydb.UTF8Value(prj.SystemID)),
			ydb.StructFieldValue("project", ydb.UTF8Value(prj.Project)),
		)
	}

	query := upsertProjectsQuery(d.path)
	return table.Retry(ctx, d.sp,
		table.OperationFunc(func(ctx context.Context, s *table.Session) (err error) {
			stmt, err := s.Prepare(ctx, query)
			if err != nil {
				return err
			}

			_, _, err = stmt.Execute(ctx, writeTx, table.NewQueryParameters(
				table.ValueParam("$projects", ydb.ListValue(ydbProjects...)),
			))
			return err
		}),
	)
}

func (d *DB) InsertACL(ctx context.Context, acls []models.RowACL) error {
	if len(acls) == 0 {
		return nil
	}

	writeTx := table.TxControl(
		table.BeginTx(
			table.WithSerializableReadWrite(),
		),
		table.CommitTx(),
	)

	ydbACLs := make([]ydb.Value, len(acls))
	for i, acl := range acls {
		ydbACLs[i] = ydb.StructValue(
			ydb.StructFieldValue("system_id", ydb.UTF8Value(acl.SystemID)),
			ydb.StructFieldValue("user_id", ydb.Uint64Value(acl.UserID)),
			ydb.StructFieldValue("staff_user", ydb.UTF8Value(acl.StaffUser)),
		)
	}

	query := upsertRowACLQuery(d.path)
	return table.Retry(ctx, d.sp,
		table.OperationFunc(func(ctx context.Context, s *table.Session) (err error) {
			stmt, err := s.Prepare(ctx, query)
			if err != nil {
				return err
			}

			_, _, err = stmt.Execute(ctx, writeTx, table.NewQueryParameters(
				table.ValueParam("$acl", ydb.ListValue(ydbACLs...)),
			))
			return err
		}),
	)
}

func (d *DB) InsertSSHPermissions(ctx context.Context, perms []models.SSHPermission) error {
	if len(perms) == 0 {
		return nil
	}

	writeTx := table.TxControl(
		table.BeginTx(
			table.WithSerializableReadWrite(),
		),
		table.CommitTx(),
	)

	ydbPerms := make([]ydb.Value, len(perms))
	for i, perm := range perms {
		ydbPerms[i] = ydb.StructValue(
			ydb.StructFieldValue("system_id", ydb.UTF8Value(perm.SystemID)),
			ydb.StructFieldValue("staff_user", ydb.UTF8Value(perm.StaffUser)),
		)
	}

	query := upsertSSHPermissionsQuery(d.path)
	return table.Retry(ctx, d.sp,
		table.OperationFunc(func(ctx context.Context, s *table.Session) (err error) {
			stmt, err := s.Prepare(ctx, query)
			if err != nil {
				return err
			}

			_, _, err = stmt.Execute(ctx, writeTx, table.NewQueryParameters(
				table.ValueParam("$permission", ydb.ListValue(ydbPerms...)),
			))
			return err
		}),
	)
}

func createTables(ctx context.Context, sp *table.SessionPool, prefix string) error {
	query := createSSHUsageQuery(prefix)
	return table.Retry(ctx, sp,
		table.OperationFunc(func(ctx context.Context, s *table.Session) (err error) {
			return s.ExecuteSchemeQuery(ctx, query)
		}),
	)
}
