package scaninstance

import (
	"context"
	"database/sql"
	"encoding/json"
	"fmt"
	"strings"
	"time"

	"github.com/jackc/pgtype"

	"a.yandex-team.ru/security/impulse/api/internal/db"
	"a.yandex-team.ru/security/impulse/models"
	"a.yandex-team.ru/security/libs/go/simplelog"
)

type scanInstanceRepository struct {
	db *db.DB
}

func NewScanInstanceRepository(db *db.DB) Repository {
	return &scanInstanceRepository{db}
}

func (m *scanInstanceRepository) Create(ctx context.Context, scanInstance *models.ScanInstance) (*models.ScanInstance, error) {
	err := m.db.Trier.Try(ctx, func(ctx context.Context) (err error) {
		row, err := m.db.PG.NamedQueryContext(ctx, `
INSERT INTO
	ScanInstance (scan_id, task_id, raw_report_url, report_url, commit_hash, start_time, end_time)
VALUES
	(:scan_id, :task_id, :raw_report_url, :report_url, :commit_hash, :start_time, :end_time)
RETURNING id
`, scanInstance)
		if err != nil {
			if !db.IsRetriableError(err) {
				return
			}

			err = fmt.Errorf("failed to create scanInstance %v: %w", scanInstance.TaskID, err)
			return
		}
		defer row.Close()
		if !row.Next() {
			return db.ErrNotFound
		}

		err = row.Scan(&scanInstance.ID)
		if err != nil {
			return err
		}

		err = row.Err()
		return
	})
	return scanInstance, err
}

func (m *scanInstanceRepository) UpdateLastScanInstance(ctx context.Context, scanInstance *models.ScanInstance) error {
	err := m.db.Trier.Try(ctx, func(ctx context.Context) (err error) {
		_, err = m.db.PG.ExecContext(ctx,
			"INSERT INTO LastScanInstance "+
				" (organization_id, project_id, scan_type_id, scan_id, scan_instance_id) "+
				" ( "+
				"  SELECT o.id, p.id, st.id, s.id, si.id "+
				"  FROM ScanInstance si "+
				"  JOIN Scan s ON si.scan_id = s.id "+
				"  JOIN ScanType st ON s.scan_type_id = st.id "+
				"  JOIN project p ON s.project_id = p.id "+
				"  JOIN organization o ON p.organization_id = o.id "+
				"  WHERE si.id = $1 "+
				" ) "+
				" ON CONFLICT (organization_id, project_id, scan_type_id, scan_id) "+
				" DO UPDATE SET scan_instance_id = $1", scanInstance.ID)
		if err != nil {
			if !db.IsRetriableError(err) {
				return
			}

			err = fmt.Errorf("failed to update LastScanInstance Cache %v: %w", scanInstance.TaskID, err)
			return
		}
		return
	})
	return err
}

