package db

import (
	"context"
	"encoding/json"
	"errors"
	"fmt"
	"path"
	"time"

	"a.yandex-team.ru/kikimr/public/sdk/go/ydb"
	"a.yandex-team.ru/kikimr/public/sdk/go/ydb/table"
	"a.yandex-team.ru/library/go/yandex/tvm"
	"a.yandex-team.ru/security/libs/go/ydbtvm"
	"a.yandex-team.ru/security/xray/internal/dbmodels"
)

type (
	DB struct {
		sp                            *table.SessionPool
		selectAnalysisQuery           string
		selectAnalysisStatusQuery     string
		selectAnalysisLogQuery        string
		selectStageStatusQuery        string
		selectStagesStatusQuery       string
		selectLatestStagesStatusQuery string
		selectStageAnalysisQuery      string
		selectLatestStageQuery        string
		scheduleAnalyzeQuery          string
		completeAnalyzeQuery          string
	}

	Options struct {
		Database string
		Path     string
		Endpoint string
	}
)

var (
	ErrNotFound        = errors.New("not found")
	ErrAlreadyAnalyzed = errors.New("stage already analyzed")
)

func New(ctx context.Context, tvmClient tvm.Client, opts *Options) (*DB, error) {
	config := &ydb.DriverConfig{
		Database: opts.Database,
		Credentials: &ydbtvm.TvmCredentials{
			DstID:     ydbtvm.YDBClientID,
			TvmClient: tvmClient,
		},
	}

	driver, err := (&ydb.Dialer{
		DriverConfig: config,
	}).Dial(ctx, opts.Endpoint)

	if err != nil {
		return nil, fmt.Errorf("dial error: %v", err)
	}

	tableClient := table.Client{
		Driver: driver,
	}

	sp := table.SessionPool{
		IdleThreshold: 10 * time.Second,
		Builder:       &tableClient,
	}

	err = createTables(ctx, &sp, opts.Path)
	if err != nil {
		return nil, fmt.Errorf("create tables error: %v", err)
	}

	return &DB{
		sp:                            &sp,
		selectAnalysisQuery:           selectAnalysisQuery(opts.Path),
		selectAnalysisStatusQuery:     selectAnalysisStatusQuery(opts.Path),
		selectAnalysisLogQuery:        selectAnalysisLogQuery(opts.Path),
		selectStageStatusQuery:        selectStageStatusQuery(opts.Path),
		selectStageAnalysisQuery:      selectStageAnalysisQuery(opts.Path),
		selectLatestStageQuery:        selectLatestStageQuery(opts.Path),
		selectStagesStatusQuery:       selectStagesStatusQuery(opts.Path),
		selectLatestStagesStatusQuery: selectLatestStagesStatusQuery(opts.Path),
		scheduleAnalyzeQuery:          scheduleAnalyzeQuery(opts.Path),
		completeAnalyzeQuery:          completeAnalyzeQuery(opts.Path),
	}, nil
}

func (db *DB) Reset(ctx context.Context) error {
	return db.sp.Close(ctx)
}

func (db *DB) LookupAnalysis(ctx context.Context, id string) (*dbmodels.Analysis, error) {
	readTx := table.TxControl(
		table.BeginTx(
			table.WithOnlineReadOnly(),
		),
		table.CommitTx(),
	)

	var res *table.Result
	err := table.Retry(ctx, db.sp,
		table.OperationFunc(func(ctx context.Context, s *table.Session) (err error) {
			stmt, err := s.Prepare(ctx, db.selectAnalysisQuery)
			if err != nil {
				return err
			}

			_, res, err = stmt.Execute(ctx, readTx, table.NewQueryParameters(
				table.ValueParam("$id", ydb.UTF8Value(id)),
			))
			return
		}),
	)

	if err != nil {
		return nil, err
	}

	if !res.NextSet() || !res.NextRow() {
		return nil, ErrNotFound
	}

	result := new(dbmodels.Analysis)
	// id, created_at, updated_at, result_path, log_path, status, status_description

	res.SeekItem("id")
	result.ID = res.OUTF8()

	res.NextItem()
	result.CreatedAt = time.Unix(res.OInt64(), 0)

	res.NextItem()
	result.UpdatedAt = time.Unix(res.OInt64(), 0)

	res.NextItem()
	result.ResultPath = res.OUTF8()

	res.NextItem()
	result.LogPath = res.OUTF8()

	res.NextItem()
	result.Status = dbmodels.Status(res.OInt32())

	res.NextItem()
	result.StatusDescription = res.OUTF8()

	return result, res.Err()
}

