package task

import (
	"context"
	"database/sql"

	"golang.org/x/xerrors"

	"a.yandex-team.ru/security/impulse/api/internal/db"
	"a.yandex-team.ru/security/impulse/models"
)

type taskRepository struct {
	db *db.DB
}

func NewTaskRepository(db *db.DB) Repository {
	return &taskRepository{db}
}

func (m *taskRepository) Create(ctx context.Context, t *models.Task) (err error) {
	err = m.db.Trier.Try(ctx, func(ctx context.Context) (err error) {
		_, err = m.db.PG.NamedExecContext(ctx, `
INSERT INTO
	task (task_id, organization_id, project_id, workflow_id, workflow_instance_id, parameters, analysers, sandbox_task_id, start_time, end_time, status, cron_id, callback_url, non_template_scan)
VALUES
	(:task_id, :organization_id, :project_id, :workflow_id, :workflow_instance_id, :parameters, :analysers, :sandbox_task_id, :start_time, :end_time, :status, :cron_id, :callback_url, :non_template_scan)
`, t)

		if err != nil {
			if !db.IsRetriableError(err) {
				return
			}

			err = xerrors.Errorf("failed to create task record %v: %w", t.TaskID, err)
			return
		}
		return
	})
	return
}

func (m *taskRepository) Update(ctx context.Context, t *models.Task) (int64, error) {
	var rows int64
	err := m.db.Trier.Try(ctx, func(ctx context.Context) error {
		res, err := m.db.PG.NamedExecContext(ctx, `
UPDATE task SET
	status = :status, workflow_instance_id = :workflow_instance_id, end_time = :end_time, sandbox_task_id = :sandbox_task_id
WHERE
	task_id = :task_id
`, t)

		if err != nil {
			if !db.IsRetriableError(err) {
				return err
			}

			err = xerrors.Errorf("failed to update task status %v: %w", t.TaskID, err)
			return err
		}
		rows, err = res.RowsAffected()
		if err != nil {
			err = xerrors.Errorf("failed to get number of updated rows %v: %w", t.TaskID, err)
			return err
		}
		return err
	})
	return rows, err
}

func (m *taskRepository) GetByTaskID(ctx context.Context, taskID string) (t *models.Task, err error) {
	t = new(models.Task)
	err = m.db.Trier.Try(ctx, func(ctx context.Context) (err error) {
		err = m.db.PG.GetContext(ctx, t, `
SELECT
	*
FROM
	task
WHERE
	task_id = $1`, taskID)

		if err != nil {
			if !db.IsRetriableError(err) {
				return
			}

			err = xerrors.Errorf("failed to get task record %s: %w", taskID, err)
			return
		}
		return
	})
	return
}

func (m *taskRepository) CheckFinished(ctx context.Context, taskID string) (result bool, err error) {
	err = m.db.Trier.Try(ctx, func(ctx context.Context) (err error) {
		err = m.db.PG.GetContext(ctx, &result, `
SELECT COUNT(DISTINCT(count)) = 1 FROM (
		SELECT
			COUNT(DISTINCT(st.type_name))
		FROM
			scaninstance si
		JOIN 
			task t ON si.task_id = t.task_id
		JOIN scan s ON si.scan_id = s.id
		JOIN scantype st ON s.scan_type_id = st.id
		WHERE
			t.task_id = $1
	UNION ALL
		SELECT
			json_array_length(analysers)
		FROM
			task
		WHERE task_id = $1
) AS sub`, taskID)

		if err != nil {
			if !db.IsRetriableError(err) {
				return
			}

			err = xerrors.Errorf("failed to get task record %s: %w", taskID, err)
			return
		}
		return
	})
	return
}

func (m *taskRepository) GetLastTemplate(ctx context.Context, projectID int) (t *models.TaskResponseDTO, err error) {
	t = new(models.TaskResponseDTO)
	err = m.db.Trier.Try(ctx, func(ctx context.Context) (err error) {
		err = m.db.PG.GetContext(ctx, t, `
SELECT
	t.task_id as task_id, t.organization_id as organization_id, t.project_id as project_id, t.workflow_id as workflow_id,
	t.workflow_instance_id as workflow_instance_id, t.parameters as parameters, t.analysers as analysers, t.sandbox_task_id as sandbox_task_id,
	t.start_time as start_time, t.end_time as end_time, t.status as status, t.cron_id as cron_id, t.callback_url as callback_url, t.non_template_scan as non_template_scan,
	COALESCE(p.name, '') as project_name
FROM
	task t
LEFT JOIN
	project p
ON
	t.project_id = p.id
WHERE
	project_id = $1 AND non_template_scan = false
ORDER BY start_time DESC
LIMIT 1`, projectID)

		if err != nil {
			if !db.IsRetriableError(err) {
				return
			}

			err = xerrors.Errorf("failed to get last task record for project %v: %w", projectID, err)
			return
		}
		return
	})

	if err != nil && err == sql.ErrNoRows {
		return nil, nil
	}

	return
}

