package exportdb

import (
	"context"
	"fmt"
	"time"

	"a.yandex-team.ru/kikimr/public/sdk/go/ydb"
	"a.yandex-team.ru/kikimr/public/sdk/go/ydb/table"
	"a.yandex-team.ru/library/go/yandex/tvm"
	"a.yandex-team.ru/security/libs/go/ydbtvm"
	"a.yandex-team.ru/security/skotty/datalens-exporter/internal/config"
	"a.yandex-team.ru/security/skotty/datalens-exporter/internal/models"
)

type DB struct {
	sp   *table.SessionPool
	path string
}

func NewDB(ctx context.Context, tvmc tvm.Client, cfg config.YDB) (*DB, error) {
	driverConfig := &ydb.DriverConfig{
		Database: cfg.Database,
		Credentials: &ydbtvm.TvmCredentials{
			DstID:     ydbtvm.YDBClientID,
			TvmClient: tvmc,
		},
	}

	driver, err := (&ydb.Dialer{
		DriverConfig: driverConfig,
	}).Dial(ctx, cfg.Endpoint)

	if err != nil {
		return nil, fmt.Errorf("dial error: %v", err)
	}

	tableClient := table.Client{
		Driver: driver,
	}

	sp := table.SessionPool{
		IdleThreshold: 10 * time.Second,
		Builder:       &tableClient,
	}

	if err := createTables(ctx, &sp, cfg.Path); err != nil {
		return nil, err
	}

	return &DB{
		sp:   &sp,
		path: cfg.Path,
	}, nil
}

func (d *DB) Close(ctx context.Context) error {
	return d.sp.Close(ctx)
}

func (d *DB) DropTemporaryTables(ctx context.Context) error {
	queries := []string{
		dropTemporaryUsersQuery(d.path),
		dropTemporaryACLQuery(d.path),
		dropTemporaryDepartmentsQuery(d.path),
	}
	for _, query := range queries {
		_ = table.Retry(ctx, d.sp,
			table.OperationFunc(func(ctx context.Context, s *table.Session) (err error) {
				return s.ExecuteSchemeQuery(ctx, query)
			}),
		)
	}

	// TODO(buglloc): check err
	return nil
}

func (d *DB) CreateTemporaryTables(ctx context.Context) error {
	query := createTemporaryTablesQuery(d.path)
	return table.Retry(ctx, d.sp,
		table.OperationFunc(func(ctx context.Context, s *table.Session) (err error) {
			return s.ExecuteSchemeQuery(ctx, query)
		}),
	)
}

func (d *DB) FinalizeTables(ctx context.Context) error {
	query := renameTablesQuery(d.path)
	return table.Retry(ctx, d.sp,
		table.OperationFunc(func(ctx context.Context, s *table.Session) (err error) {
			return s.ExecuteSchemeQuery(ctx, query)
		}),
	)
}

func (d *DB) InsertTotalUsers(ctx context.Context, count int) error {
	writeTx := table.TxControl(
		table.BeginTx(
			table.WithSerializableReadWrite(),
		),
		table.CommitTx(),
	)

	now := time.Now()
	query := upsertTotalUsersQuery(d.path)
	return table.Retry(ctx, d.sp,
		table.OperationFunc(func(ctx context.Context, s *table.Session) (err error) {
			stmt, err := s.Prepare(ctx, query)
			if err != nil {
				return err
			}

			_, _, err = stmt.Execute(ctx, writeTx, table.NewQueryParameters(
				table.ValueParam("$date", ydb.DatetimeValueFromTime(now)),
				table.ValueParam("$count", ydb.Int32Value(int32(count))),
			))
			return err
		}),
	)
}

func (d *DB) InsertUserInfos(ctx context.Context, users []models.UserInfo) error {
	writeTx := table.TxControl(
		table.BeginTx(
			table.WithSerializableReadWrite(),
		),
		table.CommitTx(),
	)

	ydbUsers := make([]ydb.Value, len(users))
	for i, user := range users {
		ydbUsers[i] = ydb.StructValue(
			ydb.StructFieldValue("login", ydb.UTF8Value(user.Login)),
			ydb.StructFieldValue("department", ydb.UTF8Value(user.Department)),
			ydb.StructFieldValue("have_ssh_keys", ydb.BoolValue(user.HaveSSHKeys)),
			ydb.StructFieldValue("enrolled", ydb.BoolValue(user.Enrolled)),
			ydb.StructFieldValue("enrolled_at", ydb.DatetimeValueFromTime(user.EnrolledAt)),
		)
	}

	query := upsertUsersQuery(d.path)
	return table.Retry(ctx, d.sp,
		table.OperationFunc(func(ctx context.Context, s *table.Session) (err error) {
			stmt, err := s.Prepare(ctx, query)
			if err != nil {
				return err
			}

			_, _, err = stmt.Execute(ctx, writeTx, table.NewQueryParameters(
				table.ValueParam("$users", ydb.ListValue(ydbUsers...)),
			))
			return err
		}),
	)
}

func (d *DB) InsertACLs(ctx context.Context, acls []models.ACL) error {
	writeTx := table.TxControl(
		table.BeginTx(
			table.WithSerializableReadWrite(),
		),
		table.CommitTx(),
	)

	ydbACLs := make([]ydb.Value, len(acls))
	for i, acl := range acls {
		ydbACLs[i] = ydb.StructValue(
			ydb.StructFieldValue("login", ydb.UTF8Value(acl.Login)),
			ydb.StructFieldValue("head_uid", ydb.Uint64Value(acl.HeadUID)),
		)
	}

	query := upsertACLQuery(d.path)
	return table.Retry(ctx, d.sp,
		table.OperationFunc(func(ctx context.Context, s *table.Session) (err error) {
			stmt, err := s.Prepare(ctx, query)
			if err != nil {
				return err
			}

			_, _, err = stmt.Execute(ctx, writeTx, table.NewQueryParameters(
				table.ValueParam("$acl", ydb.ListValue(ydbACLs...)),
			))
			return err
		}),
	)
}

func (d *DB) InsertDepartments(ctx context.Context, deps []models.Department) error {
	writeTx := table.TxControl(
		table.BeginTx(
			table.WithSerializableReadWrite(),
		),
		table.CommitTx(),
	)

	ydbDeps := make([]ydb.Value, len(deps))
	for i, dep := range deps {
		ydbDeps[i] = ydb.StructValue(
			ydb.StructFieldValue("login", ydb.UTF8Value(dep.Login)),
			ydb.StructFieldValue("department", ydb.UTF8Value(dep.Department)),
		)
	}

	query := upsertDepartmentsQuery(d.path)
	return table.Retry(ctx, d.sp,
		table.OperationFunc(func(ctx context.Context, s *table.Session) (err error) {
			stmt, err := s.Prepare(ctx, query)
			if err != nil {
				return err
			}

			_, _, err = stmt.Execute(ctx, writeTx, table.NewQueryParameters(
				table.ValueParam("$departments", ydb.ListValue(ydbDeps...)),
			))
			return err
		}),
	)
}

func createTables(ctx context.Context, sp *table.SessionPool, prefix string) error {
	query := createTablesQuery(prefix)
	return table.Retry(ctx, sp,
		table.OperationFunc(func(ctx context.Context, s *table.Session) (err error) {
			return s.ExecuteSchemeQuery(ctx, query)
		}),
	)
}
