package vulnerability

import (
	"context"
	"database/sql"
	"encoding/json"
	"fmt"
	"strings"
	"time"

	"github.com/gofrs/uuid"
	"github.com/jackc/pgtype"

	"a.yandex-team.ru/security/impulse/api/internal/db"
	"a.yandex-team.ru/security/impulse/api/internal/sqlbuilder"
	"a.yandex-team.ru/security/impulse/models"
	"a.yandex-team.ru/security/libs/go/simplelog"
)

type vulnerabilityRepository struct {
	db            *db.DB
	filterBuilder *sqlbuilder.SQLBuilder
}

func NewVulnerabilityRepository(db *db.DB) Repository {
	filterBuilder := sqlbuilder.New(map[string]bool{
		"id":                      true,
		"project.id":              true,
		"project.name":            true,
		"project.organization_id": true,
		"st.type_name":            true,
		"st.title":                true,
		"v.severity":              true,
		"v.status":                true,
		"v.key_properties":        true,
		"v.display_properties":    true,
		"v.category_id":           true,
		"vc.category_name":        true,
		"v.tracker_ticket":        true,
		"v.first_found_at":        true})
	return &vulnerabilityRepository{db: db, filterBuilder: &filterBuilder}
}

func (v vulnerabilityRepository) fetch(ctx context.Context, query string, args ...interface{}) ([]*models.Vulnerability, error) {
	rows, err := v.db.PG.QueryContext(ctx, query, args...)
	if err != nil {
		err = fmt.Errorf("failed to get vulnerabilities: %w", err)
		return nil, err
	}

	defer func() {
		err := rows.Close()
		if err != nil {
			simplelog.Error("failed to close", "err", err)
		}
	}()

	result := make([]*models.Vulnerability, 0)
	for rows.Next() {
		t := new(models.Vulnerability)
		err = rows.Scan(
			&t.ID,
			&t.ScanID,
			&t.KeyProperties,
			&t.DisplayProperties,
		)
		if err != nil {
			return nil, err
		}
		result = append(result, t)
	}

	return result, rows.Err()
}

func (v vulnerabilityRepository) defaultOrderString() string {
	return " CASE v.status " +
		" WHEN 'not_reviewed' THEN 0 " +
		" WHEN 'to_verify' THEN 1 " +
		" WHEN 'confirmed' THEN 2 " +
		" WHEN 'not_exploitable' THEN 3 " +
		" WHEN 'not_an_issue' THEN 4 " +
		" END ASC, " +
		" CASE v.severity " +
		" WHEN 'info' THEN 0 " +
		" WHEN 'low' THEN 1 " +
		" WHEN 'medium' THEN 2 " +
		" WHEN 'critical' THEN 3 " +
		" WHEN 'blocker' THEN 4 " +
		" END DESC, " +
		" v.id DESC "
}

func (v vulnerabilityRepository) GetByVulnerabilityIDAndScanInstanceID(ctx context.Context, vulnerabilityID int,
	scanInstanceID int) (*models.Vulnerability, error) {
	vulnerability := models.Vulnerability{}
	err := v.db.PG.GetContext(ctx, &vulnerability,
		"SELECT v.id as id, scan.project_id as project_id, project.name as project_name, "+
			" si.scan_id as scan_id, v2s.severity as severity, "+
			" v.status as status, v.category_id as category_id, vc.name as category_name, "+
			" v2s.key_properties as key_properties, v2s.display_properties as display_properties, "+
			" v.first_found_at as first_found_at, v.tracker_ticket as tracker_ticket "+
			" FROM vulnerability2scaninstance v2s LEFT JOIN scanInstance si ON v2s.scan_instance_id = si.id "+
			" LEFT JOIN scan ON si.scan_id = scan.id "+
			" LEFT JOIN project ON project.id = scan.project_id "+
			" JOIN vulnerability v ON v2s.vulnerability_id = v.id "+
			" JOIN vulnerabilityCategory vc ON v2s.category_id = vc.id "+
			" WHERE v2s.vulnerability_id = $1 AND v2s.scan_instance_id = $2 ",
		vulnerabilityID, scanInstanceID)
	if err != nil {
		return nil, err
	}
	return &vulnerability, nil
}

