package ydb

import (
	"context"
	"fmt"
	"path"
	"time"

	"a.yandex-team.ru/kikimr/public/sdk/go/ydb"
	"a.yandex-team.ru/kikimr/public/sdk/go/ydb/table"
	"a.yandex-team.ru/library/go/core/xerrors"
	"a.yandex-team.ru/security/yadi/snatcher/pkg/feed"
)

type (
	DB struct {
		ctx                      context.Context
		sp                       *table.SessionPool
		commitSize               int // rows per one commit
		selectAllVulnsQuery      string
		updateActionsQuery       string
		updateSrcQuery           string
		updateDeleteActionsQuery string
	}

	Options struct {
		Database  string
		Path      string
		Endpoint  string
		AuthToken string
		SrcType   string
	}

	UpdateData struct {
		SrcType       string
		Action        string
		Platform      string
		Vulnerability feed.Vulnerability
	}
)

var (
	roTX = table.TxControl(
		table.BeginTx(table.WithOnlineReadOnly()),
		table.CommitTx(),
	)
	rwTX = table.TxControl(
		table.BeginTx(table.WithSerializableReadWrite()),
		table.CommitTx(),
	)
)

func New(ctx context.Context, opts *Options) (*DB, error) {
	config := &ydb.DriverConfig{
		Database: opts.Database,
		Credentials: ydb.AuthTokenCredentials{
			AuthToken: opts.AuthToken,
		},
	}

	driver, err := (&ydb.Dialer{
		DriverConfig: config,
	}).Dial(ctx, opts.Endpoint)

	if err != nil {
		return nil, xerrors.Errorf("dial error: %v", err)
	}

	tableClient := table.Client{
		Driver: driver,
	}

	sp := table.SessionPool{
		IdleThreshold: 10 * time.Second,
		Builder:       &tableClient,
	}

	tablePathPrefix := path.Join(opts.Database, opts.Path)

	err = createTables(ctx, &sp, tablePathPrefix)
	if err != nil {
		return nil, xerrors.Errorf("create tables error: %v", err)
	}

	return &DB{
		ctx:                      ctx,
		sp:                       &sp,
		commitSize:               100,
		selectAllVulnsQuery:      selectAllVulnsQuery(tablePathPrefix),
		updateActionsQuery:       updateActionsQuery(tablePathPrefix),
		updateSrcQuery:           updateSrcQuery(tablePathPrefix),
		updateDeleteActionsQuery: updateDeleteActionsQuery(tablePathPrefix),
	}, nil
}

func createTables(ctx context.Context, sp *table.SessionPool, prefix string) error {
	err := table.Retry(ctx, sp,
		table.OperationFunc(func(ctx context.Context, s *table.Session) error {
			return s.CreateTable(ctx, fmt.Sprintf("%s/%s", prefix, "src"),
				// PRIMARY KEY
				table.WithColumn("key", ydb.Optional(ydb.TypeUint64)),
				table.WithColumn("srcType", ydb.Optional(ydb.TypeUTF8)),
				table.WithColumn("srcId", ydb.Optional(ydb.TypeUTF8)),

				// tracked fields
				table.WithColumn("cvssScore", ydb.Optional(ydb.TypeFloat)),
				table.WithColumn("vulnVersions", ydb.Optional(ydb.TypeUTF8)),
				table.WithColumn("title", ydb.Optional(ydb.TypeUTF8)),
				table.WithColumn("pkgName", ydb.Optional(ydb.TypeUTF8)),

				// other
				table.WithColumn("externalReferences", ydb.Optional(ydb.TypeJSON)),
				table.WithColumn("description", ydb.Optional(ydb.TypeUTF8)),
				table.WithColumn("lang", ydb.Optional(ydb.TypeUTF8)),
				table.WithColumn("patchedVersions", ydb.Optional(ydb.TypeUTF8)),
				table.WithColumn("patchExists", ydb.Optional(ydb.TypeBool)),
				table.WithColumn("disclosedAt", ydb.Optional(ydb.TypeInt64)),
				table.WithColumn("richDescription", ydb.Optional(ydb.TypeBool)),

				// meta
				table.WithColumn("updatedAt", ydb.Optional(ydb.TypeInt64)),
				table.WithColumn("yaId", ydb.Optional(ydb.TypeUTF8)),
				table.WithColumn("isDeleted", ydb.Optional(ydb.TypeBool)),

				table.WithPrimaryKeyColumn("key", "srcType", "srcId"),
			)
		}),
	)
	if err != nil {
		return xerrors.Errorf("failed to create src table: %w", err)
	}

	err = table.Retry(ctx, sp,
		table.OperationFunc(func(ctx context.Context, s *table.Session) error {
			return s.CreateTable(ctx, fmt.Sprintf("%s/%s", prefix, "actions"),
				table.WithColumn("srcType", ydb.Optional(ydb.TypeUTF8)),
				table.WithColumn("srcId", ydb.Optional(ydb.TypeUTF8)),
				table.WithColumn("vulnAction", ydb.Optional(ydb.TypeUTF8)),
				table.WithPrimaryKeyColumn("srcType", "srcId"),
			)
		}),
	)
	if err != nil {
		return xerrors.Errorf("failed to create actions table: %w", err)
	}

	return nil
}

func chunkCommit(list []ydb.Value, chunkSize int) ([][]ydb.Value, error) {
	if chunkSize == 0 {
		return nil, xerrors.New("commit chunk size must be positive")
	}

	limit := chunkSize
	var commits [][]ydb.Value
	for i := 0; i < len(list); i += limit {
		end := i + limit

		if end > len(list) {
			end = len(list)
		}

		commits = append(commits, list[i:end])
	}

	return commits, nil
}