func (db *DB) LookupStageStatus(ctx context.Context, stage dbmodels.StageInfo) (*dbmodels.StageStatus, error) {
	readTx := table.TxControl(
		table.BeginTx(
			table.WithOnlineReadOnly(),
		),
		table.CommitTx(),
	)

	var res *table.Result
	err := table.Retry(ctx, db.sp,
		table.OperationFunc(func(ctx context.Context, s *table.Session) error {
			stmt, err := s.Prepare(ctx, db.selectStageStatusQuery)
			if err != nil {
				return err
			}

			_, res, err = stmt.Execute(ctx, readTx, table.NewQueryParameters(
				table.ValueParam("$uuid", ydb.UTF8Value(stage.UUID)),
				table.ValueParam("$revision", ydb.Uint32Value(stage.Revision)),
			))
			return err
		}),
	)

	if err != nil {
		return nil, err
	}

	if !res.NextSet() || !res.NextRow() {
		return nil, ErrNotFound
	}

	result := new(dbmodels.StageStatus)
	// uuid, revision, id, updated_at, overview, status

	res.SeekItem("uuid")
	result.UUID = res.OUTF8()

	res.NextItem()
	result.Revision = res.OUint32()

	res.NextItem()
	result.ID = res.OUTF8()

	res.NextItem()
	result.UpdatedAt = time.Unix(res.OInt64(), 0)

	res.NextItem()
	overview := res.OString()
	if len(overview) > 0 {
		err = json.Unmarshal(overview, &result.Overview)
		if err != nil {
			return nil, fmt.Errorf("failed to parse stage overview: %w", err)
		}
	}

	res.NextItem()
	result.Status = dbmodels.Status(res.OInt32())

	return result, res.Err()
}

func (db *DB) LookupStageAnalysis(ctx context.Context, stage dbmodels.StageInfo) (*dbmodels.StageAnalysis, error) {
	readTx := table.TxControl(
		table.BeginTx(
			table.WithOnlineReadOnly(),
		),
		table.CommitTx(),
	)

	var res *table.Result
	err := table.Retry(ctx, db.sp,
		table.OperationFunc(func(ctx context.Context, s *table.Session) (resultErr error) {
			stmt, err := s.Prepare(ctx, db.selectStageAnalysisQuery)
			if err != nil {
				return err
			}

			_, res, resultErr = stmt.Execute(ctx, readTx, table.NewQueryParameters(
				table.ValueParam("$stageUUID", ydb.UTF8Value(stage.UUID)),
				table.ValueParam("$stageRevision", ydb.Uint32Value(stage.Revision)),
			))
			return
		}),
	)

	if err != nil {
		return nil, err
	}

	if !res.NextSet() || !res.NextRow() {
		return nil, ErrNotFound
	}

	result := new(dbmodels.StageAnalysis)
	// analyze.id, analyze.updated_at, analyze.result_path, analyze.log_path, analyze.status, analyze.status_description

	res.SeekItem("analyze.id")
	result.ID = res.OUTF8()

	res.NextItem()
	result.UpdatedAt = time.Unix(res.OInt64(), 0)

	res.NextItem()
	result.ResultPath = res.OUTF8()

	res.NextItem()
	result.LogPath = res.OUTF8()

	res.NextItem()
	result.Status = dbmodels.Status(res.OInt32())

	res.NextItem()
	result.StatusDescription = res.OUTF8()

	res.NextItem()
	overview := res.OString()
	if len(overview) > 0 {
		err = json.Unmarshal(overview, &result.Overview)
		if err != nil {
			return nil, fmt.Errorf("failed to parse stage overview (uuid=%s revision=%d): %w",
				stage.UUID, stage.Revision, err)
		}
	}

	return result, res.Err()
}