func (v vulnerabilityRepository) GetByVulnerabilityID(ctx context.Context,
	vulnerabilityID int) (*models.Vulnerability, error) {
	vulnerability := models.Vulnerability{}
	err := v.db.PG.GetContext(ctx, &vulnerability,
		"SELECT v.id as id, scan.project_id as project_id, project.name as project_name, "+
			" scan.id as scan_id, v.severity as severity, "+
			" v.status as status, v.category_id as category_id, vc.name as category_name, "+
			" v.key_properties as key_properties, v.display_properties as display_properties, "+
			" v.first_found_at as first_found_at, v.tracker_ticket as tracker_ticket "+
			" FROM vulnerability v "+
			" LEFT JOIN scan ON v.scan_id = scan.id "+
			" LEFT JOIN project ON project.id = scan.project_id "+
			" JOIN vulnerabilityCategory vc ON v.category_id = vc.id "+
			" WHERE v.id = $1", vulnerabilityID)
	if err != nil {
		return nil, err
	}
	return &vulnerability, nil
}

func (v vulnerabilityRepository) FetchByScanID(ctx context.Context, scanID int) (res []*models.Vulnerability, err error) {
	start := time.Now()

	query := `SELECT id, scan_id, key_properties, display_properties
				FROM vulnerability WHERE scan_id = $1`
	res, err = v.fetch(ctx, query, scanID)
	if err != nil {
		return nil, err
	}

	elapsed := time.Since(start)
	simplelog.Info(fmt.Sprintf("Vulnerability FetchByScanID took: %s", elapsed))
	return
}

func (v vulnerabilityRepository) UpdateByScanID(ctx context.Context, scanID int, vulnerabilities []*models.Vulnerability) error {
	start := time.Now()

	tx, err := v.db.PG.BeginTx(ctx, nil)
	if err != nil {
		return err
	}
	for _, vuln := range vulnerabilities {
		keyPropertiesJSON, err := json.Marshal(vuln.KeyProperties)
		if err != nil {
			_ = tx.Rollback()
			return err
		}
		displayPropertiesJSON, err := json.Marshal(vuln.DisplayProperties)
		if err != nil {
			_ = tx.Rollback()
			return err
		}
		_, err = tx.Exec(`UPDATE vulnerability SET severity = $1, key_properties = $2, display_properties = $3 WHERE id=$4 AND scan_id=$5`,
			string(*vuln.Severity), string(keyPropertiesJSON), string(displayPropertiesJSON), vuln.ID, scanID)
		if err != nil {
			_ = tx.Rollback()
			return err
		}
	}
	err = tx.Commit()
	if err != nil {
		return err
	}

	elapsed := time.Since(start)
	simplelog.Info(fmt.Sprintf("Vulnerability UpdateByScanID took: %s", elapsed))
	return nil
}

