package ru.yandex.travel.workflow.repository;

import java.time.ZoneId;
import java.time.ZonedDateTime;
import java.util.ArrayList;
import java.util.HashMap;
import java.util.List;
import java.util.Map;
import java.util.Set;
import java.util.UUID;
import java.util.stream.Collectors;

import javax.persistence.EntityManager;

import org.apache.commons.lang3.tuple.Pair;
import org.springframework.data.jpa.repository.support.SimpleJpaRepository;
import org.springframework.jdbc.core.namedparam.MapSqlParameterSource;
import org.springframework.jdbc.core.namedparam.NamedParameterJdbcTemplate;
import org.springframework.stereotype.Component;

import ru.yandex.travel.workflow.EWorkflowEventState;
import ru.yandex.travel.workflow.EWorkflowState;
import ru.yandex.travel.workflow.entities.Workflow;

@Component
public class CustomizedWorkflowRepositoryImpl extends SimpleJpaRepository<Workflow, UUID>
        implements CustomizedWorkflowRepository<Workflow> {
    private static final String FIND_SUPERVISED_WORKFLOWS_WITH_STATE =
            "WITH RECURSIVE nodes(id, supervisor_id, state) AS " +
                    "( SELECT w.id, w.supervisor_id, w.state " +
                    " FROM workflows w where supervisor_id = :id " +
                    " UNION ALL" +
                    " SELECT w.id, w.supervisor_id, w.state " +
                    " FROM workflows w, nodes n " +
                    " WHERE n.id = w.supervisor_id) " +
                    "SELECT id FROM nodes WHERE state = :state";

    private final NamedParameterJdbcTemplate namedParameterJdbcTemplate;
    private final WorkflowRepositoryProcessingPoolsConfig processingPoolsConfig;

    public CustomizedWorkflowRepositoryImpl(EntityManager em, NamedParameterJdbcTemplate namedParameterJdbcTemplate,
                                            WorkflowRepositoryProcessingPoolsConfig processingPoolsConfig) {
        super(Workflow.class, em);
        this.namedParameterJdbcTemplate = namedParameterJdbcTemplate;
        this.processingPoolsConfig = processingPoolsConfig;
    }

    public Map<UUID, Integer> findWorkflowsToBeScheduledByPools(Set<UUID> idsToExclude,
                                                                int defaultPoolId,
                                                                Map<Integer, Integer> poolLimits) {

        var now = ZonedDateTime.now(ZoneId.of("UTC")).toLocalDateTime();

        List<String> subqueries = new ArrayList<>();

        for (Map.Entry<Integer, Integer> poolLimit : poolLimits.entrySet()) {
            if (poolLimit.getKey() == defaultPoolId) {
                Set<Integer> excludedPoolIds = poolLimits.keySet().stream()
                        .filter(id -> id != defaultPoolId)
                        .collect(Collectors.toSet());
                subqueries.add(
                        "(" +
                                queryWorkflowsToBeScheduledForUnion(idsToExclude, poolLimit.getKey(),
                                        excludedPoolIds, poolLimit.getValue()
                                )
                                + ")"
                );
            } else {
                subqueries.add(
                        "(" +
                                queryWorkflowsToBeScheduledForUnion(idsToExclude, poolLimit.getKey(),
                                        poolLimit.getValue())
                                + ")"
                );
            }
        }

        String query = String.join(" UNION ", subqueries);

        return namedParameterJdbcTemplate.query(
                query, new MapSqlParameterSource(Map.of("now", now)),
                rs -> {
                    Map<UUID, Integer> result = new HashMap<>();
                    while (rs.next()) {
                        int poolId = rs.getInt(3);
                        result.put(((UUID) rs.getObject(1)), poolId);
                    }
                    return result;
                }
        );
    }

    public String queryWorkflowsToBeScheduledForUnion(Set<UUID> idsToExclude, int poolId, int limit) {
        var now = ZonedDateTime.now(ZoneId.of("UTC")).toLocalDateTime();
        return findQuery("w.id, min(we.created_at), " + poolId + " as pool_id", idsToExclude, poolId, null, false)
                + " GROUP BY w.id ORDER BY 2 LIMIT " + limit;
    }

    public String queryWorkflowsToBeScheduledForUnion(Set<UUID> idsToExclude, int poolId, Set<Integer> poolIdsToExclude,
                                                      int limit) {
        var now = ZonedDateTime.now(ZoneId.of("UTC")).toLocalDateTime();
        return findQuery("w.id, min(we.created_at), " + poolId + " as pool_id", idsToExclude, null, poolIdsToExclude,
                false)
                + " GROUP BY w.id ORDER BY 2 LIMIT " + limit;
    }

    public List<UUID> findWorkflowsToBeScheduled(Set<UUID> idsToExclude, int poolId, int limit) {
        var now = ZonedDateTime.now(ZoneId.of("UTC")).toLocalDateTime();
        String query = findQuery("w.id, min(we.created_at)", idsToExclude, poolId, null, true)
                + " GROUP BY w.id ORDER BY 2 LIMIT " + limit;
        return namedParameterJdbcTemplate.query(
                query, new MapSqlParameterSource(Map.of("now", now, "pool_id", poolId)),
                (rs, rowNum) -> (UUID) rs.getObject(1)
        );
    }

    public List<UUID> findWorkflowsToBeScheduled(Set<UUID> idsToExclude, Set<Integer> poolIdsToExclude, int limit) {
        var now = ZonedDateTime.now(ZoneId.of("UTC")).toLocalDateTime();
        String query = findQuery("w.id, min(we.created_at)", idsToExclude, null, poolIdsToExclude, true)
                + " GROUP BY w.id ORDER BY 2 LIMIT " + limit;
        return namedParameterJdbcTemplate.query(
                query, new MapSqlParameterSource("now", now),
                (rs, rowNum) -> (UUID) rs.getObject(1)
        );
    }

    @Override
    public Integer countWorkflowsToBeScheduled(Set<UUID> idsToExclude) {
        var now = ZonedDateTime.now(ZoneId.of("UTC")).toLocalDateTime();
        return namedParameterJdbcTemplate.queryForObject(
                findQuery("count( distinct w.id)", idsToExclude), new MapSqlParameterSource("now", now),
                Integer.class
        );
    }

    @Override
    public Map<Integer, Integer> countWorkflowsToBeScheduledPerPoolId(Set<UUID> idsToExclude) {
        var now = ZonedDateTime.now(ZoneId.of("UTC")).toLocalDateTime();
        String sql = findQuery("w.processing_pool_id, count(distinct w.id) cnt", idsToExclude) +
                " GROUP BY w.processing_pool_id";
        return namedParameterJdbcTemplate.query(
                sql, new MapSqlParameterSource("now", now),
                (rs, rowNum) -> {
                    Integer poolId = rs.getObject("processing_pool_id", Integer.class);
                    Integer count = rs.getObject("cnt", Long.class).intValue();
                    return Pair.of(poolId, count);
                }
        ).stream().collect(Collectors.toMap(Pair::getLeft, Pair::getRight));
    }

    @Override
    public List<UUID> findSupervisedRunningWorkflows(UUID workflowId) {
        return findSupervisedWorkflowsWithState(workflowId, EWorkflowState.WS_RUNNING);
    }

    @Override
    public List<UUID> findSupervisedWorkflowsWithState(UUID workflowId, EWorkflowState state) {
        var params = new MapSqlParameterSource();
        params.addValue("id", workflowId);
        params.addValue("state", state.getNumber());
        return namedParameterJdbcTemplate.query(
                FIND_SUPERVISED_WORKFLOWS_WITH_STATE, params,
                (rs, rowNum) -> (UUID) rs.getObject(1)
        );
    }


    private String findQuery(String selectPart, Set<UUID> idsToExclude) {
        return findQuery(selectPart, idsToExclude, null, null, false);
    }

    private String findQuery(String selectPart, Set<UUID> idsToExclude, Integer poolId, Set<Integer> poolIdsToExclude
            , boolean poolIdParam) {
        StringBuilder query = new StringBuilder("SELECT ");
        query.append(selectPart).append(" FROM workflows w ");
        query.append(
                "INNER JOIN workflow_events we ON w.id = we.workflow_id WHERE (w.sleep_till IS NULL OR w.sleep_till <" +
                        " :now) AND (w.state = ")
                .append(EWorkflowState.WS_RUNNING.getNumber()).append(") AND (we.state = ").
                append(EWorkflowEventState.WES_NEW.getNumber()).append(")");
        if (!idsToExclude.isEmpty()) {
            query.append(" AND w.id NOT IN (")
                    .append(idsToExclude.stream().map(id -> "'" + id + "'").collect(Collectors.joining(",")))
                    .append(")");
        }
        if (poolId != null) {
            if (poolIdParam) {
                query.append(" AND w.processing_pool_id = :pool_id");
            } else {
                query.append(" AND w.processing_pool_id = '").append(poolId).append("'");
            }

        }
        if (poolIdsToExclude != null && !poolIdsToExclude.isEmpty()) {
            query.append(" AND (w.processing_pool_id IS NULL OR w.processing_pool_id NOT IN (")
                    .append(poolIdsToExclude.stream().map(String::valueOf).collect(Collectors.joining(",")))
                    .append("))");
        }
        return query.toString();
    }

    private void assignProcessingPoolIdIfMissing(Workflow workflow) {
        if (workflow.getProcessingPoolId() == null) {
            workflow.setProcessingPoolId(processingPoolsConfig.getProcessingPoolId(workflow.getEntityType()));
        }
    }

    @Override
    public <S extends Workflow> S save(S entity) {
        assignProcessingPoolIdIfMissing(entity);
        return super.save(entity);
    }

    @Override
    public <S extends Workflow> S saveAndFlush(S entity) {
        assignProcessingPoolIdIfMissing(entity);
        return super.saveAndFlush(entity);
    }

    @Override
    public <S extends Workflow> List<S> saveAll(Iterable<S> entities) {
        for (S entity : entities) {
            assignProcessingPoolIdIfMissing(entity);
        }
        return super.saveAll(entities);
    }
}
