package ru.yandex.direct.dbutil.wrapper;

import java.sql.Connection;
import java.sql.PreparedStatement;
import java.sql.ResultSet;
import java.sql.SQLException;
import java.util.List;
import java.util.function.Consumer;
import java.util.function.Function;
import java.util.regex.Matcher;
import java.util.regex.Pattern;

import javax.annotation.ParametersAreNonnullByDefault;
import javax.sql.DataSource;

import org.jooq.DSLContext;
import org.jooq.ExecuteContext;
import org.jooq.Query;
import org.jooq.SQLDialect;
import org.jooq.TransactionalRunnable;
import org.jooq.conf.ParamType;
import org.jooq.conf.Settings;
import org.jooq.impl.DSL;
import org.jooq.impl.DefaultConfiguration;
import org.jooq.impl.DefaultExecuteListener;
import org.jooq.tools.LoggerListener;
import org.jooq.tools.StopWatchListener;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
import org.springframework.dao.DataAccessException;
import org.springframework.jdbc.core.JdbcOperations;
import org.springframework.jdbc.core.JdbcTemplate;
import org.springframework.jdbc.core.RowMapper;
import org.springframework.jdbc.core.namedparam.NamedParameterJdbcOperations;
import org.springframework.jdbc.core.namedparam.NamedParameterJdbcTemplate;
import org.springframework.jdbc.core.namedparam.SqlParameterSource;
import org.springframework.jdbc.datasource.TransactionAwareDataSourceProxy;

import ru.yandex.clickhouse.ClickHousePreparedStatement;
import ru.yandex.clickhouse.response.ClickHouseResponse;
import ru.yandex.direct.dbutil.exception.DbUtilSqlException;
import ru.yandex.direct.env.EnvironmentType;
import ru.yandex.direct.tracing.Trace;
import ru.yandex.direct.tracing.TraceProfile;
import ru.yandex.direct.tracing.util.TraceUtil;
import ru.yandex.direct.utils.SystemUtils;

/**
 * Wraps an underlying DSLContext and JdbcTemplate adding tracing information
 */
@ParametersAreNonnullByDefault
public class DatabaseWrapper {
    private static final Logger logger = LoggerFactory.getLogger(DatabaseWrapper.class);

    private final String dbname;
    private final DataSource dataSource;
    private final JdbcOperations jdbcTemplate;
    private final NamedParameterJdbcOperations namedParameterJdbcTemplate;

    private final SQLDialect sqlDialect;
    private final DSLContext dslContext;

    private final Function<String, String> queryTagRenderer;

    /**
     * Note: Отключает запрос sql warnings в jooq, поскольку в MySql jdbc драйвере выгрузка предупреждений
     * приводит к дополнительному sql запросу. Как-будто бы в протоколе обмена с MySql есть поле с числом предупреждений,
     * но в jdbc нет способа узнать число предупреждений без их реального получения.
     *
     * @param dbname
     * @param dataSource
     * @param sqlDialect
     * @see org.jooq.impl.Utils#consumeWarnings(ExecuteContext, org.jooq.ExecuteListener)
     * @see java.sql.Statement#getWarnings()
     */
    public DatabaseWrapper(String dbname, DataSource dataSource, SQLDialect sqlDialect, EnvironmentType environmentType) {
        this.dbname = dbname;
        this.dataSource = dataSource;
        this.sqlDialect = sqlDialect;
        this.jdbcTemplate = new JdbcTemplate(dataSource);
        this.namedParameterJdbcTemplate = new NamedParameterJdbcTemplate(jdbcTemplate);
        this.dslContext = DSL.using(new DefaultConfiguration()
                .set(new TransactionAwareDataSourceProxy(dataSource))
                .set(new Settings().withRenderSchema(false).withFetchWarnings(false).withExecuteLogging(false))
                .set(() -> new ProfilingListener(dbname),
                        () -> new QueryTagger(environmentType),
                        () -> new LoggerListener(),
                        () -> new StopWatchListener(),
                        () -> new StackTraceLoggerListener(dbname),
                        () -> new SqlQueriesLogger(dbname)
                )
                .set(sqlDialect));
        this.queryTagRenderer = new QueryTagger(environmentType)::render;
    }