func (v vulnerabilityRepository) InsertByScanIDAndLastUpdateToken(ctx context.Context, scanID int, lastUpdateToken string,
	vulnerabilities []*models.Vulnerability) ([]*models.Vulnerability, error) {
	start := time.Now()
	insertedVulnerabilities := make([]*models.Vulnerability, 0, len(vulnerabilities))

	tx, err := v.db.PG.BeginTx(ctx, nil)
	err = func() error {
		if err != nil {
			return err
		}
		_, err = tx.ExecContext(ctx, `set transaction isolation level repeatable read`)
		if err != nil {
			return err
		}

		currentScan := models.Scan{}
		err = tx.QueryRowContext(ctx, "SELECT id, project_id, scan_type_id, last_update_token FROM scan WHERE id = $1", scanID).Scan(
			&currentScan.ID, &currentScan.ProjectID, &currentScan.ScanTypeID, &currentScan.LastUpdateToken)
		if err != nil {
			return err
		}
		if currentScan.LastUpdateToken != lastUpdateToken {
			return fmt.Errorf("lastUpdateToken Expired")
		}

		insertedValueStrings := make([]string, 0, len(vulnerabilities))
		insertedValueArgs := make([]interface{}, 0, len(vulnerabilities)*3)
		for indx, vuln := range vulnerabilities {
			keyPropertiesJSON, err := json.Marshal(vuln.KeyProperties)
			if err != nil {
				return err
			}
			displayPropertiesJSON, err := json.Marshal(vuln.DisplayProperties)
			if err != nil {
				return err
			}
			insertedValueStrings = append(insertedValueStrings, fmt.Sprintf("($%d, $%d, $%d, $%d, NOW(), $%d, $%d)",
				6*indx+1, 6*indx+2, 6*indx+3, 6*indx+4, 6*indx+5, 6*indx+6))
			insertedValueArgs = append(insertedValueArgs, vuln.ScanID)
			insertedValueArgs = append(insertedValueArgs, string(*vuln.Severity))
			insertedValueArgs = append(insertedValueArgs, string(models.NotReviewed))
			insertedValueArgs = append(insertedValueArgs, vuln.CategoryID)
			insertedValueArgs = append(insertedValueArgs, string(keyPropertiesJSON))
			insertedValueArgs = append(insertedValueArgs, string(displayPropertiesJSON))

			insertedVulnerabilities = append(insertedVulnerabilities, &models.Vulnerability{
				ID:                0,
				ScanID:            vuln.ScanID,
				Severity:          vuln.Severity,
				CategoryID:        vuln.CategoryID,
				KeyProperties:     vuln.KeyProperties,
				DisplayProperties: vuln.DisplayProperties,
			})
		}

		insertQuery := fmt.Sprintf("INSERT INTO vulnerability (scan_id, severity, status, category_id, first_found_at, key_properties, display_properties) "+
			" VALUES %s RETURNING id ",
			strings.Join(insertedValueStrings, ","))

		rows, err := tx.QueryContext(ctx, insertQuery, insertedValueArgs...)
		if err != nil {
			return err
		}

		defer func() {
			err := rows.Close()
			if err != nil {
				simplelog.Error("failed to close", "err", err)
			}
		}()

		for indx := 0; rows.Next(); indx++ {
			err = rows.Scan(&insertedVulnerabilities[indx].ID)
			if err != nil {
				return err
			}
		}

		if err := rows.Err(); err != nil {
			return err
		}

		newLastUpdateToken := uuid.Must(uuid.NewV4()).String()
		res, err := tx.ExecContext(ctx, `UPDATE scan SET last_update_token = $1 WHERE id = $2 AND last_update_token = $3`,
			newLastUpdateToken, scanID, lastUpdateToken)
		if err != nil {
			return fmt.Errorf("lastUpdateToken Expired: %w", err)
		}
		countUpdatedRows, err := res.RowsAffected()
		if err != nil {
			return fmt.Errorf("lastUpdateToken Expired: %w", err)
		}
		if countUpdatedRows != 1 {
			return fmt.Errorf("lastUpdateToken Expired")
		}
		err = tx.Commit()
		if err != nil {
			return err
		}
		return nil
	}()
	if err != nil {
		if err2 := tx.Rollback(); err2 != nil {
			simplelog.Error("rollback failed", "err", err2)
		}
		return nil, err
	}

	elapsed := time.Since(start)
	simplelog.Info(fmt.Sprintf("Vulnerability InsertByScanIDAndLastUpdateToken took: %s", elapsed))
	return insertedVulnerabilities, nil
}

func (v vulnerabilityRepository) UpdateFromScanInstanceVulnerabilities(ctx context.Context, scanInstanceID int) error {
	start := time.Now()

	tx, err := v.db.PG.BeginTx(ctx, nil)
	if err != nil {
		return err
	}

	_, err = tx.ExecContext(ctx, `UPDATE vulnerability v
SET severity = v2s.severity, key_properties = v2s.key_properties, display_properties = v2s.display_properties
FROM vulnerability2scanInstance v2s
WHERE v2s.scan_instance_id = $1 AND v2s.vulnerability_id = v.id`, scanInstanceID)
	if err != nil {
		return err
	}

	err = tx.Commit()
	if err != nil {
		return err
	}

	elapsed := time.Since(start)
	simplelog.Info(fmt.Sprintf("Vulnerability UpdateFromScanInstanceVulnerabilities took: %s", elapsed))
	return nil
}

func (v vulnerabilityRepository) FetchByScanInstanceID(ctx context.Context, scanInstanceID int, limit int, offset int) ([]*models.Vulnerability, error) {
	vulnerabilities := make([]*models.Vulnerability, 0)
	err := v.db.PG.SelectContext(ctx, &vulnerabilities,
		"SELECT v.id as id, s.project_id as project_id, project.name as project_name, "+
			" v.scan_id as scan_id, v2si.scan_instance_id as scan_instance_id, st.type_name as scan_type_name, "+
			" st.title as scan_type_title, v.severity as severity, v.status as status, "+
			" v2si.key_properties as key_properties, v2si.display_properties as display_properties, "+
			" v.category_id as category_id, vc.name as category_name,  v.first_found_at as first_found_at  "+
			" FROM Vulnerability2ScanInstance v2si "+
			" JOIN Vulnerability v ON v.id = v2si.vulnerability_id "+
			" JOIN VulnerabilityCategory vc ON v.category_id = vc.id "+
			" JOIN Scan s ON v.scan_id = s.id "+
			" JOIN ScanType st ON s.scan_type_id = st.id "+
			" JOIN Project ON s.project_id = project.id "+
			" WHERE v2si.scan_instance_id = $1 "+
			" ORDER BY "+v.defaultOrderString()+" "+
			" LIMIT $2 OFFSET $3 ",
		scanInstanceID, limit, offset)
	if err != nil {
		return nil, err
	}
	return vulnerabilities, nil
}

