package db

import (
	"context"
	"fmt"

	"a.yandex-team.ru/kikimr/public/sdk/go/ydb"
	"a.yandex-team.ru/kikimr/public/sdk/go/ydb/table"
	"a.yandex-team.ru/library/go/core/xerrors"
	"a.yandex-team.ru/security/yadi/snatcher/pkg/feed"
	"a.yandex-team.ru/security/yadi/web/internal/models"
)

func (c *DB) LookupAddedVulns(opts FeedLookupOpts) (result models.FeedLookupQuery, resultErr error) {
	readTx := table.TxControl(
		table.BeginTx(
			table.WithOnlineReadOnly(),
		),
		table.CommitTx(),
	)

	var res *table.Result
	resultErr = table.Retry(c.ctx, c.sp,
		table.OperationFunc(func(ctx context.Context, s *table.Session) (err error) {
			stmt, err := s.Prepare(ctx, c.selectNewVulnsQuery)
			if err != nil {
				return
			}

			_, res, err = stmt.Execute(ctx, readTx, table.NewQueryParameters(
				table.ValueParam("$lastKey", ydb.Uint64Value(opts.LastKey)),
				table.ValueParam("$limit", ydb.Uint64Value(opts.Limit)),
			))

			return
		}),
	)

	if resultErr != nil {
		return
	}

	for res.NextSet() {
		for res.NextRow() {
			var vuln *feed.Vulnerability
			vuln, resultErr = feed.TableToVuln(res)
			if resultErr != nil {
				return
			}

			resultErr = vuln.Adjust()
			if resultErr != nil {
				return
			}

			res.SeekItem("key")
			result.LastKey = res.OUint64()

			var action string
			res.SeekItem("vulnAction")
			action = res.OUTF8()

			result.VulnsWithActions = append(result.VulnsWithActions, models.VulnWithAction{
				Vuln:   *vuln,
				Action: action,
			})

		}
	}

	if res.RowCount() < int(opts.Limit) {
		result.IsLast = true
	}

	if res.RowCount() == 0 {
		resultErr = ErrNotFound
		return
	}

	resultErr = res.Err()
	return
}

func (c *DB) LookupChangedVulns(opts FeedLookupOpts) (result models.FeedLookupQuery, resultErr error) {
	readTx := table.TxControl(
		table.BeginTx(
			table.WithOnlineReadOnly(),
		),
		table.CommitTx(),
	)

	var res *table.Result
	resultErr = table.Retry(c.ctx, c.sp,
		table.OperationFunc(func(ctx context.Context, s *table.Session) (err error) {
			stmt, err := s.Prepare(ctx, c.selectChangedVulnsQuery)
			if err != nil {
				return
			}

			_, res, err = stmt.Execute(ctx, readTx, table.NewQueryParameters(
				table.ValueParam("$lastKey", ydb.Uint64Value(opts.LastKey)),
				table.ValueParam("$limit", ydb.Uint64Value(opts.Limit)),
			))

			return
		}),
	)

	if resultErr != nil {
		return
	}

	var maxInt = func(a, b int) int {
		if a > b {
			return a
		} else {
			return b
		}
	}

	type diffTable map[string]*models.VulnsPair
	diff := make(diffTable)
	var maxSetLength int
	for res.NextSet() {
		maxSetLength = maxInt(maxSetLength, res.SetRowCount())
		for res.NextRow() {
			var vuln *feed.Vulnerability
			vuln, resultErr = feed.TableToVuln(res)
			if resultErr != nil {
				return
			}

			yaID := vuln.YadiID
			if len(yaID) == 0 {
				resultErr = xerrors.Errorf("bad yadiID for %s vuln", vuln)
				return
			}

			if _, exists := diff[yaID]; !exists {
				diff[yaID] = &models.VulnsPair{}
			}

			res.SeekItem("generation")
			switch generation := string(res.String()); generation {
			case "new":
				// We fix vulnerability only from [src] table
				// Vulnerability in [feed] table must be already fixed
				if resultErr = vuln.Adjust(); resultErr != nil {
					return
				}
				diff[yaID].New = *vuln
			case "old":
				diff[yaID].Old = *vuln
			default:
				return result, xerrors.Errorf("failed to determine generation of vuln: got %s, need old or new", generation)
			}

			res.SeekItem("vulnAction")
			diff[yaID].Action = res.OUTF8()

			res.SeekItem("key")
			result.LastKey = res.OUint64()
		}
	}

	for _, pair := range diff {
		result.VulnsPairs = append(result.VulnsPairs, *pair)
	}

	if maxSetLength < int(opts.Limit) {
		result.IsLast = true
	}

	if res.RowCount() == 0 {
		resultErr = ErrNotFound
		return
	}

	resultErr = res.Err()
	return
}

