package ru.yandex.direct.jobs.yt.audit;

import java.time.LocalDateTime;
import java.util.ArrayList;
import java.util.Collection;
import java.util.List;
import java.util.Set;
import java.util.stream.Collectors;

import org.jooq.DSLContext;
import org.jooq.impl.DSL;

import ru.yandex.direct.dbschema.ppc.enums.YtChecksumStatus;
import ru.yandex.direct.dbutil.wrapper.DatabaseWrapper;
import ru.yandex.direct.dbutil.wrapper.DatabaseWrapperProvider;
import ru.yandex.direct.ytwrapper.model.YtCluster;

import static java.util.stream.Collectors.toList;
import static ru.yandex.direct.dbschema.ppc.Tables.YT_CHECKSUM;

/**
 * Репозиторий для работы с mysql-таблицей yt_checksum
 * Для генерации запросов в ppc и ppcdict используется одна и так же jooq-схема (из ppc)
 * Это работает, т.к. таблицы одинаковые; а создавать разные версии методов для ppc и ppcdict было бы неудобно.
 */
public class YtChecksumMysqlRepository {
    private final DatabaseWrapperProvider databaseWrapperProvider;

    public YtChecksumMysqlRepository(DatabaseWrapperProvider databaseWrapperProvider) {
        this.databaseWrapperProvider = databaseWrapperProvider;
    }

    public Integer getMaxIteration(String dbName, YtCluster ytCluster, String table) {
        String tbl = TableChecker.makeTbl(table, ytCluster);
        DatabaseWrapper databaseWrapper = databaseWrapperProvider.get(dbName);
        try (DSLContext dslContext = databaseWrapper.getDslContext()) {
            return (Integer) dslContext.select(DSL.cast(DSL.max(YT_CHECKSUM.ITERATION), Integer.class).as("max_iter"))
                    .from(YT_CHECKSUM)
                    .where(YT_CHECKSUM.TBL.eq(tbl))
                    .fetchOne("max_iter");
        }
    }

    public List<IterationInfo> getLastFinishedIterations(String dbName, YtCluster ytCluster,
                                                         Collection<String> tables) {
        List<String> tblList = tables.stream().map(table -> TableChecker.makeTbl(table, ytCluster)).collect(toList());
        DatabaseWrapper databaseWrapper = databaseWrapperProvider.get(dbName);
        try (DSLContext dslContext = databaseWrapper.getDslContext()) {
            return dslContext.select(
                    YT_CHECKSUM.TBL,
                    YT_CHECKSUM.ITERATION,
                    DSL.max(YT_CHECKSUM.TS).as("max_ts")
            )
                    .from(YT_CHECKSUM)
                    .where(YT_CHECKSUM.TBL.in(tblList))
                    .and(YT_CHECKSUM.STATUS.eq(YtChecksumStatus.Finished))
                    .groupBy(YT_CHECKSUM.TBL, YT_CHECKSUM.ITERATION)
                    .fetch(record -> new IterationInfo(
                            record.getValue(YT_CHECKSUM.TBL),
                            (int) (long) record.getValue(YT_CHECKSUM.ITERATION),
                            (LocalDateTime) record.getValue("max_ts")
                    ));
        }
    }

    public static String removeShard(String dbName) {
        if (dbName.contains(":")) {
            return dbName.split(":")[0];
        }
        return dbName;
    }

    public List<String> getTablesList(List<String> prefixedList, String dbName) {
        String prefix = removeShard(dbName) + ":";
        List<String> allTables = getAllTables(dbName)
                .stream()
                .filter(t -> !t.equals("yt_checksum"))
                .collect(toList());
        //
        List<String> result = new ArrayList<>();
        for (String tbl : prefixedList) {
            if (tbl.strip().startsWith(prefix)) {
                String table = tbl.strip().substring(prefix.length());
                if (table.equals("*")) {
                    result.addAll(allTables);
                } else {
                    result.add(table);
                }
            }
        }
        return result;
    }

    public List<String> getAllTables(String dbName) {
        DatabaseWrapper databaseWrapper = databaseWrapperProvider.get(dbName);
        try (DSLContext dslContext = databaseWrapper.getDslContext()) {
            return dslContext.resultQuery("SHOW TABLES")
                    .fetch(record -> (String) record.get(0));
        }
    }

    public Set<String> getCheckedTables(String dbName, YtCluster ytCluster, LocalDateTime minTs) {
        DatabaseWrapper databaseWrapper = databaseWrapperProvider.get(dbName);
        try (DSLContext dslContext = databaseWrapper.getDslContext()) {
            return dslContext.selectDistinct(YT_CHECKSUM.TBL)
                    .from(YT_CHECKSUM)
                    .where(YT_CHECKSUM.TBL.like("%/" + ytCluster.getName()))
                    .and(YT_CHECKSUM.STATUS.eq(YtChecksumStatus.Finished))
                    .and(YT_CHECKSUM.ITERATION.greaterOrEqual(0L))
                    .and(YT_CHECKSUM.TS.greaterThan(minTs))
                    .fetch(YT_CHECKSUM.TBL)
                    .stream()
                    .map(v -> v.split("/")[0])
                    .collect(Collectors.toSet());
        }
    }
}
