package ru.yandex.direct.dbutil.wrapper;

import java.util.List;
import java.util.Map;
import java.util.concurrent.ConcurrentHashMap;
import java.util.concurrent.atomic.AtomicLong;

import one.util.streamex.EntryStream;
import org.jooq.ExecuteContext;
import org.jooq.Query;
import org.jooq.impl.DefaultExecuteListener;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

/**
 * Логирует sql-запрос. Работает при уровне debug и ниже, иначе никакого overhead'а не несет.
 * На уровне trace логирует stacktrace с фильтром по "ru.yandex.direct", на уровне debug только метод репозитория.
 * Пример записи (будет одна строка):
 * <p>
 * <code>2018-02-27:15:52:24 ppcdev2.yandex.ru,direct.web/minus_keyword.add,3146065076970765545:0:3146065076970765545
 * [jetty-worker-1-13] DEBUG ru.yandex.direct.dbutil.wrapper.SqlQueriesLogger - insert into `added_phrases_cache`
 * (`cid`, `pid`, `bids_id`, `type`, `phrase_hash`, `add_date`) values (?, ?, ?, ?, ?, ?) on duplicate key update
 * `added_phrases_cache`.`add_date` = ?; with values (15126730, 3488021006, 0, minus, 17156080074884823802,
 * 2018-02-27T15:52:24.037, 2018-02-27T15:52:24.038); KeywordCacheRepository.addKeywordsCache:126 </code>
 */
public class SqlQueriesLogger extends DefaultExecuteListener {
    private final String dbName;

    private static final Logger logger = LoggerFactory.getLogger(SqlQueriesLogger.class);

    private static final String PACKAGE_PREFIX = "ru.yandex.direct";

    // Статистика методов по количеству вызовов для всех баз
    private static volatile ConcurrentHashMap<String, ConcurrentHashMap<String, AtomicLong>> allDbCalls =
            new ConcurrentHashMap<>();

    public SqlQueriesLogger(String dbName) {
        this.dbName = dbName;
    }

    @Override
    public void executeEnd(ExecuteContext ctx) {
        if (!logger.isDebugEnabled()) {
            return;
        }
        StringBuilder logRecord = new StringBuilder();
        addQueries(ctx, logRecord);
        addStacktraceAndUpdateStats(logRecord);
        logger.debug(logRecord.toString());
    }

    /**
     * Добавляет в лог sql-запросы. Хранимые процедуры не логируются, если понадобится,
     * см. {@link ExecuteContext#routine()}
     */
    private void addQueries(ExecuteContext ctx, StringBuilder logRecord) {
        Query[] queries = ctx.batchQueries();
        if (queries.length == 0) {
            return;
        }
        for (Query query : queries) {
            logRecord.append(query.getSQL());
            List<Object> bindValues = query.getBindValues();
            addBindValues(bindValues, logRecord);
            logRecord.append("; ");
        }
    }

    /**
     * Если есть значения, которые нужно подставить (в запросе были placeholder'ы), добавляет их к записи лога
     */
    private void addBindValues(List<Object> bindValues, StringBuilder logRecord) {
        if (bindValues.isEmpty()) {
            return;
        }
        logRecord.append("; with values (");

        boolean first = true;
        for (Object bindValue : bindValues) {
            if (!first) {
                logRecord.append(", ");
            }
            if (bindValue != null && bindValue.toString().equals("")) {
                logRecord.append("''");
            } else {
                logRecord.append(bindValue);
            }
            first = false;
        }
        logRecord.append(")");
    }

    /**
     * В режиме debug добавляет в лог метод репозитория, из которого был сделан запрос. В режиме trace все методы,
     * начиная с репозитория, фильтруя по {@link SqlQueriesLogger#PACKAGE_PREFIX}
     */
    private void addStacktraceAndUpdateStats(StringBuilder logRecord) {
        StackTraceElement[] stackTrace = Thread.currentThread().getStackTrace();
        boolean first = true;
        for (StackTraceElement stackTraceElement : stackTrace) {
            String className = stackTraceElement.getClassName();
            if (className.startsWith(PACKAGE_PREFIX) && !stackTraceElement.getClassName()
                    .equals(SqlQueriesLogger.class.getCanonicalName())) {
                if (!first) {
                    logRecord.append(", ");
                }
                String simpleClassName = className.substring(className.lastIndexOf('.') + 1);
                String methodName = stackTraceElement.getMethodName();
                logRecord.append(simpleClassName).append(".").append(methodName)
                        .append(":").append(stackTraceElement.getLineNumber());
                if (first) {
                    String method = simpleClassName + "." + methodName;
                    AtomicLong callsCount = allDbCalls
                            .computeIfAbsent(dbName, name -> new ConcurrentHashMap<>())
                            .computeIfAbsent(method, name -> new AtomicLong());
                    callsCount.incrementAndGet();
                }
                if (!logger.isTraceEnabled()) {
                    break;
                }
                first = false;
            }
        }
    }
    public static void resetAllCalsStatistics() {
        allDbCalls = new ConcurrentHashMap<>();
    }

    public static String getAllCallsStatistics() {
        StringBuilder statisticsText = new StringBuilder();
        Map<String, Map<String, Long>> allDbCallsCopy = EntryStream.of(allDbCalls)
                .mapValues(SqlQueriesLogger::copyMethodsMap)
                .toMap();
        Map<String, Long> dbCalls = EntryStream.of(allDbCallsCopy)
                .mapValues(innerMap -> innerMap.values().stream().mapToLong(Long::longValue).sum())
                .toMap();
        long totalCount = dbCalls.values().stream().mapToLong(count -> count).sum();

        statisticsText.append(String.format("SqlQueriesLogger stat. Total queries count: %d, stat by db:%n", totalCount));

        dbCalls.entrySet().stream().sorted(Map.Entry.<String, Long>comparingByValue().reversed())
                .forEach(e -> {
                    String dbName = e.getKey();
                    long count = e.getValue();
                    statisticsText.append(String.format("Stat for db: %s, calls count: %d:%n", dbName, count));
                    appendDbCallsStatistics(statisticsText, "\t", allDbCallsCopy.get(dbName));
                    statisticsText.append('\n');
                });
        return statisticsText.toString();
    }

    private static void appendDbCallsStatistics(
            StringBuilder statisticsText, String prefix, Map<String, Long> methodsCallsMap) {
        methodsCallsMap.entrySet().stream()
                .sorted(Map.Entry.<String, Long>comparingByValue().reversed())
                .forEach(e -> statisticsText.append(String.format("%s%s: %d%n", prefix, e.getKey(), e.getValue())));
    }

    private static Map<String, Long> copyMethodsMap(ConcurrentHashMap<String, AtomicLong> sourceMap) {
        return EntryStream.of(sourceMap)
                .mapValues(AtomicLong::longValue)
                .toMap();
    }

}
