package ru.yandex.direct.dbutil.wrapper;

import java.util.Objects;
import java.util.Optional;
import java.util.concurrent.ConcurrentHashMap;

import javax.annotation.ParametersAreNonnullByDefault;
import javax.annotation.PostConstruct;
import javax.annotation.PreDestroy;
import javax.sql.DataSource;

import org.jooq.SQLDialect;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
import org.springframework.beans.factory.annotation.Autowired;
import org.springframework.stereotype.Component;

import ru.yandex.direct.db.config.DbConfig;
import ru.yandex.direct.db.config.DbConfigEvent;
import ru.yandex.direct.db.config.DbConfigFactory;
import ru.yandex.direct.db.config.DbConfigListener;
import ru.yandex.direct.dbutil.sharding.ShardWeight;
import ru.yandex.direct.env.EnvironmentType;

/**
 * Used for creating and caching jdbc templates for our databases
 * <p>
 * Controls DataSource inside each DatabaseWrapper
 */
@ParametersAreNonnullByDefault
@Component
public class DatabaseWrapperProvider implements DbConfigListener {
    private static final Logger logger = LoggerFactory.getLogger(DatabaseWrapperProvider.class);

    private final ConcurrentHashMap<String, DatabaseWrapper> cache = new ConcurrentHashMap<>();
    private final DataSourceFactory dataSourceFactory;
    private final DbConfigFactory dbConfigFactory;
    private final EnvironmentType environmentType;

    @Autowired
    private DatabaseWrapperProvider(
            DataSourceFactory dataSourceFactory,
            DbConfigFactory dbConfigFactory,
            EnvironmentType environmentType
    ) {
        Objects.requireNonNull(dataSourceFactory, "dataSourceFactory is required");
        Objects.requireNonNull(dbConfigFactory, "dbConfigFactory is required");
        this.dataSourceFactory = dataSourceFactory;
        this.dbConfigFactory = dbConfigFactory;
        this.environmentType = environmentType;
    }

    public static DatabaseWrapperProvider newInstance(DataSourceFactory dataSourceFactory,
                                                      DbConfigFactory dbConfigFactory, EnvironmentType environmentType) {
        DatabaseWrapperProvider provider =
                new DatabaseWrapperProvider(dataSourceFactory, dbConfigFactory, environmentType);
        provider.afterConstruction();
        return provider;
    }

    private static DataSource wrapDataSource(DataSource dataSource) {
        return new RefreshableDataSource(dataSource);
    }

    @PostConstruct
    private void afterConstruction() {
        // После выполнения конструктора передаем ссылку на полностью
        // сконструированный объект. Так thread-safe-ее
        dbConfigFactory.addListener(this);
    }

    public DatabaseWrapper get(SimpleDb simpleDb) {
        return get(simpleDb.toString().toLowerCase());
    }

    public DatabaseWrapper get(ShardedDb shardedDb, int shard) {
        return get(shardedDb.getDbNameForShard(shard));
    }

    /**
     * Вернуть wrapper для заданной базы данных
     *
     * @param dbName Название базы данных в формате db:shardNo, например ppc:1
     */
    public DatabaseWrapper get(String dbName) {
        DatabaseWrapper cached = cache.get(dbName);
        if (cached != null) {
            return cached;
        }
        return createAndCacheWrapper(dbName);
    }

    /**
     * Проверить доступность shard-а
     */
    public boolean isAlive(ShardedDb db, int shard) {
        Objects.requireNonNull(db, "db");
        if (shard <= 0) {
            throw new IllegalArgumentException("shard: " + shard);
        }
        try {
            return get(db, shard).isAlive();
        } catch (RuntimeException e) {
            logger.error(
                    String.format("Can't check is shard alive (db: %s, shard: %d)", db, shard), e);
            return false;
        }
    }

    /**
     * Вернуть веса доступных shard-ов
     */
    public ShardWeight getShardWeight(ShardedDb db, int shard) {
        Objects.requireNonNull(db, "db");
        if (shard <= 0) {
            throw new IllegalArgumentException("shard: " + shard);
        }
        return new ShardWeight(
                shard,
                Optional.ofNullable(
                        dbConfigFactory.get(
                                db.getDbNameForShard(shard)).getWeight())
                        .orElse(1));
    }

    /**
     * Needs synchronization for consistent building and refreshing DataSource'es from DbConfig
     */
    private synchronized DatabaseWrapper createAndCacheWrapper(String dbName) {
        logger.debug("Creating and caching DatabaseWrapper for {}", dbName);
        DbConfig config = dbConfigFactory.get(dbName);
        return cache.computeIfAbsent(dbName,
                key -> new DatabaseWrapper(dbName, wrapDataSource(createDataSource(config)), SQLDialect.MYSQL,
                        environmentType));
    }

    /**
     * Needs synchronization for consistent building and refreshing DataSource'es from DbConfig
     */
    @Override
    public synchronized void update(DbConfigEvent event) {
        logger.debug("Refreshing all DatabaseWrapper's DataSources");
        cache.forEach((dbName, wrapper) -> safeRefreshDataSource(dbName, wrapper.getDataSource()));
    }

    @PreDestroy
    public void close() throws Exception {
        cache.forEach((dbName, wrapper) -> closeDataSource(dbName, wrapper.getDataSource()));
    }

    private DataSource createDataSource(DbConfig config) {
        return dataSourceFactory.createDataSource(config);
    }

    /**
     * Bullet-proof аналог {@link #refreshDataSource(String, DataSource)}.
     * Устойчив к ошибкам при обновлении DataSource'а.
     */
    private void safeRefreshDataSource(String dbName, DataSource dataSource) {
        try {
            refreshDataSource(dbName, dataSource);
        } catch (RuntimeException error) {
            logger.error("Error on updating DataSource {}. Evict it from cache", dbName, error);
            evict(dbName, dataSource);
        }
    }

    private void refreshDataSource(String dbName, DataSource dataSource) {
        logger.debug("Refreshing DatabaseWrapper DataSource for {}", dbName);
        if (!dbConfigFactory.has(dbName)) {
            logger.warn("DataSource by name {} wasn't found in config. Will be evicted", dbName);
            evict(dbName, dataSource);
        } else {
            DbConfig config = dbConfigFactory.get(dbName);
            DataSource newWrappedDs = createDataSource(config);
            DataSource oldWrappedDs = ((RefreshableDataSource) dataSource).refresh(newWrappedDs);
            closeDataSource(dbName, oldWrappedDs);
        }
    }

    /**
     * Исключение dataSource по имени {@code dbName} из кэша и его закрытие
     */
    private synchronized void evict(String dbName, DataSource dataSource) {
        logger.info("Evicting dataSource {} from cache", dbName);
        cache.remove(dbName, dataSource);
        closeDataSource(dbName, dataSource);
    }

    private void closeDataSource(String dbName, DataSource dataSource) {
        if (dataSource instanceof AutoCloseable) {
            try {
                ((AutoCloseable) dataSource).close();
            } catch (Exception e) {
                logger.error("Error closing DataSource for " + dbName, e);
            }
        }
    }
}
