package scantype

import (
	"context"

	"golang.org/x/xerrors"

	"a.yandex-team.ru/security/impulse/api/internal/db"
	"a.yandex-team.ru/security/impulse/models"
)

type scanTypeRepository struct {
	db *db.DB
}

func NewScanTypeRepository(db *db.DB) Repository {
	return &scanTypeRepository{db}
}

func (m *scanTypeRepository) List(ctx context.Context) (s []*models.ScanType, err error) {
	s = make([]*models.ScanType, 0)
	err = m.db.Trier.Try(ctx, func(ctx context.Context) (err error) {
		err = m.db.PG.SelectContext(ctx, &s, `SELECT id, type_name, title FROM scanType`)
		if err != nil {
			if !db.IsRetriableError(err) {
				return
			}

			err = xerrors.Errorf("failed to list scan types: %w", err)
			return
		}
		return
	})
	return
}

func (m *scanTypeRepository) GetByScanTypeName(ctx context.Context, scanTypeName string) (*models.ScanType, error) {
	scanType := models.ScanType{}
	err := m.db.PG.GetContext(ctx, &scanType, "SELECT * FROM scanType WHERE type_name = $1", scanTypeName)
	if err != nil {
		return nil, err
	}
	return &scanType, nil
}

func (m *scanTypeRepository) GetByID(ctx context.Context, scanTypeID int) (*models.ScanType, error) {
	scanType := models.ScanType{}
	err := m.db.PG.GetContext(ctx, &scanType, "SELECT * FROM scanType WHERE id = $1", scanTypeID)
	if err != nil {
		return nil, err
	}
	return &scanType, nil
}

func (m *scanTypeRepository) ListByProjectID(ctx context.Context, projectID int) ([]*models.ScanType, error) {
	scanTypes := make([]*models.ScanType, 0)
	err := m.db.Trier.Try(ctx, func(ctx context.Context) error {
		return m.db.PG.SelectContext(ctx, &scanTypes,
			"SELECT st.id as id, st.type_name as type_name, st.title as title "+
				" FROM scan s JOIN scantype st on s.scan_type_id = st.id WHERE s.project_id = $1 "+
				" ORDER BY st.title ASC", projectID)
	})
	if err != nil {
		return nil, err
	}
	return scanTypes, nil
}

func (m *scanTypeRepository) GetScanTypeParameters(ctx context.Context, scanTypeID int) (s []*models.ScanParameter, err error) {
	s = make([]*models.ScanParameter, 0)
	err = m.db.Trier.Try(ctx, func(ctx context.Context) (err error) {
		err = m.db.PG.SelectContext(ctx, &s, `SELECT * FROM scanParameter WHERE scan_type_id = $1`, scanTypeID)
		if err != nil {
			if !db.IsRetriableError(err) {
				return
			}

			err = xerrors.Errorf("failed to list scan parameters: %w", err)
			return
		}
		return
	})
	return
}

func (m *scanTypeRepository) GetScanTypeDisplayName(ctx context.Context, scanTypeName string) (displayName string, err error) {
	err = m.db.PG.GetContext(ctx, &displayName, "SELECT title FROM scanType WHERE type_name = $1", scanTypeName)
	if err != nil {
		return "", err
	}
	return
}
