package ru.yandex.qe.dispenser.domain.dao;

import java.util.ArrayList;
import java.util.Collection;
import java.util.Collections;
import java.util.HashSet;
import java.util.List;
import java.util.Map;
import java.util.Objects;
import java.util.Optional;
import java.util.Set;
import java.util.concurrent.TimeUnit;
import java.util.function.BiConsumer;
import java.util.function.Supplier;

import javax.inject.Inject;

import com.google.common.base.Stopwatch;
import com.google.common.collect.ImmutableMap;
import org.jetbrains.annotations.NotNull;
import org.jetbrains.annotations.Nullable;
import org.postgresql.util.PSQLException;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
import org.springframework.dao.DataAccessException;
import org.springframework.dao.DataAccessResourceFailureException;
import org.springframework.jdbc.core.ConnectionCallback;
import org.springframework.jdbc.core.JdbcTemplate;
import org.springframework.jdbc.core.PreparedStatementCallback;
import org.springframework.jdbc.core.ResultSetExtractor;
import org.springframework.jdbc.core.RowCallbackHandler;
import org.springframework.jdbc.core.RowMapper;
import org.springframework.jdbc.core.namedparam.EmptySqlParameterSource;
import org.springframework.jdbc.core.namedparam.MapSqlParameterSource;
import org.springframework.jdbc.core.namedparam.NamedParameterJdbcTemplate;
import org.springframework.jdbc.core.namedparam.SqlParameterSource;
import org.springframework.jdbc.support.GeneratedKeyHolder;
import org.springframework.jdbc.support.KeyHolder;

import ru.yandex.qe.dispenser.domain.hierarchy.Session;

public final class DiJdbcTemplate {
    private static final Logger LOG = LoggerFactory.getLogger(DiJdbcTemplate.class);
    public static final String CONNECTION_FAILURE_SQL_STATE_CODE = "08006";

    private final NamedParameterJdbcTemplate jdbcTemplate;
    private final NamedParameterJdbcTemplate readOnlyJdbcTemplate;
    private final JdbcTemplate jdbcTemplateOrigin;

    private final int retryCount;
    private final long retryDelay;
    private final long retriesWithoutDelay;

    @NotNull
    private final List<BiConsumer<String, Object>> queryListeners = new ArrayList<>();

    @Inject
    public DiJdbcTemplate(@NotNull final NamedParameterJdbcTemplate jdbcTemplate,
                          @NotNull final NamedParameterJdbcTemplate readOnlyJdbcTemplate,
                          @NotNull final JdbcTemplate jdbcTemplateOrigin,
                          final int retryCount,
                          final long retryDelay,
                          final long retriesWithoutDelay) {
        this.jdbcTemplate = jdbcTemplate;
        this.readOnlyJdbcTemplate = readOnlyJdbcTemplate;
        this.jdbcTemplateOrigin = jdbcTemplateOrigin;
        this.retryCount = retryCount;
        this.retryDelay = retryDelay;
        this.retriesWithoutDelay = retriesWithoutDelay;
    }

    public void execute(@NotNull final String sql) throws DataAccessException {
        process(sql, null, () -> jdbcTemplateOrigin.execute(sql));
    }

    public void batchUpdate(@NotNull final String... sql) throws DataAccessException {
        withRetry(() -> jdbcTemplateOrigin.batchUpdate(sql));
    }

    public <T> T execute(final ConnectionCallback<T> action) throws DataAccessException {
        return withRetry(() -> jdbcTemplateOrigin.execute(action));
    }

    private NamedParameterJdbcTemplate getJdbcTemplate() {
        return Session.IS_READONLY.get() ? readOnlyJdbcTemplate : jdbcTemplate;
    }

    @NotNull
    public <T> T execute(@NotNull final String sql, @NotNull final SqlParameterSource paramSource, @NotNull final PreparedStatementCallback<T> action) throws DataAccessException {
        return process(sql, paramSource, () -> getJdbcTemplate().execute(sql, paramSource, action));
    }