func (c *DB) LookupAllFeed(opts FeedLookupOpts) (result models.FeedLookupQuery, resultErr error) {
	readTx := table.TxControl(
		table.BeginTx(
			table.WithOnlineReadOnly(),
		),
		table.CommitTx(),
	)

	var res *table.Result
	resultErr = table.Retry(c.ctx, c.sp,
		table.OperationFunc(func(ctx context.Context, s *table.Session) (err error) {
			stmt, err := s.Prepare(ctx, c.selectFullFeedQuery)
			if err != nil {
				return
			}

			_, res, err = stmt.Execute(ctx, readTx, table.NewQueryParameters(
				table.ValueParam("$lastKey", ydb.Uint64Value(opts.LastKey)),
				table.ValueParam("$limit", ydb.Uint64Value(opts.Limit)),
			))

			return
		}),
	)

	if resultErr != nil {
		return
	}

	for res.NextSet() {
		for res.NextRow() {
			vuln, err := feed.TableToVuln(res)
			if err != nil {
				return result, err
			}
			res.SeekItem("key")
			result.LastKey = res.OUint64()
			result.AllVulns = append(result.AllVulns, *vuln)
		}
	}

	if res.RowCount() < int(opts.Limit) {
		result.IsLast = true
	}

	if res.RowCount() == 0 {
		resultErr = ErrNotFound
		return
	}

	resultErr = res.Err()
	return
}

func (c *DB) LookupOneFeedVuln(opts VulnLookupOpts) (vuln feed.Vulnerability, resultErr error) {
	var query string
	params := new(table.QueryParameters)
	if len(opts.SrcType) > 0 && len(opts.SrcID) > 0 {
		params = table.NewQueryParameters(
			table.ValueParam("$srcType", ydb.UTF8Value(opts.SrcType)),
			table.ValueParam("$srcId", ydb.UTF8Value(opts.SrcID)),
		)
		query = c.selectOneFeedVulnQueryBySrc
	} else if len(opts.YaID) > 0 {
		params = table.NewQueryParameters(
			table.ValueParam("$yaId", ydb.UTF8Value(opts.YaID)),
		)
		query = c.selectOneFeedVulnQueryByYaID
	} else {
		return feed.Vulnerability{}, fmt.Errorf("empty yaId or srcType and srcId")
	}

	readTx := table.TxControl(
		table.BeginTx(
			table.WithOnlineReadOnly(),
		),
		table.CommitTx(),
	)

	var res *table.Result
	resultErr = table.Retry(c.ctx, c.sp,
		table.OperationFunc(func(ctx context.Context, s *table.Session) (err error) {
			stmt, err := s.Prepare(ctx, query)
			if err != nil {
				return
			}

			_, res, err = stmt.Execute(ctx, readTx, params)
			return
		}),
	)

	if resultErr != nil {
		return
	}

	rowsNum := res.RowCount()
	if rowsNum < 1 {
		resultErr = ErrNotFound
		return
	}

	if rowsNum > 1 {
		resultErr = xerrors.Errorf("got more than expected: %d", rowsNum)
		return
	}

	var dbVuln *feed.Vulnerability
	for res.NextSet() {
		for res.NextRow() {
			dbVuln, resultErr = feed.TableToVuln(res)
			if resultErr != nil {
				return
			}
		}
	}

	vuln = *dbVuln
	resultErr = res.Err()
	return
}

