package ru.yandex.webmaster3.storage.util.sql;

import java.util.ArrayList;
import java.util.LinkedList;
import java.util.List;
import java.util.Queue;
import java.util.concurrent.Executors;
import java.util.concurrent.ScheduledExecutorService;
import java.util.concurrent.ThreadFactory;
import java.util.concurrent.TimeUnit;
import java.util.concurrent.atomic.AtomicBoolean;
import java.util.concurrent.atomic.AtomicInteger;

import com.google.common.util.concurrent.ThreadFactoryBuilder;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
import org.springframework.beans.factory.annotation.Required;
import org.springframework.jdbc.core.SingleColumnRowMapper;
import org.springframework.transaction.support.TransactionTemplate;

/**
 * @author avhaliullin
 */
public class JdbcRouterBackend {
    private static final Logger log = LoggerFactory.getLogger(JdbcRouterBackend.class);

    private static final long PING_PERIOD_SECONDS = 5;

    private LogicalDB db;
    private List<ManagedConnection<IReadJdbcTemplate>> managedReadConnections;
    private ManagedWriteConnection managedWriteConnection;

    public void init() {
        ThreadFactory tf = new ThreadFactoryBuilder()
                .setUncaughtExceptionHandler((t, e) -> log.error("[FATAL] Uncaught exception in {}", t, e))
                .setDaemon(true)
                .build();

        ScheduledExecutorService pingPool = Executors.newScheduledThreadPool(
                db.getSlaveTemplates().size() + 1,
                tf
        );
        managedReadConnections = new ArrayList<>();

        for (IReadJdbcTemplate readTemplate : db.getSlaveTemplates()) {
            ManagerReadConnection managerReadConnection = new ManagerReadConnection(readTemplate);
            pingPool.scheduleAtFixedRate(managerReadConnection, PING_PERIOD_SECONDS, PING_PERIOD_SECONDS, TimeUnit.SECONDS);
            managedReadConnections.add(managerReadConnection);
        }
        managedWriteConnection = new ManagedWriteConnection(db.getMainTemplate());
        pingPool.scheduleAtFixedRate(managedWriteConnection, PING_PERIOD_SECONDS, PING_PERIOD_SECONDS, TimeUnit.SECONDS);
    }

    public TransactionTemplate getTransactionTemplate() {
        return db.getTransactionTemplate();
    }

    public ManagedConnection<IWriteJdbcTemplate> getWriteManagedConnection() {
        return managedWriteConnection;
    }

    public List<ManagedConnection<? extends IReadJdbcTemplate>> getReadManagedConnections() {
        List<ManagedConnection<? extends IReadJdbcTemplate>> res = new ArrayList<>();
        for (ManagedConnection<IReadJdbcTemplate> connection : managedReadConnections) {
            res.add(connection);
        }
        res.add(managedWriteConnection);
        return res;
    }

    public String getDbIdentity() {
        return db.getDbIdentity();
    }

    @Required
    public void setDb(LogicalDB db) {
        this.db = db;
    }

    private abstract class AbstractManagedConnection<T extends IReadJdbcTemplate> implements ManagedConnection<T>, Runnable {
        private final Logger log = LoggerFactory.getLogger(AbstractManagedConnection.class);
        private final String PING_QUERY = "SELECT 1";
        private final int QUEUE_SIZE = 100;

        private final T template;

        private final AverageTracker pingTracker = new AverageTracker(QUEUE_SIZE);
        private final AverageTracker queryTracker = new AverageTracker(QUEUE_SIZE);

        private final AtomicBoolean hasError = new AtomicBoolean(false);
        private final AtomicInteger queriesInProgress = new AtomicInteger(0);

        private AbstractManagedConnection(T template) {
            this.template = wrapTemplate(template);
        }

        protected final void onQueryStart(DelegatingReadJdbcTemplate.StartQueryInfo info) {
            queriesInProgress.incrementAndGet();
        }