func (v vulnerabilityRepository) FetchByScanInstanceIDAndCategoryID(ctx context.Context, scanInstanceID int, categoryID int,
	limit int, offset int) ([]*models.Vulnerability, error) {
	vulnerabilities := make([]*models.Vulnerability, 0)
	err := v.db.PG.SelectContext(ctx, &vulnerabilities,
		"SELECT v.id as id, s.project_id as project_id, project.name as project_name, "+
			" v.scan_id as scan_id, v2si.scan_instance_id as scan_instance_id, "+
			" st.type_name as scan_type_name, st.title as scan_type_title, "+
			" v.severity as severity, v.status as status, "+
			" v2si.key_properties as key_properties, v2si.display_properties as display_properties, "+
			" v.category_id as category_id, vc.name as category_name,  v.first_found_at as first_found_at  "+
			" FROM Vulnerability2ScanInstance v2si "+
			" JOIN Vulnerability v ON v.id = v2si.vulnerability_id "+
			" JOIN VulnerabilityCategory vc ON v.category_id = vc.id "+
			" JOIN Scan s ON v.scan_id = s.id "+
			" JOIN ScanType st ON s.scan_type_id = st.id "+
			" JOIN Project ON s.project_id = project.id "+
			" WHERE v2si.scan_instance_id = $1 AND v.category_id = $2 "+
			" ORDER BY "+v.defaultOrderString()+" "+
			" LIMIT $3 OFFSET $4 ",
		scanInstanceID, categoryID, limit, offset)
	if err != nil {
		return nil, err
	}
	return vulnerabilities, nil
}

func (v vulnerabilityRepository) FetchByScanInstances(ctx context.Context, scanInstances []*models.ScanInstance,
	limit int, offset int, filterMap map[string]interface{}) ([]*models.Vulnerability, error) {
	vulnerabilities := make([]*models.Vulnerability, 0)
	scanInstancesIDs := make([]int, len(scanInstances))
	for _, scanInstance := range scanInstances {
		scanInstancesIDs = append(scanInstancesIDs, scanInstance.ID)
	}
	ids := &pgtype.Int4Array{}
	err := ids.Set(scanInstancesIDs)
	if err != nil {
		return nil, err
	}
	filterExpression, filterArguments, err := v.filterBuilder.BuildNotFormattedWhere(filterMap)
	if err != nil {
		return nil, err
	}
	if len(filterExpression) > 0 {
		filterExpression = " AND " + filterExpression
	}
	if len(filterArguments) == 0 {
		filterArguments = []interface{}{ids}
	} else {
		filterArguments = append([]interface{}{ids}, filterArguments...)
	}
	filterArguments = append(filterArguments, limit, offset)
	positionNumbers := make([]interface{}, len(filterArguments))
	for i := 1; i <= len(filterArguments); i++ {
		positionNumbers[i-1] = i
	}
	selectQuery := fmt.Sprintf("SELECT v.id as id, s.project_id as project_id, project.name as project_name, "+
		" v.scan_id as scan_id, v2si.scan_instance_id as scan_instance_id, "+
		" st.type_name as scan_type_name, st.title as scan_type_title, v.severity as severity, v.status as status, "+
		" v2si.key_properties as key_properties, v2si.display_properties as display_properties, "+
		" v.category_id as category_id, vc.name as category_name,  v.first_found_at as first_found_at "+
		" FROM Vulnerability2ScanInstance v2si "+
		" JOIN Vulnerability v ON v.id = v2si.vulnerability_id "+
		" JOIN VulnerabilityCategory vc ON v.category_id = vc.id "+
		" JOIN Scan s ON v.scan_id = s.id "+
		" JOIN ScanType st ON s.scan_type_id = st.id "+
		" JOIN Project ON s.project_id = project.id "+
		" WHERE v2si.scan_instance_id = ANY ($%d)"+
		" "+filterExpression+
		" ORDER BY "+v.defaultOrderString()+" "+
		" LIMIT $%d OFFSET $%d ", positionNumbers...)
	err = v.db.Trier.Try(ctx, func(ctx context.Context) error {
		return v.db.PG.SelectContext(ctx, &vulnerabilities,
			selectQuery, filterArguments...)
	})
	if err != nil && err != sql.ErrNoRows {
		return nil, err
	}
	return vulnerabilities, nil
}