func (m *scanInstanceRepository) InsertVulnerabilities2ScanInstance(ctx context.Context, scanInstance *models.ScanInstance,
	vulnerabilities []*models.Vulnerability) ([]*models.Vulnerability2ScanInstance, error) {
	start := time.Now()

	tx, err := m.db.PG.BeginTx(ctx, nil)
	if err != nil {
		return nil, err
	}
	insertedVulnerabilities2ScanInstance := make([]*models.Vulnerability2ScanInstance, len(vulnerabilities))
	insertedValueStrings := make([]string, 0, len(vulnerabilities))
	insertedValueArgs := make([]interface{}, 0, len(vulnerabilities)*4)
	for indx, vuln := range vulnerabilities {
		keyPropertiesJSON, err := json.Marshal(vuln.KeyProperties)
		if err != nil {
			return nil, err
		}
		displayPropertiesJSON, err := json.Marshal(vuln.DisplayProperties)
		if err != nil {
			return nil, err
		}

		insertedValueStrings = append(insertedValueStrings, fmt.Sprintf("($%d, $%d, $%d, $%d, $%d, $%d)",
			6*indx+1, 6*indx+2, 6*indx+3, 6*indx+4, 6*indx+5, 6*indx+6))
		insertedValueArgs = append(insertedValueArgs, vuln.ID)
		insertedValueArgs = append(insertedValueArgs, scanInstance.ID)
		insertedValueArgs = append(insertedValueArgs, string(*vuln.Severity))
		insertedValueArgs = append(insertedValueArgs, vuln.CategoryID)
		insertedValueArgs = append(insertedValueArgs, string(keyPropertiesJSON))
		insertedValueArgs = append(insertedValueArgs, string(displayPropertiesJSON))

		insertedVulnerability2ScanInstance := models.Vulnerability2ScanInstance{
			VulnerabilityID:   vuln.ID,
			ScanInstanceID:    scanInstance.ID,
			Severity:          *vuln.Severity,
			CategoryID:        vuln.CategoryID,
			KeyProperties:     vuln.KeyProperties,
			DisplayProperties: vuln.DisplayProperties,
		}
		insertedVulnerabilities2ScanInstance = append(insertedVulnerabilities2ScanInstance,
			&insertedVulnerability2ScanInstance)
	}
	insertQuery := fmt.Sprintf("INSERT INTO vulnerability2scanInstance "+
		" (vulnerability_id, scan_instance_id, severity, category_id, key_properties, display_properties) VALUES %s "+
		" ON CONFLICT (vulnerability_id, scan_instance_id) DO NOTHING ",
		strings.Join(insertedValueStrings, ","))

	_, err = tx.ExecContext(ctx, insertQuery, insertedValueArgs...)
	if err != nil {
		return nil, err
	}

	err = tx.Commit()
	if err != nil {
		return nil, err
	}

	elapsed := time.Since(start)
	simplelog.Info(fmt.Sprintf("ScanInstance InsertVulnerabilities2ScanInstance took: %s", elapsed))
	return insertedVulnerabilities2ScanInstance, nil
}

func (m *scanInstanceRepository) ListByTaskID(ctx context.Context, taskID string) ([]*models.ScanInstance, error) {
	scanInstances := make([]*models.ScanInstance, 0)
	err := m.db.Trier.Try(ctx, func(ctx context.Context) error {
		return m.db.PG.SelectContext(ctx, &scanInstances,
			"SELECT si.id as id, si.scan_id as scan_id, si.task_id as task_id, "+
				" si.raw_report_url as raw_report_url, si.report_url as report_url, si.commit_hash as commit_hash, "+
				" si.start_time as start_time, si.end_time as end_time, st.type_name as scan_type_name "+
				" FROM scanInstance si, scan s, scanType st "+
				" WHERE task_id = $1 AND si.scan_id = s.id AND s.scan_type_id = st.id "+
				" ORDER BY end_time DESC", taskID)
	})
	if err != nil {
		return nil, err
	}
	return scanInstances, nil
}

func (m *scanInstanceRepository) GetByID(ctx context.Context, id int) (*models.ScanInstance, error) {
	scanInstance := models.ScanInstance{}
	err := m.db.PG.GetContext(ctx, &scanInstance, "SELECT * FROM ScanInstance WHERE id = $1", id)
	if err != nil {
		return nil, err
	}
	return &scanInstance, nil
}

func (m *scanInstanceRepository) ListByScanID(ctx context.Context, scanID int) ([]*models.ScanInstance, error) {
	scanInstances := make([]*models.ScanInstance, 0)
	err := m.db.Trier.Try(ctx, func(ctx context.Context) error {
		return m.db.PG.SelectContext(ctx, &scanInstances,
			"SELECT si.id as id, si.scan_id as scan_id, si.task_id as task_id, "+
				" si.raw_report_url as raw_report_url, si.report_url as report_url, si.commit_hash as commit_hash, "+
				" si.start_time as start_time, si.end_time as end_time, st.type_name as scan_type_name "+
				" FROM scanInstance si, scan s, scanType st "+
				" WHERE scan_id = $1 AND si.scan_id = s.id AND s.scan_type_id = st.id "+
				" ORDER BY end_time DESC", scanID)
	})
	if err != nil {
		return nil, err
	}
	return scanInstances, nil
}