        protected final void onQueryFinish(DelegatingReadJdbcTemplate.FinishQueryInfo info) {
            if (log.isDebugEnabled()) {
                if (!info.queryInfo.query.equals(PING_QUERY)) {
                    log.debug("SQL query executed in {}ms. Query = {}", info.executionTimeMs, info.queryInfo.query);
                }
            }
            queriesInProgress.decrementAndGet();
            hasError.set(info.exception != null);
            if (PING_QUERY.equals(info.queryInfo.query)) {
                pingTracker.trackNext(info.executionTimeMs);
            } else {
                queryTracker.trackNext(info.executionTimeMs);
            }
        }

        public void run() {
            try {
                template.query(PING_QUERY, new SingleColumnRowMapper<Integer>());
            } catch (Throwable e) {
                log.warn("Ping query failed", e);
            }
        }

        protected abstract T wrapTemplate(T template);

        @Override
        public String getDbIdentity() {
            return template.getDbIdentity();
        }

        @Override
        public T getTemplate() {
            return template;
        }

        @Override
        public long getAvgPing() {
            return pingTracker.getAvg();
        }

        @Override
        public long getLastPing() {
            return pingTracker.getLast();
        }

        @Override
        public long getAvgQuery() {
            return queryTracker.getAvg();
        }

        @Override
        public long getLastQuery() {
            return queryTracker.getLast();
        }

        @Override
        public boolean hasError() {
            return hasError.get();
        }

        @Override
        public int getQueriesInProgress() {
            return queriesInProgress.get();
        }
    }

    private class ManagerReadConnection extends AbstractManagedConnection<IReadJdbcTemplate> {
        private ManagerReadConnection(IReadJdbcTemplate template) {
            super(template);
        }

        @Override
        protected IReadJdbcTemplate wrapTemplate(final IReadJdbcTemplate template) {
            return new DelegatingReadJdbcTemplate() {
                @Override
                public String getDbIdentity() {
                    return template.getDbIdentity();
                }

                @Override
                protected IReadJdbcTemplate getTemplate() {
                    return template;
                }

                @Override
                protected String prepareQuery(String s) {
                    return s;
                }

                @Override
                protected void onQueryStart(StartQueryInfo info) {
                    ManagerReadConnection.this.onQueryStart(info);
                }

                @Override
                protected void onQueryFinish(FinishQueryInfo info) {
                    ManagerReadConnection.this.onQueryFinish(info);
                }
            };
        }
    }

    private class ManagedWriteConnection extends AbstractManagedConnection<IWriteJdbcTemplate> {
        private ManagedWriteConnection(IWriteJdbcTemplate template) {
            super(template);
        }

        @Override
        protected IWriteJdbcTemplate wrapTemplate(final IWriteJdbcTemplate template) {
            return new DelegatingWriteJdbcTemplate() {
                @Override
                public String getDbIdentity() {
                    return template.getDbIdentity();
                }

                @Override
                protected IWriteJdbcTemplate getTemplate() {
                    return template;
                }

                @Override
                protected String prepareQuery(String s) {
                    return s;
                }

                @Override
                protected void onQueryStart(StartQueryInfo info) {
                    ManagedWriteConnection.this.onQueryStart(info);
                }

                @Override
                protected void onQueryFinish(FinishQueryInfo info) {
                    ManagedWriteConnection.this.onQueryFinish(info);
                }
            };
        }
    }

    private static class AverageTracker {
        private final Queue<Long> queue = new LinkedList<>();
        private final int maxValues;
        private long sum = 0;
        private volatile long avg = 0;
        private volatile long last = 0;

        public AverageTracker(int maxValues) {
            this.maxValues = maxValues;
        }

        public synchronized void trackNext(long value) {
            if (queue.size() >= maxValues) {
                sum -= queue.poll();
            }
            sum += value;
            queue.offer(value);
            avg = sum / queue.size();
            last = value;
        }

        public long getAvg() {
            return avg;
        }

        public long getLast() {
            return last;
        }
    }
}