func (db *DB) SelectStagesStatus(ctx context.Context, limit int32) ([]dbmodels.StageStatus, error) {
	readTx := table.TxControl(
		table.BeginTx(
			table.WithOnlineReadOnly(),
		),
		table.CommitTx(),
	)

	var res *table.Result
	err := table.Retry(ctx, db.sp,
		table.OperationFunc(func(ctx context.Context, s *table.Session) (resultErr error) {
			stmt, err := s.Prepare(ctx, db.selectStagesStatusQuery)
			if err != nil {
				return err
			}

			_, res, resultErr = stmt.Execute(ctx, readTx, table.NewQueryParameters(
				table.ValueParam("$limit", ydb.Int32Value(limit)),
			))
			return
		}),
	)

	if err != nil {
		return nil, err
	}

	result := make([]dbmodels.StageStatus, 0, limit)
	for res.NextSet() {
		for res.NextRow() {

			var status dbmodels.StageStatus
			// ls.timestamp, s.updated_at, ls.stage_id, ls.stage_uuid, ls.stage_revision, s.overview, s.status

			res.SeekItem("s.updated_at")
			status.UpdatedAt = time.Unix(res.OInt64(), 0)

			res.NextItem()
			status.StageInfo.ID = res.OUTF8()

			res.NextItem()
			status.StageInfo.UUID = res.OUTF8()

			res.NextItem()
			status.StageInfo.Revision = res.OUint32()

			res.NextItem()
			overview := res.OString()
			if len(overview) > 0 {
				err = json.Unmarshal(overview, &status.Overview)
				if err != nil {
					return nil, fmt.Errorf("failed to parse stage overview (uuid=%s revision=%d): %w",
						status.UUID, status.Revision, err)
				}
			}

			res.NextItem()
			status.Status = dbmodels.Status(res.OInt32())

			result = append(result, status)
		}
	}

	return result, res.Err()
}

func (db *DB) SelectLatestStagesStatus(ctx context.Context, uuids []string, limit int32) ([]dbmodels.StageStatus, error) {
	readTx := table.TxControl(
		table.BeginTx(
			table.WithOnlineReadOnly(),
		),
		table.CommitTx(),
	)

	var res *table.Result
	uuidValues := make([]ydb.Value, len(uuids))
	for i, uuid := range uuids {
		uuidValues[i] = ydb.UTF8Value(uuid)
	}

	err := table.Retry(ctx, db.sp,
		table.OperationFunc(func(ctx context.Context, s *table.Session) (resultErr error) {
			stmt, err := s.Prepare(ctx, db.selectLatestStagesStatusQuery)
			if err != nil {
				return err
			}

			_, res, resultErr = stmt.Execute(ctx, readTx, table.NewQueryParameters(
				table.ValueParam("$limit", ydb.Int32Value(limit)),
				table.ValueParam("$uuids", ydb.ListValue(uuidValues...)),
			))
			return
		}),
	)

	if err != nil {
		return nil, err
	}

	result := make([]dbmodels.StageStatus, 0, limit)
	for res.NextSet() {
		for res.NextRow() {

			var status dbmodels.StageStatus
			// updated_at, stage_id, stage_uuid, stage_revision, overview, status

			res.SeekItem("updated_at")
			status.UpdatedAt = time.Unix(res.OInt64(), 0)

			res.NextItem()
			status.StageInfo.ID = res.OUTF8()

			res.NextItem()
			status.StageInfo.UUID = res.OUTF8()

			res.NextItem()
			status.StageInfo.Revision = res.OUint32()

			res.NextItem()
			overview := res.OString()
			if len(overview) > 0 {
				err = json.Unmarshal(overview, &status.Overview)
				if err != nil {
					return nil, fmt.Errorf("failed to parse stage overview (uuid=%s revision=%d): %w",
						status.UUID, status.Revision, err)
				}
			}

			res.NextItem()
			status.Status = dbmodels.Status(res.OInt32())

			result = append(result, status)
		}
	}

	return result, res.Err()
}

func (db *DB) LookupAnalysisStatus(ctx context.Context, id string) (*dbmodels.AnalysisStatus, error) {
	readTx := table.TxControl(
		table.BeginTx(
			table.WithOnlineReadOnly(),
		),
		table.CommitTx(),
	)

	var res *table.Result
	err := table.Retry(ctx, db.sp,
		table.OperationFunc(func(ctx context.Context, s *table.Session) error {
			stmt, err := s.Prepare(ctx, db.selectAnalysisStatusQuery)
			if err != nil {
				return err
			}

			_, res, err = stmt.Execute(ctx, readTx, table.NewQueryParameters(
				table.ValueParam("$id", ydb.UTF8Value(id)),
			))
			return err
		}),
	)

	if err != nil {
		return nil, err
	}

	if !res.NextSet() || !res.NextRow() {
		return nil, ErrNotFound
	}

	result := new(dbmodels.AnalysisStatus)

	// status, status_description
	res.SeekItem("status")
	result.Value = dbmodels.Status(res.OInt32())

	res.NextItem()
	result.Description = res.OUTF8()

	return result, res.Err()
}