func (v vulnerabilityRepository) FetchLatest(ctx context.Context,
	limit int, offset int, filterMap map[string]interface{}) ([]*models.Vulnerability, error) {
	vulnerabilities := make([]*models.Vulnerability, 0)
	filterExpression, filterArguments, err := v.filterBuilder.BuildNotFormattedWhere(filterMap)
	if err != nil {
		return nil, err
	}

	filterArguments = append(filterArguments, limit, offset)
	positionNumbers := make([]interface{}, len(filterArguments))
	for i := 1; i <= len(filterArguments); i++ {
		positionNumbers[i-1] = i
	}
	selectQuery := fmt.Sprintf("SELECT v.id as id, s.project_id as project_id, project.name as project_name, "+
		" v.scan_id as scan_id, v2si.scan_instance_id as scan_instance_id, "+
		" st.type_name as scan_type_name, st.title as scan_type_title, v.severity as severity, v.status as status, "+
		" v2si.key_properties as key_properties, v2si.display_properties as display_properties, "+
		" v.category_id as category_id, vc.name as category_name,  v.first_found_at as first_found_at "+
		" FROM Vulnerability2ScanInstance v2si "+
		" JOIN Vulnerability v ON v.id = v2si.vulnerability_id "+
		" JOIN VulnerabilityCategory vc ON v.category_id = vc.id "+
		" JOIN LastScanInstance lsi ON v.scan_id = lsi.scan_id AND v2si.scan_instance_id = lsi.scan_instance_id "+
		" JOIN Scan s ON lsi.scan_id = s.id "+
		" JOIN ScanType st ON s.scan_type_id = st.id "+
		" JOIN Project ON s.project_id = project.id "+
		" WHERE "+filterExpression+
		" ORDER BY v.id DESC "+
		" LIMIT $%d OFFSET $%d ", positionNumbers...)
	err = v.db.Trier.Try(ctx, func(ctx context.Context) error {
		return v.db.PG.SelectContext(ctx, &vulnerabilities,
			selectQuery, filterArguments...)
	})
	if err != nil && err != sql.ErrNoRows {
		return nil, err
	}
	return vulnerabilities, nil
}

func (v vulnerabilityRepository) ListVulnerabilityCategories(ctx context.Context, scanTypeID, scanID int) ([]*models.VulnerabilityCategoryResponseDTO, error) {
	categories := make([]*models.VulnerabilityCategoryResponseDTO, 0)
	err := v.db.Trier.Try(ctx, func(ctx context.Context) (err error) {
		err = v.db.PG.SelectContext(ctx, &categories, `
SELECT DISTINCT
	c.id as id, c.name as name
FROM
	VulnerabilityCategory c
JOIN
	Vulnerability v
ON
	v.category_id = c.id
WHERE
	c.scan_type_id = $1 AND v.scan_id = $2`,
			scanTypeID, scanID)
		if err != nil {
			if !db.IsRetriableError(err) {
				return
			}

			err = fmt.Errorf("failed to list vulnerability categories: %w", err)
			return
		}
		return
	})
	return categories, err
}

func (v vulnerabilityRepository) UpdateStatusByIDs(ctx context.Context, status models.StatusType, IDs []int) error {
	postgresIDs := pgtype.Int4Array{}
	err := postgresIDs.Set(IDs)
	if err != nil {
		return err
	}
	err = v.db.Trier.Try(ctx, func(ctx context.Context) error {
		_, err := v.db.PG.ExecContext(ctx, "UPDATE vulnerability SET status = $1 WHERE id = ANY($2)",
			status, postgresIDs)
		if err != nil {
			return err
		}
		return nil
	})
	return err
}

func (v vulnerabilityRepository) UpdateTrackerTicketByID(ctx context.Context, trackerTicket string, id int) error {
	err := v.db.Trier.Try(ctx, func(ctx context.Context) error {
		_, err := v.db.PG.ExecContext(ctx, "UPDATE vulnerability SET tracker_ticket = $1 WHERE id = $2",
			trackerTicket, id)
		if err != nil {
			return err
		}
		return nil
	})
	return err
}
