package workflow

import (
	"context"
	"fmt"

	"golang.org/x/xerrors"

	"a.yandex-team.ru/security/impulse/api/internal/db"
	"a.yandex-team.ru/security/impulse/models"
)

type workflowRepository struct {
	db *db.DB
}

func NewWorkflowRepository(db *db.DB) Repository {
	return &workflowRepository{db}
}

func (m *workflowRepository) Update(ctx context.Context, w *models.Workflow) (int64, error) {
	var rows int64
	err := m.db.Trier.Try(ctx, func(ctx context.Context) error {
		res, err := m.db.PG.NamedExecContext(ctx, `
INSERT INTO
	workflow (id, name, description, url)
VALUES
	(:id, :name, :description, :url)
ON CONFLICT (id)
DO
UPDATE SET
	name = :name, description = :description, url = :url
`, w)
		if err != nil {
			if !db.IsRetriableError(err) {
				return err
			}

			err = xerrors.Errorf("failed to create or update workflow info %v: %w", w.ID, err)
			return err
		}
		rows, err = res.RowsAffected()
		if err != nil {
			err = xerrors.Errorf("failed to get number of updated rows %v: %w", w.ID, err)
			return err
		}
		return err
	})
	return rows, err
}

func (m *workflowRepository) UpdateWorkflowScanTypes(ctx context.Context, workflowID string, scanTypeIds []int) (err error) {
	err = m.db.Trier.Try(ctx, func(ctx context.Context) (err error) {
		values := []interface{}{workflowID}
		valuesStr := ""
		for id, scanTypeID := range scanTypeIds {
			values = append(values, scanTypeID)
			valuesStr += fmt.Sprintf("($1, $%d),", id+2)
		}
		valuesStr = valuesStr[0 : len(valuesStr)-1]

		query := `
DELETE FROM
	workflow2scanType
WHERE
	workflow_id = $1;
INSERT INTO
	workflow2scanType (workflow_id, scan_type_id)
VALUES ` + valuesStr
		_, err = m.db.PG.ExecContext(ctx, query, values...)
		if err != nil {
			if !db.IsRetriableError(err) {
				return
			}

			err = xerrors.Errorf("failed to update workflow scan types %v: %w", workflowID, err)
			return
		}

		return
	})
	return
}

func (m *workflowRepository) Get(ctx context.Context, workflowID string) (w *models.Workflow, err error) {
	w = new(models.Workflow)
	err = m.db.Trier.Try(ctx, func(ctx context.Context) (err error) {
		err = m.db.PG.GetContext(ctx, w, `
SELECT
	id, name, description, url
FROM
	workflow
WHERE
	id = $1`, workflowID)

		if err != nil {
			if !db.IsRetriableError(err) {
				return
			}

			err = xerrors.Errorf("failed to get workflow %s: %w", workflowID, err)
			return
		}
		return
	})
	return
}

func (m *workflowRepository) ListWorkflows(ctx context.Context) (w []*models.Workflow, err error) {
	w = make([]*models.Workflow, 0)
	err = m.db.Trier.Try(ctx, func(ctx context.Context) (err error) {
		err = m.db.PG.SelectContext(ctx, &w, `SELECT id, name, description, url FROM workflow ORDER BY name ASC`)

		if err != nil {
			if !db.IsRetriableError(err) {
				return
			}

			err = xerrors.Errorf("failed to list workflows: %w", err)
			return
		}
		return
	})
	return
}

func (m *workflowRepository) ListWorkflowScanTypeTitles(ctx context.Context, workflowID string) (w []*models.Workflow2ScanTypeTitle, err error) {
	w = make([]*models.Workflow2ScanTypeTitle, 0)
	err = m.db.Trier.Try(ctx, func(ctx context.Context) (err error) {
		err = m.db.PG.SelectContext(ctx, &w, `
SELECT
	workflow_id, title
FROM
	workflow2scanType wt
INNER JOIN
	scanType t
ON
	wt.scan_type_id = t.id
WHERE
	workflow_id = $1 OR $1 = ''`, workflowID)
		if err != nil {
			if !db.IsRetriableError(err) {
				return
			}

			err = xerrors.Errorf("failed to list workflow scan types: %w", err)
			return
		}
		return
	})
	return
}

func (m *workflowRepository) ListWorkflowScanTypes(ctx context.Context, workflowID string) ([]*models.Workflow2ScanType, error) {
	workflowScanTypes := make([]*models.Workflow2ScanType, 0)
	err := m.db.PG.SelectContext(ctx, &workflowScanTypes,
		"SELECT workflow_id, scan_type_id FROM workflow2scanType WHERE workflow_id = $1", workflowID)
	if err != nil {
		return nil, err
	}
	return workflowScanTypes, nil
}