    /**
     * Returns the database name
     */
    public String getDbname() {
        return dbname;
    }

    /**
     * Returns the underlying DataSource instance
     */
    public DataSource getDataSource() {
        return dataSource;
    }

    /**
     * Returns the underlying JdbcTemplate instance
     */
    public JdbcOperations getJdbcTemplate() {
        return jdbcTemplate;
    }

    /**
     * Returns the underlying DSL context for SQL generation
     */
    public DSLContext getDslContext() {
        return dslContext;
    }

    /**
     * Returns the underlying NamedParameterJdbcTemplate instance
     */
    public NamedParameterJdbcOperations getNamedParameterJdbcTemplate() {
        return namedParameterJdbcTemplate;
    }

    /**
     * Tag SQL query with trace comment.
     */
    public String render(String sql) {
        return queryTagRenderer.apply(sql);
    }

    /**
     * Calls the specified lambda with the underlying jdbc template as argument
     */
    public <R> R query(Function<JdbcOperations, R> function) throws DataAccessException {
        TraceProfile profile = Trace.current().profile("db:query", dbname);
        try {
            return function.apply(jdbcTemplate);
        } finally {
            profile.close();
        }
    }

    /**
     * Calls the specified lambda with the underlying jdbc template as argument
     */
    public <R> R queryWithNamedParameters(Function<NamedParameterJdbcOperations, R> function)
            throws DataAccessException {
        TraceProfile profile = Trace.current().profile("db:query", dbname);
        try {
            return function.apply(namedParameterJdbcTemplate);
        } finally {
            profile.close();
        }
    }

    /**
     * Makes query using the underlying jdbc template instance
     */
    public <T> List<T> query(Query query, RowMapper<T> rowMapper) throws DataAccessException {
        String sql = query.getSQL(ParamType.INDEXED);
        Object[] args = query.getBindValues().toArray();
        return query(jdbc -> jdbc.query(sql, args, rowMapper));
    }

    /**
     * Makes query using the underlying jdbc template instance
     */
    public <T> List<T> query(String sql, Object[] args, RowMapper<T> rowMapper) throws DataAccessException {
        return query(jdbc -> jdbc.query(sql, args, rowMapper));
    }

    /**
     * Makes query using the underlying jdbc template instance
     */
    public void rawQuery(String sql, Object[] args, Consumer<ResultSet> processer) {
        try (
                TraceProfile ignore = Trace.current().profile("db:query", dbname);
                Connection connection = dataSource.getConnection();
                PreparedStatement preparedStatement = connection.prepareStatement(sql);
        ) {
            for (int i = 0; i < args.length; i++) {
                preparedStatement.setObject(i + 1, args[i]);
            }
            try (ResultSet resultSet = preparedStatement.executeQuery();) {
                processer.accept(resultSet);
            }
        } catch (SQLException e) {
            throw new DbUtilSqlException(e);
        }
    }

    /**
     * Makes query using the underlying jdbc template instance
     */
    public <T> List<T> query(String sql, RowMapper<T> rowMapper, Object... args) throws DataAccessException {
        return query(jdbc -> jdbc.query(sql, rowMapper, args));
    }

    /**
     * Makes query using the underlying jdbc template instance
     */
    public <T> List<T> query(String sql, SqlParameterSource params, RowMapper<T> rowMapper) throws DataAccessException {
        return queryWithNamedParameters(jdbc -> jdbc.query(sql, params, rowMapper));
    }

