package db

import (
	"context"
	"encoding/json"
	"fmt"
	"strings"
	"time"

	"a.yandex-team.ru/kikimr/public/sdk/go/ydb"
	"a.yandex-team.ru/kikimr/public/sdk/go/ydb/table"
	"a.yandex-team.ru/library/go/core/xerrors"
	"a.yandex-team.ru/security/yadi/indexer/internal/dbmodels"
)

const (
	VersionsDelimiter = ";"
)

var (
	ErrNotFound = xerrors.New("record not found")
)

type (
	DB struct {
		ctx                       context.Context
		sp                        *table.SessionPool
		selectPackageVersionQuery string
		selectPackageQuery        string
		updatePackageQuery        string
		updatePackageLiteQuery    string
		cleanUpPackageQuery       string
	}

	Options struct {
		Database  string
		Path      string
		Endpoint  string
		AuthToken string
	}
)

func New(ctx context.Context, opts *Options) (*DB, error) {
	config := &ydb.DriverConfig{
		Database: opts.Database,
		Credentials: ydb.AuthTokenCredentials{
			AuthToken: opts.AuthToken,
		},
	}

	driver, err := (&ydb.Dialer{
		DriverConfig: config,
	}).Dial(ctx, opts.Endpoint)

	if err != nil {
		return nil, fmt.Errorf("dial error: %v", err)
	}

	tableClient := table.Client{
		Driver: driver,
	}

	sp := table.SessionPool{
		IdleThreshold: 10 * time.Second,
		Builder:       &tableClient,
	}

	tablePath := fmt.Sprintf("%s/%s", opts.Database, opts.Path)

	err = createTables(ctx, &sp, tablePath)
	if err != nil {
		return nil, fmt.Errorf("create tables error: %v", err)
	}

	return &DB{
		ctx:                       ctx,
		sp:                        &sp,
		selectPackageVersionQuery: selectPackageVersionQuery(tablePath),
		selectPackageQuery:        selectPackageQuery(tablePath),
		updatePackageQuery:        updatePackageQuery(tablePath),
		updatePackageLiteQuery:    updatePackageLiteQuery(tablePath),
		cleanUpPackageQuery:       cleanUpPackageQuery(tablePath),
	}, nil
}

func (d *DB) Reset() error {
	return d.sp.Close(d.ctx)
}

func (d *DB) LookupPackageVersion(name, version string) (result *dbmodels.PackageVersion, resultErr error) {
	readTx := table.TxControl(
		table.BeginTx(
			table.WithOnlineReadOnly(),
		),
		table.CommitTx(),
	)

	var res *table.Result
	resultErr = table.Retry(d.ctx, d.sp,
		table.OperationFunc(func(ctx context.Context, s *table.Session) (err error) {
			stmt, err := s.Prepare(ctx, d.selectPackageVersionQuery)
			if err != nil {
				return err
			}

			_, res, err = stmt.Execute(ctx, readTx, table.NewQueryParameters(
				table.ValueParam("$name", ydb.UTF8Value(name)),
				table.ValueParam("$version", ydb.UTF8Value(version)),
			))
			return
		}),
	)

	if resultErr != nil {
		return
	}

	if !res.NextSet() || !res.NextRow() {
		resultErr = ErrNotFound
		return
	}

	result = &dbmodels.PackageVersion{
		Name:    name,
		Version: version,
	}
	// license, pkg_url, requirements, updated_at

	res.SeekItem("license")
	result.License = res.OUTF8()

	res.NextItem()
	result.PkgURL = res.OUTF8()

	res.NextItem()
	requirements := res.OJSON()
	if len(requirements) > 0 {
		resultErr = json.Unmarshal([]byte(requirements), &result.Requirements)
		if resultErr != nil {
			resultErr = xerrors.Errorf("failed to decode requirements: %w", resultErr)
			return
		}
	}

	res.NextItem()
	result.UpdatedAt = time.Unix(res.OInt64(), 0)

	resultErr = res.Err()
	return
}

func (d *DB) LookupPackage(name string) (result *dbmodels.Package, resultErr error) {
	readTx := table.TxControl(
		table.BeginTx(
			table.WithOnlineReadOnly(),
		),
		table.CommitTx(),
	)

	var res *table.Result
	resultErr = table.Retry(d.ctx, d.sp,
		table.OperationFunc(func(ctx context.Context, s *table.Session) (err error) {
			stmt, err := s.Prepare(ctx, d.selectPackageQuery)
			if err != nil {
				return err
			}

			_, res, err = stmt.Execute(ctx, readTx, table.NewQueryParameters(
				table.ValueParam("$name", ydb.UTF8Value(name)),
			))
			return
		}),
	)

	if resultErr != nil {
		return
	}

	if !res.NextSet() || !res.NextRow() {
		resultErr = ErrNotFound
		return
	}

	result = &dbmodels.Package{
		Name: name,
	}

	// pkg_name, source, versions, updated_at

	res.SeekItem("pkg_name")
	result.PkgName = res.OUTF8()

	res.NextItem()
	result.Source = res.OUTF8()

	res.NextItem()
	if versions := res.OUTF8(); len(versions) > 0 {
		result.Versions = strings.Split(versions, VersionsDelimiter)
	}

	res.NextItem()
	result.UpdatedAt = time.Unix(res.OInt64(), 0)

	resultErr = res.Err()
	return
}

