package ru.yandex.market.clickhouse.ddl;

import com.google.common.annotations.VisibleForTesting;
import com.google.common.base.Preconditions;
import com.google.common.collect.ArrayListMultimap;
import com.google.common.collect.ListMultimap;
import com.google.common.collect.Multimaps;
import org.apache.logging.log4j.LogManager;
import org.apache.logging.log4j.Logger;
import org.springframework.jdbc.core.JdbcTemplate;
import ru.yandex.clickhouse.BalancedClickhouseDataSource;
import ru.yandex.clickhouse.ClickHouseDataSource;
import ru.yandex.clickhouse.ClickhouseJdbcUrlParser;
import ru.yandex.clickhouse.settings.ClickHouseProperties;
import ru.yandex.market.monitoring.MonitoringStatus;

import javax.sql.DataSource;
import java.util.Collections;
import java.util.HashMap;
import java.util.LinkedHashMap;
import java.util.List;
import java.util.Map;
import java.util.Optional;
import java.util.regex.Pattern;
import java.util.stream.Collectors;

/**
 * @author Dmitry Andreev <a href="mailto:AndreevDm@yandex-team.ru"></a>
 * @date 28/03/2018
 */
public class ClickHouseDdlService {

    private static final Logger log = LogManager.getLogger();

    private final ClickHouseCluster cluster;
    private final JdbcTemplate seedJdbcTemplate;
    private final Map<String, ClickHouseHostDdlDao> hostDdlDaos;

    private static final Pattern PARTITION_DATE_PATTERN = Pattern.compile("^\\d{4}-\\d{2}-\\d{2}$");

    private ClickHouseDdlService(ClickHouseCluster cluster, JdbcTemplate seedJdbcTemplate,
                                 Map<String, ClickHouseHostDdlDao> hostDdlDaos) {
        this.cluster = cluster;
        this.seedJdbcTemplate = seedJdbcTemplate;
        this.hostDdlDaos = Collections.unmodifiableMap(hostDdlDaos);
    }

    public TableDdlState applyDdl(ClickHouseTableDefinition tableDefinition) {
        List<TableDdlState.ServerState> serverStates = hostDdlDaos.keySet()
            .stream()
            .parallel()
            .map(host -> applyDdlOnHost(tableDefinition, host))
            .collect(Collectors.toList());
        return TableDdlState.create(serverStates);
    }


    private TableDdlState.ServerState applyDdlOnHost(ClickHouseTableDefinition tableDefinition, String host) {
        ClickHouseHostDdlDao hostDdlDao = hostDdlDaos.get(host);
        ClickHouseCluster.Server server = cluster.getServer(host);

        try {
            DDL ddl = hostDdlDao.getDdl(tableDefinition);
            if (!ddl.requireAttention()) {
                for (DdlQuery query : ddl.getUpdates()) {
                    hostDdlDao.applyQuery(query);
                }
            } else {
                return new TableDdlState.ServerState(
                    server, MonitoringStatus.CRITICAL, ddl, null
                );
            }
        } catch (Exception e) {
            log.warn("Exception while applying ddl on host", e);
            return new TableDdlState.ServerState(
                server, MonitoringStatus.CRITICAL, null, e.getMessage()
            );
        }
        return new TableDdlState.ServerState(server, MonitoringStatus.OK, new DDL(host), null);
    }

    public ClickHouseCluster getCluster() {
        return cluster;
    }

    public static ClickHouseDdlService create(List<String> seedHosts,
                                              ClickHouseProperties clickHouseProperties, String clusterName) {
        return create(
            seedHosts, clickHouseProperties, clusterName,
            (host, dataSource) -> new JdbcTemplate(dataSource)
        );
    }

    public void truncateTable(TableName tableName) {
        log.info("Truncating table {}.{}", tableName.getDatabase(), tableName.getTable());
        executeReplicatedQuery((jdbcTemplate, server) -> {
            List<String> partitions = jdbcTemplate.queryForList(
                "SELECT DISTINCT partition FROM system.parts WHERE database = ? AND table = ? AND active",
                String.class,
                tableName.getDatabase(), tableName.getTable()
            );

            partitions = checkPartitionQuotes(partitions);

            log.info(
                "Found {} partitions on server {} for table {}: {}",
                partitions.size(), server, tableName.getFullName(), partitions
            );
            return partitions.stream()
                .map(p -> String.format(
                    "ALTER TABLE %s.%s DROP PARTITION %s", tableName.getDatabase(), tableName.getTable(), p)
                )
                .collect(Collectors.toList());
        });
    }

    List<String> checkPartitionQuotes(List<String> partitions) {
        return partitions.stream()
            .map(p -> {
                if (PARTITION_DATE_PATTERN.asPredicate().test(p)) {
                    return String.format("'%s'", p);
                }
                return p;
            })
            .collect(Collectors.toList());
    }

    public void attachPartitionFromTable(TableName sourceTable, TableName destTable, String partition) {
        executeReplicatedQuery(
            String.format(
                "ALTER TABLE %s ATTACH PARTITION %s FROM %s",
                destTable.getFullName(), partition, sourceTable.getFullName()
            )
        );
    }

    public void replacePartitionFromTable(TableName sourceTable, TableName destTable, String partition) {
        executeReplicatedQuery(
            String.format(
                "ALTER TABLE %s REPLACE PARTITION %s FROM %s",
                destTable.getFullName(), partition, sourceTable.getFullName()
            )
        );
    }