func (db *DB) LookupLatestStageRevision(ctx context.Context, stageUUID string) (uint32, error) {
	readTx := table.TxControl(
		table.BeginTx(
			table.WithOnlineReadOnly(),
		),
		table.CommitTx(),
	)

	var res *table.Result
	err := table.Retry(ctx, db.sp,
		table.OperationFunc(func(ctx context.Context, s *table.Session) error {
			stmt, err := s.Prepare(ctx, db.selectLatestStageQuery)
			if err != nil {
				return err
			}

			_, res, err = stmt.Execute(ctx, readTx, table.NewQueryParameters(
				table.ValueParam("$uuid", ydb.UTF8Value(stageUUID)),
			))
			return err
		}),
	)

	if err != nil {
		return 0, err
	}

	if !res.NextSet() || !res.NextRow() {
		return 0, ErrNotFound
	}

	res.SeekItem("revision")
	return res.OUint32(), res.Err()
}

func (db *DB) ScheduleStageAnalysis(ctx context.Context, data dbmodels.ScheduleLatestStageData) error {
	return table.Retry(ctx, db.sp,
		table.OperationFunc(func(ctx context.Context, s *table.Session) error {
			tx, err := s.BeginTransaction(ctx, table.TxSettings(
				table.WithSerializableReadWrite(),
			))
			if err != nil {
				return err
			}

			defer func() {
				if err != nil {
					_ = tx.Rollback(context.Background())
				}
			}()

			// first of all - get latest scheduled analysis for our stage
			stmt, err := s.Prepare(ctx, db.selectLatestStageQuery)
			if err != nil {
				return err
			}

			res, err := tx.ExecuteStatement(ctx, stmt, table.NewQueryParameters(
				table.ValueParam("$uuid", ydb.UTF8Value(data.StageUUID)),
			))

			if err != nil {
				return err
			}

			var latestRevision uint32
			if res.NextSet() && res.NextRow() {
				res.SeekItem("revision")
				latestRevision = res.OUint32()
			}

			// now do sanity check
			if data.StageRevision > latestRevision {
				latestRevision = data.StageRevision
			} else if !data.Force {
				return ErrAlreadyAnalyzed
			}

			// schedule
			data.AnalyzeID, err = data.Scheduler()
			if err != nil {
				return err
			}

			// and save
			stmt, err = s.Prepare(ctx, db.scheduleAnalyzeQuery)
			if err != nil {
				return err
			}

			_, err = tx.ExecuteStatement(ctx, stmt, table.NewQueryParameters(
				table.ValueParam("$analyzeID", ydb.UTF8Value(data.AnalyzeID)),
				table.ValueParam("$stageUUID", ydb.UTF8Value(data.StageUUID)),
				table.ValueParam("$stageRevision", ydb.Uint32Value(data.StageRevision)),
				table.ValueParam("$stageID", ydb.UTF8Value(data.StageID)),
				table.ValueParam("$latestStageRevision", ydb.Uint32Value(latestRevision)),
				table.ValueParam("$description", ydb.UTF8Value(data.Description)),
				table.ValueParam("$status", ydb.Int32Value(int32(data.Status))),
				table.ValueParam("$createdAt", ydb.Int64Value(data.CreatedAt.Unix())),
			))

			if err != nil {
				return err
			}

			return tx.Commit(ctx)
		}),
	)
}

func (db *DB) CompleteAnalyze(ctx context.Context, data dbmodels.UpdateAnalyzeData) error {
	writeTx := table.TxControl(
		table.BeginTx(
			table.WithSerializableReadWrite(),
		),
		table.CommitTx(),
	)

	overview, err := json.Marshal(data.Overview)
	if err != nil {
		return fmt.Errorf("failed to marshal stage overview: %w", err)
	}

	return table.Retry(ctx, db.sp,
		table.OperationFunc(func(ctx context.Context, s *table.Session) (err error) {
			stmt, err := s.Prepare(ctx, db.completeAnalyzeQuery)
			if err != nil {
				return err
			}

			_, _, err = stmt.Execute(ctx, writeTx, table.NewQueryParameters(
				table.ValueParam("$analyzeID", ydb.UTF8Value(data.AnalyzeID)),
				table.ValueParam("$stageUUID", ydb.UTF8Value(data.StageUUID)),
				table.ValueParam("$stageRevision", ydb.Uint32Value(data.StageRevision)),
				table.ValueParam("$updatedAt", ydb.Int64Value(data.UpdatedAt.Unix())),
				table.ValueParam("$resultPath", ydb.UTF8Value(data.ResultPath)),
				table.ValueParam("$logPath", ydb.UTF8Value(data.LogPath)),
				table.ValueParam("$status", ydb.Int32Value(int32(data.Status))),
				table.ValueParam("$statusDescription", ydb.UTF8Value(data.StatusDescription)),
				table.ValueParam("$overview", ydb.StringValue(overview)),
			))
			return err
		}),
	)
}

