package ru.yandex.bannerstorage.harvester.queues;

import java.sql.Timestamp;
import java.util.List;
import java.util.UUID;
import java.util.stream.Collectors;

import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
import org.springframework.jdbc.core.JdbcTemplate;

import static java.util.stream.Collectors.joining;

public class TaskRepository {
    private final Logger logger = LoggerFactory.getLogger(TaskRepository.class);
    private final JdbcTemplate jdbcTemplate;

    public TaskRepository(JdbcTemplate jdbcTemplate) {
        this.jdbcTemplate = jdbcTemplate;
    }

    public List<Task> lockAndGetNextTasks(String table,
                                          String taskType,
                                          UUID workerId,
                                          int count) {
        int alreadyLockedCount = jdbcTemplate.queryForObject(String.format(
                "select count(*) from %s" +
                " where type = ? and next_try_time <= CURRENT_TIMESTAMP and worker_id = '%s'",
                table, workerId
        ), Integer.class, taskType);

        if (alreadyLockedCount == 0) {
            logger.info("Trying to lock next {} tasks for workerId={}", count, workerId);
            String lockSql = String.format(
                    ";with cte as (" +
                    "  select top %d *" +
                    "  from %s" +
                    "  where [type] = ? and next_try_time <= CURRENT_TIMESTAMP and worker_id is null" +
                    "  order by next_try_time" +
                    " ) update cte set worker_id = '%s'" +
                    " where worker_id is null",
                    count, table, workerId
            );
            int updated = jdbcTemplate.update(lockSql, taskType);
            logger.info("Locked {} tasks", updated);
        } else {
            logger.warn("Found {} already locked tasks", alreadyLockedCount);
        }

        String getSql = String.format(
                "select top %d" +
                "  id," +
                "  [type]," +
                "  data," +
                "  created_time," +
                "  worker_id," +
                "  errors_count," +
                "  last_error," +
                "  last_error_time," +
                "  next_try_time" +
                " from %s" +
                " where type = ? and next_try_time <= CURRENT_TIMESTAMP and worker_id = '%s'" +
                " order by next_try_time",
                count, table, workerId
        );
        List<Task> tasks = jdbcTemplate.query(getSql, (rs, rowNum) -> {
            Task task = new Task();
            task.setId(UUID.fromString(rs.getString("id")));
            task.setType(rs.getString("type"));
            task.setData(rs.getString("data"));
            task.setCreatedTime(rs.getDate("created_time"));
            task.setWorkerId(UUID.fromString(rs.getString("worker_id")));
            task.setErrorsCount(rs.getInt("errors_count"));
            task.setLastError(rs.getString("last_error"));
            task.setLastErrorTime(rs.getDate("last_error_time"));
            task.setNextTryTime(rs.getDate("next_try_time"));
            return task;
        }, taskType);
        logger.info("Got {} tasks", tasks.size());
        return tasks;
    }

    public void updateTasks(String table, UUID workerId, List<Task> tasks) {
        logger.info("Updating {} tasks for workerId={} from table {}", tasks.size(), workerId, table);
        String sql = String.format(
                "update %s" +
                " set" +
                "  errors_count = ?," +
                "  last_error = ?," +
                "  last_error_time = ?," +
                "  next_try_time = ?" +
                " where id = ? and worker_id = ?",
                table
        );
        List<Object[]> collect = tasks.stream().map(
                task -> new Object[]{
                        task.getErrorsCount(),
                        task.getLastError(),
                        new Timestamp(task.getLastErrorTime().getTime()),
                        new Timestamp(task.getNextTryTime().getTime()),
                        task.getId().toString(),
                        workerId.toString()}
        ).collect(Collectors.toList());

        jdbcTemplate.batchUpdate(sql, collect);
    }

    public void deleteTasks(String table, UUID workerId, List<UUID> taskIds) {
        String sql = String.format("delete from %s where id in (%s) and worker_id = '%s'",
                table,
                taskIds.stream().map(uuid -> "'" + uuid + "'").collect(joining(",")),
                workerId
        );
        logger.info("Deleting {} tasks for workerId={} from table {}", taskIds.size(), workerId, table);
        jdbcTemplate.execute(sql);
    }
}