    public void executeReplicatedQuery(String query) {
        log.info("Applying query on cluster {}: {}", cluster.getName(), query);
        executeReplicatedQuery((jdbcTemplate, server) -> Collections.singletonList(query));
    }

    @VisibleForTesting
    void executeReplicatedQuery(QuerySupplier queriesSupplier) {
        ListMultimap<Integer, ServerQueryResult> shardResults = ArrayListMultimap.create();
        ListMultimap<Integer, ClickHouseHostDdlDao> shardDaos = Multimaps.index(
            hostDdlDaos.values(), dao -> dao.getServer().getShardNumber()
        );

        shardDaos.keySet()
            .stream()
            .parallel()
            .forEach(
                shardId -> shardDaos.get(shardId).forEach(
                    replicaDao -> {
                        /* Applying query until found first OK - status, skip other shard's replicas */
                        if (shardResults.get(shardId).stream()
                            .noneMatch(r -> r.getStatus() == ServerQueryResult.Status.OK)) {

                            shardResults.put(
                                shardId,
                                replicaDao.applyQuery(queriesSupplier)
                            );
                        }
                    }
                )
            );

        Map<Integer, String> shardErrors = new LinkedHashMap<>();

        /* Looking for any shard with all none OK replica's statuses */
        for (Integer shardNum : shardResults.keys()) {
            List<ServerQueryResult> shardQueryResults = shardResults.get(shardNum);

            if (shardQueryResults.stream().anyMatch(r -> r.getStatus() == ServerQueryResult.Status.OK)) {
                continue;
            }
            String shardError = shardQueryResults
                .stream()
                .map(r -> String.format("Host %s error: '%s')", r.getServer().getHost(), r.getErrorString()))
                .collect(Collectors.joining("; "));
            shardErrors.put(shardNum, shardError);
        }

        if (!shardErrors.isEmpty()) {
            throw new RuntimeException("Failed to apply DDL: " + shardErrors.toString());
        }
    }


    public JdbcTemplate getSeedJdbcTemplate() {
        return seedJdbcTemplate;
    }

    @VisibleForTesting
    static ClickHouseDdlService create(List<String> seedHosts, ClickHouseProperties clickHouseProperties,
                                       String clusterName, JdbcTemplateFactory jdbcTemplateFactory) {

        Preconditions.checkArgument(!seedHosts.isEmpty(), "ClickHouse host(s) not provided.");
        String url = ClickhouseJdbcUrlParser.JDBC_CLICKHOUSE_PREFIX + "//" +
            seedHosts.stream()
                .map(host -> host + ":" + clickHouseProperties.getPort())
                .collect(Collectors.joining(","));
        DataSource dataSource = new BalancedClickhouseDataSource(url, clickHouseProperties);
        JdbcTemplate jdbcTemplate = jdbcTemplateFactory.create(null, dataSource);

        ClickHouseCluster cluster = getCluster(jdbcTemplate, clusterName);

        Map<String, ClickHouseHostDdlDao> hostDdlDaos = new HashMap<>();
        for (ClickHouseCluster.Server server : cluster.getServers()) {
            hostDdlDaos.put(server.getHost(), createHostDdlDao(server, clickHouseProperties, jdbcTemplateFactory));
        }

        return new ClickHouseDdlService(cluster, jdbcTemplate, hostDdlDaos);
    }

    private static ClickHouseHostDdlDao createHostDdlDao(ClickHouseCluster.Server server,
                                                         ClickHouseProperties properties,
                                                         JdbcTemplateFactory jdbcTemplateFactory) {
        DataSource dataSource = new ClickHouseDataSource(
            ClickhouseJdbcUrlParser.JDBC_CLICKHOUSE_PREFIX + "//" +
                server.getHost() + ":" + properties.getPort(), properties
        );
        JdbcTemplate jdbcTemplate = jdbcTemplateFactory.create(server.getHost(), dataSource);
        return new ClickHouseHostDdlDao(server, jdbcTemplate);
    }


    public static ClickHouseCluster getCluster(JdbcTemplate jdbcTemplate, String cluster) {
        List<ClickHouseCluster.Server> servers = jdbcTemplate.query(
            "SELECT * FROM system.clusters WHERE cluster = ?",
            (rs, rowNum) -> new ClickHouseCluster.Server(
                rs.getString("host_name"),
                rs.getInt("shard_num"),
                rs.getInt("replica_num")
            ),
            cluster
        );
        Preconditions.checkState(!servers.isEmpty(), "Clickhouse cluster {} not exists", cluster);
        return new ClickHouseCluster(cluster, servers);
    }

    public JdbcTemplate getHostJdbcTemplate(String s) {
        return hostDdlDaos.get(s).getJdbcTemplate();
    }

    protected interface JdbcTemplateFactory {
        /**
         * @param host       server hostname, null for seed datasource
         * @param dataSource
         * @return
         */
        JdbcTemplate create(String host, DataSource dataSource);
    }

    public interface QuerySupplier {
        List<String> get(JdbcTemplate jdbcTemplate, ClickHouseCluster.Server server);
    }

    public ClickHouseTableDefinition getClickHouseTableDefinition(TableName tableName) {
        for (ClickHouseHostDdlDao clickHouseHostDdlDao : hostDdlDaos.values()) {
            Optional<ClickHouseTableDefinition> td = clickHouseHostDdlDao.getExistedTableDefinition(tableName);
            if (td.isPresent()) {
                return td.get();
            }
        }
        throw new RuntimeException(String.format("Table definition of '%s' table is not found", tableName));
    }
}