func createTables(ctx context.Context, sp *table.SessionPool, prefix string) error {
	err := table.Retry(ctx, sp,
		table.OperationFunc(func(ctx context.Context, s *table.Session) error {
			return s.CreateTable(ctx, path.Join(prefix, "analyzes"),
				table.WithColumn("key", ydb.Optional(ydb.TypeUint64)),
				table.WithColumn("id", ydb.Optional(ydb.TypeUTF8)),
				table.WithColumn("created_at", ydb.Optional(ydb.TypeInt64)),
				table.WithColumn("updated_at", ydb.Optional(ydb.TypeInt64)),
				table.WithColumn("result_path", ydb.Optional(ydb.TypeUTF8)),
				table.WithColumn("log_path", ydb.Optional(ydb.TypeUTF8)),
				table.WithColumn("status", ydb.Optional(ydb.TypeInt32)),
				table.WithColumn("status_description", ydb.Optional(ydb.TypeUTF8)),
				table.WithPrimaryKeyColumn("key", "id"),
			)
		}),
	)
	if err != nil {
		return fmt.Errorf("failed to create analyzes table: %w", err)
	}

	err = table.Retry(ctx, sp,
		table.OperationFunc(func(ctx context.Context, s *table.Session) error {
			return s.CreateTable(ctx, path.Join(prefix, "stages"),
				table.WithColumn("key", ydb.Optional(ydb.TypeUint64)),
				table.WithColumn("uuid", ydb.Optional(ydb.TypeUTF8)),
				table.WithColumn("revision", ydb.Optional(ydb.TypeUint32)),
				table.WithColumn("id", ydb.Optional(ydb.TypeUTF8)),
				table.WithColumn("description", ydb.Optional(ydb.TypeUTF8)),
				table.WithColumn("updated_at", ydb.Optional(ydb.TypeInt64)),
				table.WithColumn("analyze_id", ydb.Optional(ydb.TypeUTF8)),
				table.WithColumn("overview", ydb.Optional(ydb.TypeString)),
				table.WithColumn("status", ydb.Optional(ydb.TypeInt32)),
				table.WithPrimaryKeyColumn("key", "uuid", "revision"),
			)
		}),
	)
	if err != nil {
		return fmt.Errorf("failed to create stages table: %w", err)
	}

	err = table.Retry(ctx, sp,
		table.OperationFunc(func(ctx context.Context, s *table.Session) error {
			return s.CreateTable(ctx, path.Join(prefix, "latest_stages"),
				table.WithColumn("key", ydb.Optional(ydb.TypeUint64)),
				table.WithColumn("uuid", ydb.Optional(ydb.TypeUTF8)),
				table.WithColumn("revision", ydb.Optional(ydb.TypeUint32)),
				table.WithPrimaryKeyColumn("key", "uuid"),
			)
		}),
	)
	if err != nil {
		return fmt.Errorf("failed to create latest_stages table: %w", err)
	}

	err = table.Retry(ctx, sp,
		table.OperationFunc(func(ctx context.Context, s *table.Session) error {
			return s.CreateTable(ctx, path.Join(prefix, "latest_stage_analyzes"),
				table.WithColumn("created_at", ydb.Optional(ydb.TypeInt64)),
				table.WithColumn("stage_uuid", ydb.Optional(ydb.TypeUTF8)),
				table.WithColumn("stage_revision", ydb.Optional(ydb.TypeUint32)),
				table.WithColumn("stage_id", ydb.Optional(ydb.TypeUTF8)),
				table.WithColumn("analyze_id", ydb.Optional(ydb.TypeUTF8)),
				table.WithPrimaryKeyColumn("created_at", "stage_uuid", "stage_revision"),
			)
		}),
	)
	if err != nil {
		return fmt.Errorf("failed to create latest_stage_analyzes table: %w", err)
	}

	return nil
}