    public <T> T query(@NotNull final String sql, @NotNull final SqlParameterSource paramSource, @NotNull final ResultSetExtractor<T> rse) throws DataAccessException {
        return process(sql, paramSource, () -> getJdbcTemplate().query(sql, paramSource, rse));
    }

    public <T> T query(@NotNull final String sql, @NotNull final ResultSetExtractor<T> rse) throws DataAccessException {
        return process(sql, EmptySqlParameterSource.INSTANCE, () -> getJdbcTemplate().query(sql, rse));
    }

    @NotNull
    public <T> List<T> query(@NotNull final String sql, @NotNull final RowMapper<T> rowMapper) throws DataAccessException {
        return process(sql, null, () -> getJdbcTemplate().query(sql, log(rowMapper)));
    }

    public void query(@NotNull final String sql, @NotNull final RowCallbackHandler rch) throws DataAccessException {
        process(sql, Collections.emptyMap(), () -> getJdbcTemplate().query(sql, log(rch)));
    }

    @NotNull
    public <T> List<T> query(@NotNull final String sql, @NotNull final SqlParameterSource paramSource, @NotNull final RowMapper<T> rowMapper) throws DataAccessException {
        return process(sql, paramSource, () -> getJdbcTemplate().query(sql, paramSource, log(rowMapper)));
    }

    @NotNull
    public <T> List<T> query(@NotNull final String sql, @NotNull final Map<String, ?> paramMap, @NotNull final RowMapper<T> rowMapper) throws DataAccessException {
        return process(sql, paramMap, () -> getJdbcTemplate().query(sql, paramMap, log(rowMapper)));
    }

    public void query(@NotNull final String sql, @NotNull final Map<String, ?> paramMap, @NotNull final RowCallbackHandler rch) throws DataAccessException {
        process(sql, paramMap, () -> getJdbcTemplate().query(sql, paramMap, log(rch)));
    }

    public void query(@NotNull final String sql, @NotNull final SqlParameterSource paramSource, @NotNull final RowCallbackHandler rch) throws DataAccessException {
        process(sql, paramSource, () -> getJdbcTemplate().query(sql, paramSource, log(rch)));
    }

    @NotNull
    public <T> T queryForObject(@NotNull final String sql, @NotNull final SqlParameterSource paramSource, @NotNull final RowMapper<T> rowMapper) throws DataAccessException {
        return process(sql, paramSource, () -> getJdbcTemplate().queryForObject(sql, paramSource, log(rowMapper)));
    }

    @NotNull
    public <T> T queryForObject(@NotNull final String sql, @NotNull final Map<String, ?> paramMap, @NotNull final RowMapper<T> rowMapper) throws DataAccessException {
        return process(sql, paramMap, () -> getJdbcTemplate().queryForObject(sql, paramMap, log(rowMapper)));
    }

    @NotNull
    public <T> T queryForObject(@NotNull final String sql, @NotNull final Map<String, ?> paramMap, @NotNull final Class<T> requiredType) throws DataAccessException {
        return process(sql, paramMap, () -> getJdbcTemplate().queryForObject(sql, paramMap, requiredType));
    }

    @NotNull
    public <T> Optional<T> queryForOptional(@NotNull final String sql, @NotNull final Map<String, ?> paramMap, @NotNull final RowMapper<T> rowMapper) throws DataAccessException {
        return query(sql, paramMap, rowMapper).stream().findFirst();
    }

    @NotNull
    public <T> Set<T> queryForSet(@NotNull final String sql, @NotNull final RowMapper<T> rowMapper) throws DataAccessException {
        return new HashSet<>(query(sql, log(rowMapper)));
    }

    @NotNull
    public <T> Set<T> queryForSet(@NotNull final String sql, @NotNull final Map<String, ?> paramMap, @NotNull final RowMapper<T> rowMapper) throws DataAccessException {
        return new HashSet<>(query(sql, paramMap, log(rowMapper)));
    }

    public int update(@NotNull final String sql) {
        return update(sql, Collections.emptyMap());
    }

