package ru.yandex.wmtools.common.util;

import java.sql.ResultSet;
import java.sql.SQLException;
import java.util.Date;
import java.util.List;
import java.util.NavigableMap;
import java.util.TreeMap;

import javax.sql.DataSource;

import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
import org.springframework.dao.DataAccessException;
import org.springframework.dao.IncorrectResultSizeDataAccessException;
import org.springframework.jdbc.core.JdbcTemplate;
import org.springframework.jdbc.core.PreparedStatementCreator;
import org.springframework.jdbc.core.PreparedStatementCreatorFactory;
import org.springframework.jdbc.core.ResultSetExtractor;
import org.springframework.jdbc.core.RowCallbackHandler;
import org.springframework.jdbc.core.SqlTypeValue;
import org.springframework.jdbc.core.simple.ParameterizedRowMapper;
import org.springframework.jdbc.support.GeneratedKeyHolder;

import ru.yandex.common.framework.pager.Pager;
import ru.yandex.common.util.collections.Pair;
import ru.yandex.common.util.db.OrderByClause;
import ru.yandex.wmtools.common.error.ExtraTagInfo;
import ru.yandex.wmtools.common.error.ExtraTagNameEnum;
import ru.yandex.wmtools.common.error.InternalException;
import ru.yandex.wmtools.common.error.InternalProblem;

/**
 * User: baton
 * Date: 11.10.2007
 * Time: 18:06:46
 */
public class ServiceJdbcTemplate implements IServiceJdbcTemplate {
    private static final Logger log = LoggerFactory.getLogger(ServiceJdbcTemplate.class);
    private static final String MASTER_SWITCH_MESSAGE = "The MySQL server is running with the --read-only option so it cannot execute this statement";

    private final JdbcTemplate jdbcTemplate;

    private final ExtraTagInfo databaseInfoExtraTag;
    private long lastConnectionReset;

    public ServiceJdbcTemplate(DataSource dataSource, String databaseInfo) {
        jdbcTemplate = new JdbcTemplate(dataSource);
        this.databaseInfoExtraTag = new ExtraTagInfo(ExtraTagNameEnum.DB_INFO, databaseInfo);
    }

    @Override
    public org.springframework.jdbc.core.JdbcOperations getJdbcOperations() throws InternalException {
        try {
            return jdbcTemplate;
        } catch (RuntimeException e) {
            catchSqlException(e);
            throw new InternalException(InternalProblem.READ_DB_ERROR, createErrorMessage(e), e, databaseInfoExtraTag);
        }
    }

    @Override
    public int queryForInt(final String sqlString, final Object... requestParams) throws InternalException {
        try {
            return jdbcTemplate.queryForInt(sqlString, requestParams);
        } catch (RuntimeException e) {
            catchSqlException(e);
            throw new InternalException(InternalProblem.READ_DB_ERROR, createErrorMessage(e), e, databaseInfoExtraTag);
        }
    }

    @Override
    public long queryForLong(final String sqlString, final Object... requestParams) throws InternalException {
        try {
            return jdbcTemplate.queryForLong(sqlString, requestParams);
        } catch (RuntimeException e) {
            catchSqlException(e);
            throw new InternalException(InternalProblem.READ_DB_ERROR, createErrorMessage(e), e, databaseInfoExtraTag);
        }
    }

    @Override
    public <T> T queryForObject(final String sqlString, final java.lang.Class<T> aClass, final Object... requestParams) throws InternalException {
        try {
            return jdbcTemplate.queryForObject(sqlString, aClass, requestParams);
        } catch (RuntimeException e) {
            catchSqlException(e);
            throw new InternalException(InternalProblem.READ_DB_ERROR, createErrorMessage(e), e, databaseInfoExtraTag);
        }
    }

    @Override
    public <T> T queryForObject(final String sqlString, org.springframework.jdbc.core.simple.ParameterizedRowMapper<T> parameterizedRowMapper, final Object... requestParams) throws InternalException {
        try {
            return jdbcTemplate.queryForObject(sqlString, parameterizedRowMapper, requestParams);
        } catch (RuntimeException e) {
            catchSqlException(e);
            throw new InternalException(InternalProblem.READ_DB_ERROR, createErrorMessage(e), e, databaseInfoExtraTag);
        }
    }

    @Override
    public <T> java.util.List<T> query(final String sqlString, org.springframework.jdbc.core.simple.ParameterizedRowMapper<T> parameterizedRowMapper, final Object... requestParams) throws InternalException {
        try {
            return jdbcTemplate.query(sqlString, parameterizedRowMapper, requestParams);
        } catch (RuntimeException e) {
            catchSqlException(e);
            throw new InternalException(InternalProblem.READ_DB_ERROR, createErrorMessage(e), e, databaseInfoExtraTag);
        }
    }