func (c *DB) ApproveVuln(reqVuln *feed.Vulnerability) (resultErr error) {
	writeTx := table.TxControl(
		table.BeginTx(
			table.WithSerializableReadWrite(),
		),
		table.CommitTx(),
	)

	var res *table.Result

	// get actual vulnerability with same srcType, srcId
	resultErr = table.Retry(c.ctx, c.sp,
		table.OperationFunc(func(ctx context.Context, s *table.Session) (err error) {
			stmt, err := s.Prepare(ctx, c.selectOneSrcVulnWithActionQuery)
			if err != nil {
				return
			}

			_, res, err = stmt.Execute(ctx, writeTx, table.NewQueryParameters(
				table.ValueParam("$srcType", ydb.UTF8Value(reqVuln.SrcType)),
				table.ValueParam("$srcId", ydb.UTF8Value(reqVuln.ID)),
			))
			return
		}),
	)

	if resultErr != nil {
		return
	}

	if res.RowCount() != 1 {
		return ErrNotFound
	}

	var action string
	var srcVuln *feed.Vulnerability

	for res.NextSet() {
		for res.NextRow() {
			res.SeekItem("vulnAction")
			action = res.OUTF8()

			srcVuln, resultErr = feed.TableToVuln(res)
			if resultErr != nil {
				return
			}
		}
	}

	// make something with vuln depending on record in [actions] table
	switch action {
	case "delete":
		// just delete from action
		resultErr = table.Retry(c.ctx, c.sp,
			table.OperationFunc(func(ctx context.Context, s *table.Session) (err error) {
				stmt, err := s.Prepare(ctx, c.deleteActionQuery)
				if err != nil {
					return
				}

				_, _, err = stmt.Execute(ctx, writeTx, table.NewQueryParameters(
					table.ValueParam("$srcType", ydb.UTF8Value(srcVuln.SrcType)),
					table.ValueParam("$srcId", ydb.UTF8Value(srcVuln.ID)),
				))
				return
			}),
		)
		if resultErr != nil {
			return
		}
	case "new", "update":
		// write vuln with rounded cvss_score to the [feed] table
		if err := srcVuln.Adjust(); err != nil {
			return xerrors.Errorf("can not adjust %s vuln: %w", srcVuln.ID, err)
		}

		if err := srcVuln.StrictValidate(); err != nil {
			return xerrors.Errorf("can not validate %s vuln: %w", srcVuln.ID, err)
		}
		params := feed.VulnToFeedTable(*srcVuln)
		resultErr = table.Retry(c.ctx, c.sp,
			table.OperationFunc(func(ctx context.Context, s *table.Session) (err error) {
				stmt, err := s.Prepare(ctx, c.changeVulnQuery)
				if err != nil {
					return
				}

				_, res, err = stmt.Execute(ctx, writeTx, table.NewQueryParameters(
					table.ValueParam("$vuln", ydb.ListValue(params)),
				))
				return
			}),
		)

		if resultErr != nil {
			return
		}
	default:
		return xerrors.Errorf("can't approve: vulnerability %s without action", srcVuln.ID)
	}

	resultErr = res.Err()
	return
}

func (c *DB) ChangeVuln(vuln *feed.Vulnerability) (resultErr error) {
	writeTx := table.TxControl(
		table.BeginTx(
			table.WithSerializableReadWrite(),
		),
		table.CommitTx(),
	)

	var tableValue = feed.VulnToFeedTable(*vuln)
	var res *table.Result

	resultErr = table.Retry(c.ctx, c.sp,
		table.OperationFunc(func(ctx context.Context, s *table.Session) (err error) {
			stmt, err := s.Prepare(ctx, c.changeVulnQuery)
			if err != nil {
				return
			}

			_, res, err = stmt.Execute(ctx, writeTx, table.NewQueryParameters(
				table.ValueParam("$vuln", ydb.ListValue(tableValue)),
			))
			return
		}),
	)

	if resultErr != nil {
		return
	}

	resultErr = res.Err()
	return
}

func (c *DB) DeleteAction(srcType string, srcID string) (resultErr error) {
	writeTx := table.TxControl(
		table.BeginTx(
			table.WithSerializableReadWrite(),
		),
		table.CommitTx(),
	)
	resultErr = table.Retry(c.ctx, c.sp,
		table.OperationFunc(func(ctx context.Context, s *table.Session) (err error) {
			stmt, err := s.Prepare(ctx, c.deleteActionQuery)
			if err != nil {
				return
			}
			_, _, err = stmt.Execute(ctx, writeTx, table.NewQueryParameters(
				table.ValueParam("$srcType", ydb.UTF8Value(srcType)),
				table.ValueParam("$srcId", ydb.UTF8Value(srcID)),
			))
			return
		}),
	)
	return
}