    public int update(@NotNull final String sql, @NotNull final String key, @NotNull final Object value) {
        return update(sql, ImmutableMap.of(key, value));
    }

    public int update(@NotNull final String sql, @NotNull final String key1, @NotNull final Object value1, @NotNull final String key2, @NotNull final Object value2) {
        return update(sql, ImmutableMap.of(key1, value1, key2, value2));
    }

    public int update(@NotNull final String sql, @NotNull final SqlParameterSource paramSource) throws DataAccessException {
        return process(sql, paramSource, () -> getJdbcTemplate().update(sql, paramSource));
    }

    public int update(@NotNull final String sql, @NotNull final Map<String, ?> paramMap) throws DataAccessException {
        return process(sql, paramMap, () -> getJdbcTemplate().update(sql, paramMap));
    }

    public int update(@NotNull final String sql, @NotNull final Map<String, ?> paramMap, @NotNull final KeyHolder generatedKeyHolder) throws DataAccessException {
        return process(sql, paramMap, () -> getJdbcTemplate().update(sql, new MapSqlParameterSource(paramMap), generatedKeyHolder));
    }

    public int update(@NotNull final String sql, @NotNull final Map<String, ?> paramSource, @NotNull final KeyHolder generatedKeyHolder, @NotNull final String... keyColumnNames) throws DataAccessException {
        return process(sql, paramSource, () -> getJdbcTemplate().update(sql, new MapSqlParameterSource(paramSource), generatedKeyHolder, keyColumnNames));
    }

    public long insert(@NotNull final String sql, @NotNull final Map<String, ?> paramSource) {
        return Objects.requireNonNull(insertIfAbsent(sql, paramSource), "No generated 'id' field!");
    }

    @Nullable
    public Long insertIfAbsent(@NotNull final String sql, @NotNull final Map<String, ?> paramSource) {
        final GeneratedKeyHolder keyHolder = new GeneratedKeyHolder();
        process(sql, paramSource, () -> getJdbcTemplate().update(sql, new MapSqlParameterSource(paramSource), keyHolder, new String[]{"id"
        }));
        return Optional.ofNullable(keyHolder.getKey()).map(Number::longValue).orElse(null);
    }

    @NotNull
    public Map<String, Object> insertAndReturnAutoGenKey(@NotNull final String sql, @NotNull final Map<String, ?> paramSource, final String[] keyFields) {
        final GeneratedKeyHolder keyHolder = new GeneratedKeyHolder();
        process(sql, paramSource, () -> getJdbcTemplate().update(sql, new MapSqlParameterSource(paramSource), keyHolder, keyFields));
        return Objects.requireNonNull(keyHolder.getKeys(), "No generated fields!");
    }

    @NotNull
    public Map<String, Object> insertAndReturnAutoGenKey(@NotNull final String sql, @NotNull final Map<String, ?> paramSource) {
        final GeneratedKeyHolder keyHolder = new GeneratedKeyHolder();
        process(sql, paramSource, () -> getJdbcTemplate().update(sql, new MapSqlParameterSource(paramSource), keyHolder));
        return Objects.requireNonNull(keyHolder.getKeys(), "No generated fields!");
    }

    @NotNull
    public int[] batchUpdate(@NotNull final String sql, @NotNull final Collection<SqlParameterSource> batchArgs) {
        return batchUpdate(sql, batchArgs.toArray(new SqlParameterSource[batchArgs.size()]));
    }

    @NotNull
    @SuppressWarnings("unchecked")
    public int[] batchUpdate(@NotNull final String sql, @NotNull final List<Map<String, ?>> batchArgs) {
        return batchUpdate(sql, batchArgs.toArray(new Map[batchArgs.size()]));
    }

    @NotNull
    public int[] batchUpdate(@NotNull final String sql, @NotNull final SqlParameterSource... batchArgs) {
        return process(sql, batchArgs, () -> getJdbcTemplate().batchUpdate(sql, batchArgs));
    }