    @Override
    public java.util.Map<java.lang.String, java.lang.Object> queryForMap(final String sqlString, final Object... requestParams) throws InternalException {
        try {
            return jdbcTemplate.queryForMap(sqlString, requestParams);
        } catch (RuntimeException e) {
            throw new InternalException(InternalProblem.READ_DB_ERROR, createErrorMessage(e), e, databaseInfoExtraTag);
        }
    }

    @Override
    public java.util.List<java.util.Map<java.lang.String, java.lang.Object>> queryForList(final String sqlString, final Object... requestParams) throws InternalException {
        try {
            return jdbcTemplate.queryForList(sqlString, requestParams);
        } catch (RuntimeException e) {
            catchSqlException(e);
            throw new InternalException(InternalProblem.READ_DB_ERROR, createErrorMessage(e), e, databaseInfoExtraTag);
        }
    }

    @Override
    public int update(final String sqlString, final Object... requestParams) throws InternalException {
        try {
            return jdbcTemplate.update(sqlString, requestParams);
        } catch (RuntimeException e) {
            catchSqlException(e);
            throw new InternalException(InternalProblem.WRITE_DB_ERROR, createErrorMessage(e), e, databaseInfoExtraTag);
        }
    }

    @Override
    public Number insertSingle(String sqlString, Object... requestParams) throws InternalException {
        int[] paramTypes = new int[requestParams.length];
        for (int i = 0; i < requestParams.length; i++) {
            paramTypes[i] = SqlTypeValue.TYPE_UNKNOWN;
        }
        PreparedStatementCreatorFactory creatorFactory = new PreparedStatementCreatorFactory(sqlString, paramTypes);
        creatorFactory.setReturnGeneratedKeys(true);
        PreparedStatementCreator ps = creatorFactory.newPreparedStatementCreator(requestParams);
        GeneratedKeyHolder kh = new GeneratedKeyHolder();
        try {

            int updated = getJdbcOperations().update(ps, kh);
            if (updated > 0) {
                return kh.getKey();
            } else {
                return null;
            }
        } catch (RuntimeException e) {
            catchSqlException(e);
            throw new InternalException(InternalProblem.WRITE_DB_ERROR, createErrorMessage(e), e, databaseInfoExtraTag);
        }
    }

    @Override
    public Date safeQueryForTimestamp(final String sqlString, final Object... requestParams) throws InternalException {
        ParameterizedRowMapper<Date> dateRowMapper = new ParameterizedRowMapper<Date>() {
            @Override
            public Date mapRow(ResultSet resultSet, int i) throws SQLException {
                return SqlUtil.safeGetTimestamp(resultSet, 1);
            }
        };

        try {
            List<Date> results = query(sqlString, dateRowMapper, requestParams);
            if ((results == null) || (results.size() != 1)) {
                return null;
            }

            return results.get(0);
        } catch (InternalException e) {
            catchSqlException(e);
            if (e.getCause() instanceof IncorrectResultSizeDataAccessException) {
                return null;
            }
            throw e;
        }
    }

    @Override
    public Integer safeQueryForInt(final String sqlString, final Object... requestParams) throws InternalException {
        try {
            return queryForInt(sqlString, requestParams);
        } catch (InternalException e) {
            catchSqlException(e);
            if (e.getCause() instanceof IncorrectResultSizeDataAccessException) {
                return null;
            }
            throw e;
        }
    }

    @Override
    public Long safeQueryForLong(final String sqlString, final Object... requestParams) throws InternalException {
        try {
            return queryForLong(sqlString, requestParams);
        } catch (InternalException e) {
            catchSqlException(e);
            if (e.getCause() instanceof IncorrectResultSizeDataAccessException) {
                return null;
            }
            throw e;
        }
    }

    @Override
    public String safeQueryForString(final String sqlString, final Object... requestParams) throws InternalException {
        try {
            return queryForObject(sqlString, String.class, requestParams);
        } catch (InternalException e) {
            catchSqlException(e);
            if (e.getCause() instanceof IncorrectResultSizeDataAccessException) {
                return null;
            }
            throw e;
        }
    }

    @Override
    public <E> E safeQueryForObject(final String sqlString, ParameterizedRowMapper<E> parameterizedRowMapper, final Object... requestParams) throws InternalException {
        try {
            return queryForObject(sqlString, parameterizedRowMapper, requestParams);
        } catch (InternalException e) {
            if (e.getCause() instanceof IncorrectResultSizeDataAccessException) {
                return null;
            }
            throw e;
        }
    }