    /**
     * Ping one of connection
     *
     * @return true if connections is alive, false overwise
     * @throws IllegalStateException if sql dialect is not supported
     */
    public boolean isAlive() {
        if (sqlDialect.equals(SQLDialect.MYSQL)) {
            try (TraceProfile ignore = Trace.current().profile("db:ping", dbname);) {
                getJdbcTemplate().query("SELECT 1", (rs) -> {
                });
                return true;
            } catch (RuntimeException e) {
                logger.error("Can't ping database " + dbname, e);
                return false;
            }
        } else {
            throw new IllegalStateException(String.format("Ping for %s(%s) is not supported", dbname, sqlDialect));
        }
    }

    /**
     * Выполнить данное действие {@param transaction} в рамках транзакции
     */
    public void transaction(TransactionalRunnable transactional) {
        dslContext.transaction(transactional);
    }

    /**
     * Выполнить запрос в Clickhouse, вернуть результат в виде ClickHouseResponse.
     * В этом виде есть доступ к totals/extremes/...
     */
    public ClickHouseResponse clickhouseQuery(String sql, Object[] args) {
        try (
                TraceProfile ignore = Trace.current().profile("db:query", dbname);
                Connection connection = dataSource.getConnection();
                PreparedStatement preparedStatement = connection.prepareStatement(sql);
        ) {
            if (!(preparedStatement instanceof ClickHousePreparedStatement)) {
                throw new IllegalStateException(String.format("clickhouseQuery is not supported for %s",
                        preparedStatement.getClass().getCanonicalName()));
            }
            for (int i = 0; i < args.length; i++) {
                preparedStatement.setObject(i + 1, args[i]);
            }
            return ((ClickHousePreparedStatement) preparedStatement).executeQueryClickhouseResponse();
        } catch (SQLException e) {
            throw new DbUtilSqlException(e);
        }
    }

    public static class QueryTagger extends DefaultExecuteListener {
        private static final Pattern FIRST_WORD_PATTERN = Pattern.compile("^(\\s*\\w+)\\s(.*)", Pattern.DOTALL);

        private final String developmentMessage;

        QueryTagger(EnvironmentType environmentType) {
            developmentMessage = environmentType.isDevelopment() ?
                    String.format("/* %s@%s */", System.getProperty("user.name"), SystemUtils.hostname()) :
                    null;
        }

        @Override
        public void renderEnd(ExecuteContext ctx) {
            String traceComment = TraceUtil.getTraceSqlComment();
            String[] batchSql = ctx.batchSQL();
            for (int i = 0; i < batchSql.length; i++) {
                String sql = batchSql[i];

                // sql может быть null, см. jOOQ-ий  BatchMultiple.execute, так как данный метод будет
                // вызваться для каждого statement-а из batch-а. Также если sql == null, то начиная с данного
                // элемента идут еще не сформированные куски batch, поэтому выходим
                if (sql == null) {
                    break;
                }

                sql = render(sql, traceComment);
                if (i == 0) {
                    ctx.sql(sql);
                }
                batchSql[i] = sql;
            }
        }

        String render(String sql) {
            return render(sql, TraceUtil.getTraceSqlComment());
        }

        /**
         * @param sql          SQL-запрос
         * @param traceComment комментарий трейсинга, добавляемый в конец запроса
         * @return SQL-выражение с комментарием. Если trace-комментарий уже содержится
         * в переданном запросе, последний вернётся без изменений.
         */
        private String render(String sql, String traceComment) {
            // Так как данная часть может выполняться многократно в случае batch из нескольких statement-ов (см. выше),
            // то мы должны проверять что мы уже не обрабатывали данный statement
            Matcher firstWordMatcher = FIRST_WORD_PATTERN.matcher(sql);
            if (sql.contains(traceComment)) { // Мы уже видели данный batch?
                return sql;
            }

            if (developmentMessage != null) {
                traceComment = traceComment + " " + developmentMessage;
            }

            if (firstWordMatcher.find()) {
                String firstWord = firstWordMatcher.group(1);
                String remainChars = firstWordMatcher.group(2);
                sql = firstWord + " " + traceComment + " " + remainChars;
            } else if (!sql.startsWith("{")) {
                // в вызовы процедур не добавляем комментарий
                sql = traceComment + " " + sql;
            }
            return sql;
        }
    }
}