    @NotNull
    public int[] batchUpdate(@NotNull final String sql, @NotNull final Map<String, ?>[] batchArgs) {
        return process(sql, batchArgs, () -> getJdbcTemplate().batchUpdate(sql, batchArgs));
    }

    public boolean addQueryListener(@NotNull final BiConsumer<String, Object> queryListener) {
        return queryListeners.add(queryListener);
    }

    public boolean removeQueryListener(@NotNull final BiConsumer<String, Object> queryListener) {
        return queryListeners.remove(queryListener);
    }

    private void process(@NotNull final String sql, @Nullable final Object params, @NotNull final Runnable query) throws DataAccessException {
        process(sql, params, () -> {
            query.run();
            return null;
        });
    }

    private <T> T process(@NotNull final String sql, @Nullable final Object params, @NotNull final Supplier<T> query) throws DataAccessException {
        final Stopwatch stopwatch = Stopwatch.createStarted();
        String paramsString = "";
        if (LOG.isDebugEnabled() && params != null) {
            paramsString = " with params " + SqlUtils.toString(params);
        }

        final T result;
        try {
            result = withRetry(query);
        } catch (RuntimeException e) {
            LOG.error("Finished with error in " + stopwatch.elapsed(TimeUnit.MILLISECONDS) + " ms: '" + sql + "'" + paramsString, e);
            throw e;
        }
        final long time = stopwatch.elapsed(TimeUnit.MILLISECONDS);
        if (time < 50) {
            LOG.debug("Finished in {} ms: '{}'{}", time, sql, paramsString);
        } else {
            LOG.warn("Finished in {} ms: '{}'{} ( longrequest )", time, sql, paramsString);
        }

        queryListeners.forEach(listener -> {
            try {
                listener.accept(sql, params);
            } catch (RuntimeException e) {
                LOG.error("Query listener '" + listener + "' failed with error", e);
            }
        });

        return result;
    }

    private <T> T withRetry(final Supplier<T> supplier) throws DataAccessException {
        DataAccessException last = null;
        for (int i = 0; i < 1 + retryCount; i++) {
            try {
                return supplier.get();
            } catch (DataAccessResourceFailureException e) {
                if (!(e.getCause() instanceof PSQLException)
                        || !(((PSQLException) e.getCause()).getSQLState().equals(CONNECTION_FAILURE_SQL_STATE_CODE))) {
                    throw e;
                }
                LOG.warn("Couldn't execute query. Attempt: " + (i + 1), e);
                last = e;
            }
            if (i >= retriesWithoutDelay && i != retryCount) {
                try {
                    Thread.sleep(retryDelay);
                } catch (InterruptedException e) {
                    LOG.warn("Thread was interrupted during sleep", e);
                }
            }
        }
        LOG.error("Failed to execute query after {} attempts", 1 + retryCount);
        throw last;
    }

    @NotNull
    private <T> ResultSetExtractor<T> log(@NotNull final ResultSetExtractor<T> rse) {
        if (!LOG.isTraceEnabled()) {
            return rse;
        }
        return rs -> {
            if (LOG.isTraceEnabled()) {
                LOG.trace("Result set extracting: {}", SqlUtils.toString(rs));
            }
            return rse.extractData(rs);
        };
    }

    @NotNull
    private <T> RowMapper<T> log(@NotNull final RowMapper<T> rm) {
        if (!LOG.isTraceEnabled()) {
            return rm;
        }
        return (rs, i) -> {
            if (LOG.isTraceEnabled()) {
                LOG.trace("Row mapping: {}", SqlUtils.toString(rs));
            }
            return rm.mapRow(rs, i);
        };
    }

    @NotNull
    private RowCallbackHandler log(@NotNull final RowCallbackHandler rch) {
        if (!LOG.isTraceEnabled()) {
            return rch;
        }
        return rs -> {
            if (LOG.isTraceEnabled()) {
                LOG.trace("Row callback handling: {}", SqlUtils.toString(rs));
            }
            rch.processRow(rs);
        };
    }
}