func (m *scanInstanceRepository) GetStatisticsByID(ctx context.Context, id int) (*models.ScanInstanceStatistics, error) {
	scanInstanceStatistics := models.ScanInstanceStatistics{}
	err := m.db.Trier.Try(ctx, func(ctx context.Context) error {
		return m.db.PG.GetContext(ctx, &scanInstanceStatistics,
			"SELECT COUNT(*) as total_vulnerabilities_count, "+
				" COALESCE(SUM(CASE WHEN v.status IN ('not_reviewed', 'confirmed', 'to_verify') THEN 1 ELSE 0 END),0) as total_not_false_count, "+
				" COALESCE(SUM(CASE WHEN v.status = 'not_reviewed' THEN 1 ELSE 0 END),0) as total_not_reviewed_count, "+
				" COALESCE(SUM(CASE WHEN v.severity = 'blocker' AND v.status IN ('not_reviewed', 'confirmed', 'to_verify') THEN 1 ELSE 0 END),0) as blocker_not_false_count, "+
				" COALESCE(SUM(CASE WHEN v.severity = 'blocker' AND v.status = 'not_reviewed' THEN 1 ELSE 0 END),0) as blocker_not_reviewed_count, "+
				" COALESCE(SUM(CASE WHEN v.severity = 'critical' AND v.status IN ('not_reviewed', 'confirmed', 'to_verify') THEN 1 ELSE 0 END),0) as critical_not_false_count, "+
				" COALESCE(SUM(CASE WHEN v.severity = 'critical' AND v.status = 'not_reviewed' THEN 1 ELSE 0 END),0) as critical_not_reviewed_count, "+
				" COALESCE(SUM(CASE WHEN v.severity = 'medium' AND v.status IN ('not_reviewed', 'confirmed', 'to_verify') THEN 1 ELSE 0 END),0) as medium_not_false_count, "+
				" COALESCE(SUM(CASE WHEN v.severity = 'medium' AND v.status = 'not_reviewed' THEN 1 ELSE 0 END),0) as medium_not_reviewed_count, "+
				" COALESCE(SUM(CASE WHEN v.severity = 'low' AND v.status IN ('not_reviewed', 'confirmed', 'to_verify') THEN 1 ELSE 0 END),0) as low_not_false_count, "+
				" COALESCE(SUM(CASE WHEN v.severity = 'low' AND v.status = 'not_reviewed' THEN 1 ELSE 0 END),0) as low_not_reviewed_count, "+
				" COALESCE(SUM(CASE WHEN v.severity = 'info' AND v.status IN ('not_reviewed', 'confirmed', 'to_verify') THEN 1 ELSE 0 END),0) as info_not_false_count, "+
				" COALESCE(SUM(CASE WHEN v.severity = 'info' AND v.status = 'not_reviewed' THEN 1 ELSE 0 END),0) as info_not_reviewed_count "+
				" FROM vulnerability2scaninstance v2si "+
				" JOIN vulnerability v ON v.id = v2si.vulnerability_id "+
				" WHERE v2si.scan_instance_id = $1 ",
			id)
	})
	if err != nil && err != sql.ErrNoRows {
		return nil, err
	}
	return &scanInstanceStatistics, nil
}

