package vulnerabilitycategory

import (
	"context"

	"golang.org/x/xerrors"

	"a.yandex-team.ru/security/impulse/api/internal/db"
	"a.yandex-team.ru/security/impulse/models"
)

type vulnerabilityCategoryRepository struct {
	db *db.DB
}

func NewVulnerabilityCategoryRepository(db *db.DB) Repository {
	return &vulnerabilityCategoryRepository{db}
}

func (m *vulnerabilityCategoryRepository) ListByScanType(ctx context.Context, scanTypeID int) ([]*models.VulnerabilityCategory, error) {
	categories := make([]*models.VulnerabilityCategory, 0)
	err := m.db.Trier.Try(ctx, func(ctx context.Context) (err error) {
		err = m.db.PG.SelectContext(ctx, &categories, `
SELECT
	*
FROM
	vulnerabilityCategory
WHERE
	scan_type_id = $1`,
			scanTypeID)

		if err != nil {
			if !db.IsRetriableError(err) {
				return
			}

			err = xerrors.Errorf("failed to list vulnerability categories: %w", err)
			return
		}
		return
	})
	return categories, err
}

func (m *vulnerabilityCategoryRepository) Insert(ctx context.Context, scanTypeID int, category string) (int, error) {
	var id int
	err := m.db.Trier.Try(ctx, func(ctx context.Context) (err error) {
		err = m.db.PG.QueryRowContext(ctx, `
WITH
	ins AS (
		INSERT INTO
			vulnerabilityCategory (scan_type_id, name)
		VALUES
			($1, $2)
		ON CONFLICT (scan_type_id, name) DO NOTHING
		RETURNING id
	),
	sel AS (
		SELECT
			id
		FROM
			vulnerabilityCategory
		WHERE
			scan_type_id = $1 AND name = $2
	)
	SELECT id FROM ins
	UNION ALL
	SELECT id FROM sel
	`, scanTypeID, category).Scan(&id)

		if err != nil {
			if !db.IsRetriableError(err) {
				return
			}

			err = xerrors.Errorf("failed to insert vulnerability category: %w", err)
			return
		}
		return
	})
	return id, err
}

func (m *vulnerabilityCategoryRepository) ListByScanInstance(ctx context.Context, scanInstance models.ScanInstance) ([]*models.VulnerabilityCategory, error) {
	categories := make([]*models.VulnerabilityCategory, 0)
	err := m.db.Trier.Try(ctx, func(ctx context.Context) (err error) {
		err = m.db.PG.SelectContext(ctx, &categories, "SELECT vc.id as id, vc.scan_type_id as scan_type_id, vc.name as name "+
			" FROM vulnerabilityCategory vc "+
			" LEFT JOIN vulnerability v ON vc.id = v.category_id "+
			" LEFT JOIN vulnerability2scanInstance v2si ON v.id = v2si.vulnerability_id "+
			" WHERE v2si.scan_instance_id = $1 "+
			" GROUP BY vc.id "+
			" ORDER BY name ASC ", scanInstance.ID)
		return
	})
	if err != nil {
		return nil, err
	}
	return categories, nil
}

func (m *vulnerabilityCategoryRepository) GetStatisticsByScanInstanceIDAndCategoryID(ctx context.Context, scanInstanceID int,
	categoryID int) (*models.VulnerabilityCategoryStatistics, error) {
	vulnerabilityCategoryStatistics := models.VulnerabilityCategoryStatistics{}
	err := m.db.Trier.Try(ctx, func(ctx context.Context) error {
		err := m.db.PG.QueryRowContext(ctx,
			"SELECT COUNT(*) as total_vulnerabilities_count FROM vulnerability2scanInstance v2si "+
				" JOIN vulnerability v ON v.id=v2si.vulnerability_id "+
				"WHERE v2si.scan_instance_id = $1 AND v.category_id = $2",
			scanInstanceID, categoryID).Scan(&vulnerabilityCategoryStatistics.TotalVulnerabilitiesCount)
		if err != nil {
			return nil
		}
		return nil
	})
	if err != nil {
		return nil, err
	}
	return &vulnerabilityCategoryStatistics, nil
}