func (m *taskRepository) List(ctx context.Context, organizationID, projectID, limit, offset int) (t []*models.TaskResponseDTO, err error) {
	t = make([]*models.TaskResponseDTO, 0)
	err = m.db.Trier.Try(ctx, func(ctx context.Context) (err error) {
		err = m.db.PG.SelectContext(ctx, &t, `
SELECT
	t.task_id as task_id, t.organization_id as organization_id, t.project_id as project_id, t.workflow_id as workflow_id,
	t.workflow_instance_id as workflow_instance_id, t.parameters as parameters, t.analysers as analysers, t.sandbox_task_id as sandbox_task_id,
	t.start_time as start_time, t.end_time as end_time, t.status as status, t.cron_id as cron_id, t.callback_url as callback_url, t.non_template_scan as non_template_scan,
	COALESCE(w.name, '') as workflow
FROM
	task t
LEFT JOIN
	workflow w
ON
	t.workflow_id = w.id
WHERE
	organization_id = $1 AND project_id = $2
ORDER BY
	start_time DESC
LIMIT $3
OFFSET $4`, organizationID, projectID, limit, offset)

		if err != nil {
			if !db.IsRetriableError(err) {
				return
			}

			err = xerrors.Errorf("failed to list tasks: %w", err)
			return
		}
		return
	})
	return
}

func (m *taskRepository) ListByOrganizationID(ctx context.Context, organizationID, limit, offset int) (t []*models.TaskResponseDTO, err error) {
	t = make([]*models.TaskResponseDTO, 0)
	err = m.db.Trier.Try(ctx, func(ctx context.Context) (err error) {
		err = m.db.PG.SelectContext(ctx, &t, `
SELECT
	t.task_id as task_id, t.organization_id as organization_id, t.project_id as project_id, t.workflow_id as workflow_id,
	t.workflow_instance_id as workflow_instance_id, t.parameters as parameters, t.analysers as analysers, t.sandbox_task_id as sandbox_task_id,
	t.start_time as start_time, t.end_time as end_time, t.status as status, t.cron_id as cron_id, t.callback_url as callback_url, t.non_template_scan as non_template_scan,
	COALESCE(p.name, '') as project_name
FROM
	task t
LEFT JOIN
	project p
ON
	t.project_id = p.id
WHERE
	t.organization_id = $1 AND t.non_template_scan = false
ORDER BY
	t.start_time DESC
LIMIT $2
OFFSET $3`, organizationID, limit, offset)

		if err != nil {
			if !db.IsRetriableError(err) {
				return
			}

			err = xerrors.Errorf("failed to list tasks: %w", err)
			return
		}
		return
	})
	return
}

func (m *taskRepository) GetTotal(ctx context.Context, organizationID, projectID int) (total int, err error) {
	err = m.db.Trier.Try(ctx, func(ctx context.Context) (err error) {
		err = m.db.PG.QueryRowContext(ctx,
			`SELECT COUNT(*) as count FROM task WHERE organization_id = $1 AND project_id = $2`,
			organizationID, projectID).Scan(&total)

		if err != nil {
			if !db.IsRetriableError(err) {
				return
			}

			err = xerrors.Errorf("failed to get tasks total: %w", err)
			return
		}
		return
	})
	return
}

func (m *taskRepository) GetTotalByOrganizationID(ctx context.Context, organizationID int) (total int, err error) {
	err = m.db.Trier.Try(ctx, func(ctx context.Context) (err error) {
		err = m.db.PG.QueryRowContext(ctx,
			`SELECT COUNT(*) as count FROM task WHERE organization_id = $1 AND non_template_scan = false`,
			organizationID).Scan(&total)

		if err != nil {
			if !db.IsRetriableError(err) {
				return
			}

			err = xerrors.Errorf("failed to get tasks total: %w", err)
			return
		}
		return
	})
	return
}