func (m *scanInstanceRepository) GetSummaryStatisticsFromScanInstances(ctx context.Context,
	scanInstances []*models.ScanInstance) (*models.ScanInstanceStatistics, error) {
	scanInstanceStatistics := models.ScanInstanceStatistics{}
	scanInstancesIDs := make([]int, len(scanInstances))
	for _, scanInstance := range scanInstances {
		scanInstancesIDs = append(scanInstancesIDs, scanInstance.ID)
	}
	ids := &pgtype.Int4Array{}
	err := ids.Set(scanInstancesIDs)
	if err != nil {
		return nil, err
	}
	err = m.db.Trier.Try(ctx, func(ctx context.Context) error {
		return m.db.PG.GetContext(ctx, &scanInstanceStatistics,
			"SELECT COUNT(*) as total_vulnerabilities_count, "+
				" COALESCE(SUM(CASE WHEN v.status IN ('not_reviewed', 'confirmed', 'to_verify') THEN 1 ELSE 0 END),0) as total_not_false_count, "+
				" COALESCE(SUM(CASE WHEN v.status = 'not_reviewed' THEN 1 ELSE 0 END),0) as total_not_reviewed_count, "+
				" COALESCE(SUM(CASE WHEN v.severity = 'blocker' AND v.status IN ('not_reviewed', 'confirmed', 'to_verify') THEN 1 ELSE 0 END),0) as blocker_not_false_count, "+
				" COALESCE(SUM(CASE WHEN v.severity = 'blocker' AND v.status = 'not_reviewed' THEN 1 ELSE 0 END),0) as blocker_not_reviewed_count, "+
				" COALESCE(SUM(CASE WHEN v.severity = 'critical' AND v.status IN ('not_reviewed', 'confirmed', 'to_verify') THEN 1 ELSE 0 END),0) as critical_not_false_count, "+
				" COALESCE(SUM(CASE WHEN v.severity = 'critical' AND v.status = 'not_reviewed' THEN 1 ELSE 0 END),0) as critical_not_reviewed_count, "+
				" COALESCE(SUM(CASE WHEN v.severity = 'medium' AND v.status IN ('not_reviewed', 'confirmed', 'to_verify') THEN 1 ELSE 0 END),0) as medium_not_false_count, "+
				" COALESCE(SUM(CASE WHEN v.severity = 'medium' AND v.status = 'not_reviewed' THEN 1 ELSE 0 END),0) as medium_not_reviewed_count, "+
				" COALESCE(SUM(CASE WHEN v.severity = 'low' AND v.status IN ('not_reviewed', 'confirmed', 'to_verify') THEN 1 ELSE 0 END),0) as low_not_false_count, "+
				" COALESCE(SUM(CASE WHEN v.severity = 'low' AND v.status = 'not_reviewed' THEN 1 ELSE 0 END),0) as low_not_reviewed_count, "+
				" COALESCE(SUM(CASE WHEN v.severity = 'info' AND v.status IN ('not_reviewed', 'confirmed', 'to_verify') THEN 1 ELSE 0 END),0) as info_not_false_count, "+
				" COALESCE(SUM(CASE WHEN v.severity = 'info' AND v.status = 'not_reviewed' THEN 1 ELSE 0 END),0) as info_not_reviewed_count "+
				" FROM vulnerability2scaninstance v2si "+
				" JOIN vulnerability v ON v.id = v2si.vulnerability_id "+
				" WHERE v2si.scan_instance_id = ANY($1) ",
			ids)
	})
	if err != nil {
		return nil, err
	}
	return &scanInstanceStatistics, nil
}

func (m *scanInstanceRepository) GetLastScanInstancesByOrganizationID(ctx context.Context,
	organizationID int) ([]*models.ScanInstance, error) {
	scanInstances := make([]*models.ScanInstance, 0)
	err := m.db.Trier.Try(ctx, func(ctx context.Context) error {
		return m.db.PG.SelectContext(ctx, &scanInstances,
			"SELECT * "+
				" FROM scanInstance"+
				" WHERE id IN ( "+
				" 	SELECT MAX(si.id) as last_scan_instance_id "+
				" 	FROM scanInstance si "+
				"	JOIN scan s ON si.scan_id = s.id "+
				"	JOIN project p ON p.id = s.project_id "+
				"	JOIN organization o ON o.id = p.organization_id "+
				"	WHERE o.id = $1 "+
				"	GROUP BY o.id, p.id, si.scan_id "+
				" )", organizationID)
	})
	if err != nil {
		return nil, err
	}
	return scanInstances, nil
}
