package ru.yandex.webmaster3.storage.util.sql;

import java.util.ArrayList;
import java.util.Collections;
import java.util.List;
import java.util.concurrent.ThreadLocalRandom;

import org.jetbrains.annotations.NotNull;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
import org.springframework.beans.factory.annotation.Required;

/**
 * @author avhaliullin
 */
public class JdbcRouterFrontend implements JdbcProvider {
    private static final Logger log = LoggerFactory.getLogger(JdbcRouterFrontend.class);

    private static final long FAST_CONNECTION_THRESHOLD_MS = 5;

    private final ThreadLocal<IWriteJdbcTemplate> transactions = new ThreadLocal<>();

    private JdbcRouterBackend backend;

    @Override
    public String getDbIdentity() {
        return backend.getDbIdentity();
    }

    @Override
    public TransactionTemplate getTransactionTemplate() {
        return new TransactionTemplateImpl(backend.getTransactionTemplate(),
                backend.getWriteManagedConnection().getTemplate(), transactions);
    }

    @Override
    public TransactionTemplate getLightTransactionTemplate() {
        return new LightTransactionTemplateImpl(backend.getWriteManagedConnection().getTemplate(), transactions);
    }

    @Override
    public IReadJdbcTemplate getReadTemplate() {
        IWriteJdbcTemplate tt = transactions.get();
        if (tt != null) {
            return tt;
        }
        List<ManagedConnection<? extends IReadJdbcTemplate>> connections = backend.getReadManagedConnections();
        if (connections.size() == 1) {
            return connections.get(0).getTemplate();
        }
        long fastestPing = Long.MAX_VALUE;
        for (ManagedConnection<? extends IReadJdbcTemplate> connection : connections) {
            long ping = (connection.getAvgPing() + connection.getLastPing()) / 2;
            if (ping < fastestPing) {
                fastestPing = ping;
            }
        }

        List<ConnectionInfo> infos = new ArrayList<>();
        int index = 0;
        for (ManagedConnection<? extends IReadJdbcTemplate> connection : connections) {
            log.debug("DB{}: id={ {} } stats: {\n  avgQuery: {}\n  lastQuery: {}\n  avgPing: {}\n  lastPing: {}\n  " +
                            "hasError: {}\n  inProgress: {}\n  }",
                    index, connection.getDbIdentity(), connection.getAvgQuery(), connection.getLastQuery(), connection.getAvgPing(), connection.getLastPing(),
                    connection.hasError(), connection.getQueriesInProgress()
            );
            long ping = (connection.getAvgPing() + connection.getLastPing()) / 2;
            boolean slow = ping > FAST_CONNECTION_THRESHOLD_MS && ping > 2 * fastestPing;
            infos.add(new ConnectionInfo(connection.getTemplate(), slow, connection.hasError(), connection.getQueriesInProgress(),
                    connection.getAvgQuery(), ThreadLocalRandom.current().nextInt(), index));
            index++;
        }

        if (log.isDebugEnabled()) {
            Collections.sort(infos);
            for (ConnectionInfo info : infos) {
                log.debug("DB Order: {}", info.id);
            }
            return infos.get(0).template;
        }
        return Collections.min(infos).template;
    }

    @Override
    public IWriteJdbcTemplate getWriteTemplate() {
        return backend.getWriteManagedConnection().getTemplate();
    }

    @Required
    public void setBackend(JdbcRouterBackend backend) {
        this.backend = backend;
    }

    private class ConnectionInfo implements Comparable<ConnectionInfo> {
        public final IReadJdbcTemplate template;
        public final boolean slow;
        public final boolean errors;
        public final int processingQueries;
        public final long avgQuery;
        public final int rndId;
        public final int id;

        private ConnectionInfo(IReadJdbcTemplate template, boolean slow, boolean errors, int processingQueries, long avgQuery, int rndId, int id) {
            this.template = template;
            this.slow = slow;
            this.errors = errors;
            this.processingQueries = processingQueries;
            this.avgQuery = avgQuery;
            this.rndId = rndId;
            this.id = id;
        }

        private int errors() {
            return errors ? 1 : 0;
        }

        private int slow() {
            return slow ? 1 : 0;
        }

        @Override
        public int compareTo(@NotNull ConnectionInfo o) {
            return ((errors() - o.errors()) << 4) +
                    ((slow() - o.slow()) << 3) +
                    ((int) Math.signum(processingQueries - o.processingQueries) << 2) +
                    ((int) Math.signum(avgQuery - o.avgQuery) << 1) +
                    ((int) Math.signum(rndId - o.rndId) << 1)
                    ;
        }
    }
}