func (d *DB) UpdatePackage(data *dbmodels.UpdatePackageData) (err error) {
	writeTx := table.TxControl(
		table.BeginTx(
			table.WithSerializableReadWrite(),
		),
		table.CommitTx(),
	)

	versions := make([]ydb.Value, len(data.NewVersions))
	for i, version := range data.NewVersions {
		versions[i] = ydb.StructValue(
			ydb.StructFieldValue("version", ydb.UTF8Value(version.Version)),
			ydb.StructFieldValue("license", ydb.UTF8Value(version.License)),
			ydb.StructFieldValue("pkg_url", ydb.UTF8Value(version.PkgURL)),
			ydb.StructFieldValue("requirements", ydb.JSONValue(string(version.Requirements))),
		)
	}

	if len(versions) > 0 {
		// If we have more than one versions - update it
		ydbVersions := ydb.ListValue(versions...)
		return table.Retry(d.ctx, d.sp,
			table.OperationFunc(func(ctx context.Context, s *table.Session) (err error) {
				stmt, err := s.Prepare(ctx, d.updatePackageQuery)
				if err != nil {
					return err
				}

				_, _, err = stmt.Execute(ctx, writeTx, table.NewQueryParameters(
					table.ValueParam("$name", ydb.UTF8Value(data.Package.Name)),
					table.ValueParam("$pkgName", ydb.UTF8Value(data.Package.PkgName)),
					table.ValueParam("$source", ydb.UTF8Value(data.Package.Source)),
					table.ValueParam("$allVersions", ydb.UTF8Value(
						strings.Join(data.Package.Versions, VersionsDelimiter),
					)),
					table.ValueParam("$updatedAt", ydb.Int64Value(time.Now().Unix())),
					table.ValueParam("$versions", ydbVersions),
				))
				return err
			}),
		)
	}

	// otherwise - just touch package itself
	return table.Retry(d.ctx, d.sp,
		table.OperationFunc(func(ctx context.Context, s *table.Session) (err error) {
			stmt, err := s.Prepare(ctx, d.updatePackageLiteQuery)
			if err != nil {
				return err
			}

			_, _, err = stmt.Execute(ctx, writeTx, table.NewQueryParameters(
				table.ValueParam("$name", ydb.UTF8Value(data.Package.Name)),
				table.ValueParam("$pkgName", ydb.UTF8Value(data.Package.PkgName)),
				table.ValueParam("$source", ydb.UTF8Value(data.Package.Source)),
				table.ValueParam("$allVersions", ydb.UTF8Value(
					strings.Join(data.Package.Versions, VersionsDelimiter),
				)),
				table.ValueParam("$updatedAt", ydb.Int64Value(time.Now().Unix())),
			))
			return err
		}),
	)
}

func (d *DB) CleanUpPackage(pkg *dbmodels.Package) (err error) {
	writeTx := table.TxControl(
		table.BeginTx(
			table.WithSerializableReadWrite(),
		),
		table.CommitTx(),
	)

	versions := make([]ydb.Value, len(pkg.Versions))
	for i, version := range pkg.Versions {
		versions[i] = ydb.UTF8Value(version)
	}

	return table.Retry(d.ctx, d.sp,
		table.OperationFunc(func(ctx context.Context, s *table.Session) (err error) {
			stmt, err := s.Prepare(ctx, d.cleanUpPackageQuery)
			if err != nil {
				return err
			}

			_, _, err = stmt.Execute(ctx, writeTx, table.NewQueryParameters(
				table.ValueParam("$name", ydb.UTF8Value(pkg.Name)),
				table.ValueParam("$versions", ydb.ListValue(versions...)),
			))
			return err
		}),
	)
}

func createTables(ctx context.Context, sp *table.SessionPool, prefix string) error {
	err := table.Retry(ctx, sp,
		table.OperationFunc(func(ctx context.Context, s *table.Session) error {
			return s.CreateTable(ctx, fmt.Sprintf("%s/%s", prefix, "package_versions"),
				table.WithColumn("key", ydb.Optional(ydb.TypeUint64)),
				table.WithColumn("name", ydb.Optional(ydb.TypeUTF8)),
				table.WithColumn("version", ydb.Optional(ydb.TypeUTF8)),
				table.WithColumn("license", ydb.Optional(ydb.TypeUTF8)),
				table.WithColumn("pkg_url", ydb.Optional(ydb.TypeUTF8)),
				table.WithColumn("requirements", ydb.Optional(ydb.TypeJSON)),
				table.WithColumn("updated_at", ydb.Optional(ydb.TypeInt64)),
				table.WithPrimaryKeyColumn("key", "name", "version"),
			)
		}),
	)
	if err != nil {
		return xerrors.Errorf("failed to create package_versions table: %w", err)
	}

	err = table.Retry(ctx, sp,
		table.OperationFunc(func(ctx context.Context, s *table.Session) error {
			return s.CreateTable(ctx, fmt.Sprintf("%s/%s", prefix, "packages"),
				table.WithColumn("key", ydb.Optional(ydb.TypeUint64)),
				table.WithColumn("name", ydb.Optional(ydb.TypeUTF8)),
				table.WithColumn("pkg_name", ydb.Optional(ydb.TypeUTF8)),
				table.WithColumn("source", ydb.Optional(ydb.TypeUTF8)),
				table.WithColumn("versions", ydb.Optional(ydb.TypeUTF8)),
				table.WithColumn("updated_at", ydb.Optional(ydb.TypeInt64)),
				table.WithPrimaryKeyColumn("key", "name"),
			)
		}),
	)
	if err != nil {
		return xerrors.Errorf("failed to create packages table: %w", err)
	}

	return nil
}