    @Override
    public <K, V> NavigableMap<K, V> queryForNavigableMap(final String sqlString, ParameterizedMapRowMapper<K, V> parameterizedRowMapper, Object... requestParams) throws InternalException {
        try {
            // Considers an Object array passed into a varargs parameter as collection of arguments rather than as single argument.
            if (requestParams.length == 1 && requestParams[0] instanceof Object[]) {
                requestParams = (Object[]) requestParams[0];
            }

            // noinspection unchecked
            return (NavigableMap<K, V>) jdbcTemplate.query(sqlString, requestParams, new MapResultSetExtractor<K, V>(parameterizedRowMapper));
        } catch (RuntimeException e) {
            catchSqlException(e);
            throw new InternalException(InternalProblem.READ_DB_ERROR, createErrorMessage(e), e, databaseInfoExtraTag);
        }
    }

    @Override
    public void query(final String sqlString, final RowCallbackHandler rch, final Object... requestParams) throws InternalException {
        try {
            jdbcTemplate.query(sqlString, requestParams, rch);
        } catch (RuntimeException e) {
            catchSqlException(e);
            throw new InternalException(InternalProblem.READ_DB_ERROR, createErrorMessage(e), e, databaseInfoExtraTag);
        }
    }

    private static class MapResultSetExtractor<K, V> implements ResultSetExtractor {
        private int rowNum = 0;
        private ParameterizedMapRowMapper<K, V> rowMapper;

        public MapResultSetExtractor(ParameterizedMapRowMapper<K, V> rowMapper) {
            this.rowMapper = rowMapper;
        }

        @Override
        public NavigableMap<K, V> extractData(ResultSet resultSet) throws SQLException, DataAccessException {
            NavigableMap<K, V> results = new TreeMap<K, V>();
            while (resultSet.next()) {
                Pair<K, V> pair = rowMapper.mapRow(resultSet, rowNum++);
                results.put(pair.first, pair.second);
            }
            return results;
        }
    }

    /**
     * {@inheritDoc}
     */
    @Override
    public <T> List<T> select(final String countSql, final String selectSql, final ParameterizedRowMapper<T> mapper, final OrderByClause order, final Pager pager, final Object... params) throws InternalException {
        String orderString = "";
        if (order != null) {
            orderString = String.format(" ORDER BY %1$s %2$s ", order.getColumnName(), order.getMode());
        }

        return pageableSelect(countSql, String.format(selectSql, orderString, "%1$s"), mapper, pager, params);
    }

    /**
     * {@inheritDoc}
     */
    @Override
    public <T> List<T> pageableSelect(final String countSql, final String selectSql, final ParameterizedRowMapper<T> mapper, final Pager pager, final Object... params) throws InternalException {
        String limitString = "";
        if (pager != null) {
            pager.setCount(queryForInt(countSql, params));
            limitString = String.format(" LIMIT %1$d, %2$d ", pager.getPageSize() * pager.getPageNum(), pager.getPageSize());
        }

        try {
            return jdbcTemplate.query(String.format(selectSql, limitString), mapper, params);
        } catch (RuntimeException e) {
            catchSqlException(e);
            throw new InternalException(InternalProblem.READ_DB_ERROR, createErrorMessage(e), e, databaseInfoExtraTag);
        }
    }


    /**
     * {@inheritDoc}
     */
    @Override
    public <T> List<T> orderableSelect(final String selectSql, final ParameterizedRowMapper<T> mapper, final OrderByClause order, final Object... params) throws InternalException {
        return select(null, selectSql, mapper, order, null, params);
    }

    private String createErrorMessage(Exception e) {
        return "DB info: " + databaseInfoExtraTag.getValue() + "; Message: " + e.getMessage();
    }

    private void catchSqlException(Throwable exp) {
        int maxStack = 1000;
        Throwable exception = exp;
        while (--maxStack > 0 && exception.getCause() != null && exception.getCause() != exception) {
            exception = exception.getCause();
        }
        if (exception instanceof SQLException) {
            long now = System.currentTimeMillis();
            if (MASTER_SWITCH_MESSAGE.equalsIgnoreCase(exception.getMessage()) && (now - lastConnectionReset) > 60000) {
                // restarting connection
                log.info("Resetting datasource due to master switch");
                if (jdbcTemplate.getDataSource() instanceof org.apache.tomcat.jdbc.pool.DataSource) {
                    org.apache.tomcat.jdbc.pool.DataSource ds = (org.apache.tomcat.jdbc.pool.DataSource) jdbcTemplate.getDataSource();
                    ds.close(true);

                }
                lastConnectionReset = now;
            }
        }
    }
}
