package ru.yandex.chemodan.app.dataapi.worker.dump;

import java.time.Duration;
import java.util.List;
import java.util.Set;
import java.util.concurrent.ArrayBlockingQueue;
import java.util.concurrent.CompletableFuture;
import java.util.concurrent.ConcurrentHashMap;
import java.util.concurrent.ExecutorService;
import java.util.concurrent.ForkJoinPool;
import java.util.concurrent.ForkJoinWorkerThread;
import java.util.concurrent.SynchronousQueue;
import java.util.concurrent.ThreadPoolExecutor;
import java.util.concurrent.TimeUnit;
import java.util.concurrent.atomic.AtomicInteger;
import java.util.stream.Collectors;

import com.fasterxml.jackson.databind.JsonNode;
import com.fasterxml.jackson.databind.ObjectMapper;
import com.fasterxml.jackson.databind.node.JsonNodeFactory;
import net.jodah.failsafe.RetryPolicy;
import org.joda.time.DateTime;

import ru.yandex.bolts.collection.Cf;
import ru.yandex.bolts.collection.ListF;
import ru.yandex.bolts.collection.MapF;
import ru.yandex.bolts.collection.Option;
import ru.yandex.chemodan.app.dataapi.api.data.filter.RecordsFilter;
import ru.yandex.chemodan.app.dataapi.api.data.filter.condition.DatabaseCondition;
import ru.yandex.chemodan.app.dataapi.api.data.record.CollectionRef;
import ru.yandex.chemodan.app.dataapi.api.data.snapshot.Snapshot;
import ru.yandex.chemodan.app.dataapi.api.db.ref.DatabaseRef;
import ru.yandex.chemodan.app.dataapi.api.db.ref.UserDatabaseSpec;
import ru.yandex.chemodan.app.dataapi.api.user.DataApiUserId;
import ru.yandex.chemodan.app.dataapi.core.datasources.disk.DiskDataSource;
import ru.yandex.chemodan.app.dataapi.core.manager.DataApiManager;
import ru.yandex.chemodan.app.dataapi.web.direct.a3.DirectDataApiBenderUtils;
import ru.yandex.chemodan.util.BatchCollectorAsync;
import ru.yandex.chemodan.util.retry.RetryManager;
import ru.yandex.chemodan.util.yt.JacksonTableEntryTypeWithoutUtfEncoding;
import ru.yandex.inside.yt.kosher.Yt;
import ru.yandex.inside.yt.kosher.cypress.CypressNodeType;
import ru.yandex.inside.yt.kosher.cypress.YPath;
import ru.yandex.inside.yt.kosher.impl.ytree.builder.YTree;
import ru.yandex.inside.yt.kosher.operations.specs.MergeSpec;
import ru.yandex.inside.yt.kosher.tables.YTableEntryTypeUtils;
import ru.yandex.inside.yt.kosher.ytree.YTreeNode;
import ru.yandex.misc.ExceptionUtils;
import ru.yandex.misc.bender.BenderMapper;
import ru.yandex.misc.db.masterSlave.MasterSlaveContextHolder;
import ru.yandex.misc.db.masterSlave.MasterSlavePolicy;
import ru.yandex.misc.log.mlf.Logger;
import ru.yandex.misc.log.mlf.LoggerFactory;
import ru.yandex.misc.thread.factory.ThreadNameIndexThreadFactory;

import static ru.yandex.chemodan.app.dataapi.utils.YtPathsUtils.YT_NODE_NAME_FORMATTER;

/**
 * @author vpronto
 */
public abstract class AbstractDumpDatabaseUsersProcessor implements AutoCloseable {
    private final Logger logger = LoggerFactory.getLogger(getClass());

    private static final DatabaseCondition NON_EMPTY_DB_CONDITION = DatabaseCondition.recordsCount().gt(0);

    public static final String REDUMPED_PARTITION_PREFIX = "redumped-group-";
    public static final String REDUMPED_FINISHED_PARTITION_PREFIX = "redumped-finished-group-";
    public static final String REDUMPED_FINISH_GROUP_PREFIX = "redumped-merged-group-";
    public static final String PARTITION_SUFFIX = "-partition-";
    public static final int COUNT_PARTITIONS_IN_GROUP = 128;

    private static MapF<String, YTreeNode> COMPRESSION_ATTRIBUTES = Cf.map(
            "compression_codec", YTree.stringNode("brotli_8"),
            "erasure_codec", YTree.stringNode("lrc_12_2_2")
    );

    private static final BenderMapper benderMapper = DirectDataApiBenderUtils.mapper();
    private static final ObjectMapper jacksonMapper = YTableEntryTypeUtils.getDefaultObjectMapper();
    protected static final JsonNodeFactory jsonNodeFactory = new JsonNodeFactory(false);

    protected final DataApiManager dataApiManager;
    protected final DiskDataSource diskDataSource;
    protected final RetryPolicy dbRetryPolicy;
    protected final DatabaseRef dbRef;
    protected final Option<String> collectionId;
    private final Yt yt;

    private final int threads;
    private final int maxNumberOfUsersInSingleChunk;
    private final RetryPolicy ytRetryPolicy;

    private final int maxSkippedUsers;
    private int currentPartition = 0;

    private AtomicInteger skippedUsers;
    private MapF<String, YPath> firstWrite;
    private ForkJoinPool fetcherSupplier;
    private ExecutorService pusherExecutor;
    private Set<String> processedCollections;
    protected YPath generalPath;
    private final ArrayBlockingQueue<ListF<DataApiUserId>> prefetchedUsers = new ArrayBlockingQueue<>(3, true);

    public AbstractDumpDatabaseUsersProcessor(
            DataApiManager dataApiManager,
            DiskDataSource diskDataSource,
            Option<String> app, String databaseId, Option<String> collectionId,
            Yt yt, int maxNumberOfUsersInSingleChunk, int maxSkippedUsers,
            RetryPolicy dbRetryPolicy, RetryPolicy ytRetryPolicy, int threads) {
        this.dataApiManager = dataApiManager;
        this.diskDataSource = diskDataSource;
        this.dbRetryPolicy = dbRetryPolicy;
        this.dbRef = DatabaseRef.cons(app, databaseId);
        this.collectionId = collectionId;
        this.yt = yt;
        this.maxNumberOfUsersInSingleChunk = maxNumberOfUsersInSingleChunk;
        this.maxSkippedUsers = maxSkippedUsers;
        this.ytRetryPolicy = ytRetryPolicy;
        this.threads = threads;
    }

    private ForkJoinPool fetcherStreamSupplier(int threads) {
        final ForkJoinPool.ForkJoinWorkerThreadFactory factory = pool -> {
            final ForkJoinWorkerThread worker = ForkJoinPool.defaultForkJoinWorkerThreadFactory.newThread(pool);
            worker.setName("dump-yt-fetcher-"  + getGeneralPath().name() + "-" + worker.getPoolIndex());
            return worker;
        };
        return new ForkJoinPool(threads, factory, null, false);
    }

    private ExecutorService pusherExecutor(int threadCount) {
        return new ThreadPoolExecutor(threadCount / 2, threadCount * 2,
                100, TimeUnit.SECONDS,
                new SynchronousQueue<>(),
                new ThreadNameIndexThreadFactory("dump-yt-pusher-" + getGeneralPath().name()),
                new ThreadPoolExecutor.CallerRunsPolicy());
    }

    public void process() {
        try {
            currentPartition = 0;
            processedCollections = ConcurrentHashMap.newKeySet();
            collectionId.ifPresent(s -> processedCollections.add(s));
            skippedUsers = new AtomicInteger();
            firstWrite = Cf.concurrentHashMap();
            generalPath = getGeneralPath();
            logger.info("About to start dump with threads: {}, maxNumberOfUsersInChunk: {} at {}, generalPath: {}",
                    threads, maxNumberOfUsersInSingleChunk, dbRef, generalPath);

            pusherExecutor = pusherExecutor(threads);
            fetcherSupplier = fetcherStreamSupplier(threads);
            CompletableFuture<Boolean> finished = prepareUsers();
            while (!finished(finished)) {
                processUsersFromSpecificPartitionRecoverable(prefetchedUsers.poll(5, TimeUnit.MINUTES));
                if (currentPartition % COUNT_PARTITIONS_IN_GROUP == 0) {
                    mergeGroupPartitions(getGroupNumber(currentPartition) - 1);
                }
            }
            mergeAndSortYtTables();
        } catch (InterruptedException e) {
            throw ExceptionUtils.translate(e);
        } finally {
            close();
        }
    }

    private boolean finished(CompletableFuture<Boolean> finished) {
        return finished.getNow(false) && prefetchedUsers.isEmpty();
    }

    protected CompletableFuture<Boolean> prepareUsers() {
        return CompletableFuture.supplyAsync(() -> {
            diskDataSource.getDatabaseUsers(dbRef)
                    .withDbCond(NON_EMPTY_DB_CONDITION)
                    .withRetryPolicy(dbRetryPolicy)
                    .safe()
                    .forEachRemaining(t -> addToQueue(t));
            return true;
        }, fetcherSupplier);
    }

    private void addToQueue(ListF<DataApiUserId> dataApiUserIds) {
        try {
            prefetchedUsers.put(dataApiUserIds);
        } catch (Exception e) {
        }
    }

    protected void processUsersFromSpecificPartitionRecoverable(ListF<DataApiUserId> users) {
        logger.info("Processing users from partition: {}, batch {}", currentPartition, users.size());
        try {
            if (needProcessPartition()) {
                ListF<CompletableFuture<Boolean>> suppliers = Cf.arrayList();
                processUsersFromSpecificPartition(users, suppliers);
                finishPartition(suppliers);
            } else {
                logger.info("Skipping partition {}", currentPartition);
            }
        } catch (Exception e) {
            logger.error("Can't process users", e);
        } finally {
            currentPartition++;

        }
    }

    private void processUsersFromSpecificPartition(ListF<DataApiUserId> users, ListF<CompletableFuture<Boolean>> suppliers) {
        try {
            long start = System.currentTimeMillis();
            fetcherSupplier.submit(
                    () -> users.stream().parallel().map(u -> getSnapshotYtRepresentationWithRetries(u))
                            .filter(Option::isPresent)
                            .map(Option::get)
                            .filter(c -> c.isNotEmpty())
                            .collect(BatchCollectorAsync.batchCollector(maxNumberOfUsersInSingleChunk, objects -> {
                                mergeAndSend(objects, suppliers);
                            }))).get();
            logger.info("Finished partition: {} for {}", currentPartition,
                    Duration.ofMillis(System.currentTimeMillis() - start));
        } catch (Exception e) {
            skippedUsers.addAndGet(users.length());
            logger.warn("Can't process users at partition: {}", currentPartition);
        }
    }

    private boolean needProcessPartition() {

        Set<String> minedCollections = fetchCollections();
        logger.info("Checking partition: {}", currentPartition);
        List<Boolean> current = minedCollections.stream().map(col -> {
            try {
                YPath parent = getCollectionYPath(col);
                logger.info("Checking finished yt tables at: {}", parent);

                ListF<YPath> finished = Cf.wrap(yt.cypress().list(parent)).filterMap(p ->
                        Option.when(p.getValue().startsWith(getRedumpedFinishedPartitionPattern(currentPartition)),
                                parent.child(p.getValue())));

                return !finished.isEmpty();
            } catch (Exception e) {
                logger.error("Can't check partition ()", currentPartition, e);
                return false;
            }
        }).collect(Collectors.toList());
        cleanTemp(minedCollections, getRedumpedPartitionPattern(currentPartition));
        long finishedCollections = current.stream().filter(c -> c).count();
        logger.info("Total count of collections: {}, finished at partition {}: {}", minedCollections.size(),
                currentPartition, finishedCollections);
        boolean finished = minedCollections.size() == finishedCollections;
        if (!finished) {
            logger.info("Cleaning full dump for partition {}", currentPartition);
            cleanTemp(minedCollections, getRedumpedFinishedPartitionPattern(currentPartition));
        }
        return !finished;
    }

    private Set<String> fetchCollections() {
        Set<String> mainedCollections = Cf.hashSet();
        if (collectionId.isPresent()) {
            mainedCollections.add(collectionId.get());
        } else {
            yt.cypress().list(generalPath).forEach(p -> {
                mainedCollections.add(p.getValue());
            });
        }
        return mainedCollections;
    }

    public void cleanTemp(Set<String> collections, String pattern) {
        logger.info("About to clean collections at yt: {}, with pattern: {}", collections, pattern);
        collections.forEach(col -> {
            try {
                YPath parent = getCollectionYPath(col);
                ListF<YPath> paths = Cf.wrap(yt.cypress().list(parent)).filterMap(p ->
                        Option.when(p.getValue().startsWith(pattern), parent.child(p.getValue())));
                paths.forEach(p -> yt.cypress().remove(p));
            } catch (Exception e) {
                logger.error("Can't remove temp files", e);
            }
        });
    }

    private void mergeAndSortYtTables() {
        //waite all threads to finish
        pusherExecutor.shutdown();
        try {
            pusherExecutor.awaitTermination(1, TimeUnit.HOURS);
        } catch (InterruptedException e) {
        }
        processedCollections.forEach(col -> {
            try {
                logger.info("Processing collection: {}", col);
                YPath parent = getCollectionYPath(col);
                logger.info("Creating result at: {}", parent);

                ListF<YPath> paths = Cf.wrap(yt.cypress().list(parent)).filterMap(p ->
                        Option.when(p.getValue().startsWith(REDUMPED_FINISH_GROUP_PREFIX), parent.child(p.getValue())));

                YPath target = parent.child("redumped").withAdditionalAttributes(COMPRESSION_ATTRIBUTES);

                yt.operations().mergeAndGetOp(
                    MergeSpec.builder()
                        .setInputTables(Cf.toList(paths))
                        .setOutputTable(target)
                        .setCombineChunks(true)
                        .build()
                ).awaitAndThrowIfNotSuccess();
                yt.operations().sortAndGetOp(target, target, Cf.list("uid", "record_id")).awaitAndThrowIfNotSuccess();

                paths.forEach(p -> yt.cypress().remove(p));

                YPath resultPath = getResultPath(col);
                yt.cypress().move(target, resultPath, false, true, false);
                yt.cypress().copy(resultPath,
                        resultPath.parent().child(YT_NODE_NAME_FORMATTER.print(DateTime.now())), false, true, false);
            } catch (Exception e) {
                logger.error("Can't merge snapshots", e);
            }
        });
    }

    private void mergeGroupPartitions(int groupNumber) {
        Set<String> coll = Cf.toSet(processedCollections);
        CompletableFuture.supplyAsync(() -> {
            coll.forEach(col -> {
                try {
                    logger.info("Finished partition for group {}", groupNumber);
                    YPath parent = getCollectionYPath(col);

                    ListF<YPath> pathsForMerge = Cf.wrap(yt.cypress().list(parent)).filterMap(p ->
                            Option.when(p.getValue().startsWith(getRedumpedFinishedPartitionPrefix(groupNumber)), parent.child(p.getValue())));
                    YPath target = parent.child(getRedumpedFinishedGroupPattern(groupNumber)).withAdditionalAttributes(COMPRESSION_ATTRIBUTES);

                    logger.info("Merging group {} partitions: {} -> {}", groupNumber, pathsForMerge, target);
                    if (pathsForMerge.isEmpty()) {
                        return;
                    }

                    yt.operations().mergeAndGetOp(
                            MergeSpec.builder()
                            .setInputTables(Cf.toList(pathsForMerge))
                            .setOutputTable(target)
                            .setCombineChunks(true)
                            .build()
                    ).awaitAndThrowIfNotSuccess();

                    pathsForMerge.forEach(p -> yt.cypress().remove(p));
                } catch (Exception e) {
                    logger.error("Can't merge partitions for group {}", groupNumber, e);
                }
            });
            return null;
        }, pusherExecutor);
    }

    private void finishPartition(ListF<CompletableFuture<Boolean>> suppliers) {
        int part = currentPartition;
        Set<String> coll = Cf.toSet(processedCollections);
        CompletableFuture.supplyAsync(() -> {
            waiteAllTaskOfPartition(suppliers);
            coll.forEach(col -> {
                try {
                    logger.info("Finished currentPartition: {}", part);
                    YPath parent = getCollectionYPath(col);

                    ListF<YPath> pathsForMerge = Cf.wrap(yt.cypress().list(parent)).filterMap(p ->
                            Option.when(p.getValue().startsWith(getRedumpedPartitionPattern(part)), parent.child(p.getValue())));
                    YPath target = parent.child(getRedumpedFinishedPartitionPattern(part))
                            .withAdditionalAttributes(COMPRESSION_ATTRIBUTES);
                    logger.info("Merging partition: {}, {} -> {}", part, pathsForMerge, target);

                    if (pathsForMerge.isEmpty()) {
                        return;
                    }
                    yt.operations().mergeAndGetOp(
                        MergeSpec.builder()
                            .setInputTables(Cf.toList(pathsForMerge))
                            .setOutputTable(target)
                            .setCombineChunks(true)
                            .build()
                    ).awaitAndThrowIfNotSuccess();
                    pathsForMerge.forEach(p -> yt.cypress().remove(p));

                } catch (Exception e) {
                    logger.error("Can't finish partition: {}", part, e);
                }
            });
            return null;
        }, pusherExecutor);

    }

    private void waiteAllTaskOfPartition(ListF<CompletableFuture<Boolean>> suppliers) {
        logger.info("Waiting pushing {}", suppliers.size());
        suppliers.forEach(s -> {
            try {
                s.get();
            } catch (Exception e) {
                logger.warn("Give up to waite", e);
            }
        });
    }

    private synchronized void mergeAndSend(List<MapF<String, ListF<JsonNode>>> rows,
                                           ListF<CompletableFuture<Boolean>> suppliers) {
        if (rows.isEmpty()) {
            return;
        }
        logger.info("About to merge {} users", rows.size());
        MapF<String, ListF<JsonNode>> rowsGroupedByCollection = Cf.hashMap();
        rows.forEach(row -> {
            row.keys().forEach(collection -> {
                rowsGroupedByCollection.computeIfAbsent(collection, s ->Cf.arrayList())
                        .addAll(row.getTs(collection));
            });
        });
        writeRowsToYtWithRetries(rowsGroupedByCollection, rows.size(), suppliers);
    }

    private void writeRowsToYtWithRetries(MapF<String, ListF<JsonNode>> rowsGroupedByCollection, int usersInChunk,
                                          ListF<CompletableFuture<Boolean>> suppliers) {
        logger.info("About to write users to yt collections: {}, count: {}", rowsGroupedByCollection.keys().size(), usersInChunk);
        if (rowsGroupedByCollection.isEmpty()) {
            return;
        }
        processedCollections.addAll(rowsGroupedByCollection.keySet());
        rowsGroupedByCollection.forEach(
                (collection, rows) -> writeRowsToYtWithRetries(getCollectionYPath(collection), rows, usersInChunk, suppliers));
    }

    private void writeRowsToYtWithRetries(YPath p, ListF<JsonNode> ytTableRows, int usersInChunk, ListF<CompletableFuture<Boolean>> suppliers) {
        int part = currentPartition;
        suppliers.add(CompletableFuture.supplyAsync(() -> {
            YPath path = p.child(getChunkName(part));
            new RetryManager()
                    .withRetryPolicy(ytRetryPolicy)
                    .withLogging("Writing records for " + usersInChunk + " users in yt table: " + path)
                    .withFailureCallback(() -> skippedUsers.addAndGet(usersInChunk))
                    .runSafe(() -> {
                        YPath pathWithAppendIfNeeded = isFirstWrite(part, path) ? path : path.append(true);

                        yt.cypress().create(pathWithAppendIfNeeded, CypressNodeType.TABLE, true, true,
                                Cf.map("dynamic", YTree.booleanNode(false))
                        );
                        yt.tables().write(pathWithAppendIfNeeded, new JacksonTableEntryTypeWithoutUtfEncoding(), ytTableRows);

                    });
            return true;
        }, pusherExecutor));

    }

    private Option<MapF<String, ListF<JsonNode>>> getSnapshotYtRepresentationWithRetries(DataApiUserId uid) {

        try {
            Option<Snapshot> snapshot = MasterSlaveContextHolder.withPolicy(MasterSlavePolicy.R_SM, () -> getSnapshotWithRetries(uid));
            if (snapshot.isPresent()) {
                MapF<String, ListF<JsonNode>> rowsGroupedByCollection = Cf.hashMap();
                String serializedSnapshot = new String(benderMapper.serializeJson(snapshot.map(Snapshot::toPojo).get()));
                JsonNode jsonSnapshot = jacksonMapper.readTree(serializedSnapshot);

                for (JsonNode record : jsonSnapshot.get("records").get("items")) {
                    String collectionId = record.get("collection_id").asText();
                    rowsGroupedByCollection.computeIfAbsent(collectionId, s -> Cf.arrayList())
                            .add(convertRecordToYtRow(uid, record));
                }

                return Option.of(rowsGroupedByCollection);
            } else {
                logger.warn("Snapshot is empty for {}", uid);
            }
        } catch (Exception e) {
            logger.error("Dumping {} for uid failed.", getSnapshotFullName(), uid, e);

            ExceptionUtils.throwIfUnrecoverable(e);
        }
        return Option.empty();
    }

    private Option<Snapshot> getSnapshotWithRetries(DataApiUserId user) {
        return new RetryManager<Snapshot>()
                .withRetryPolicy(dbRetryPolicy)
                .withLogging("Getting snapshot for database " + user + "." + dbRef())
                .getSafe(() -> dataApiManager.getSnapshot(new UserDatabaseSpec(user, dbRef()), getFilter()));
    }

    protected abstract JsonNode convertRecordToYtRow(DataApiUserId uid, JsonNode record);

    protected abstract YPath getGeneralPath();

    protected abstract YPath getResultPath(String cols);

    protected YPath getCollectionYPath(String collection) {
        return collectionId.isPresent() ? generalPath : generalPath.child(collection);
    }

    private String getChunkName(int partition) {
        return getRedumpedPartitionPattern(partition) + "-" + Thread.currentThread().getName();
    }

    private boolean isFirstWrite(int partition, YPath path) {
        return firstWrite.put(getChunkName(partition) + path, path) == null;
    }

    private RecordsFilter getFilter() {
        return collectionRefO()
                .map(RecordsFilter.DEFAULT::withColRef)
                .getOrElse(RecordsFilter.DEFAULT);
    }

    private Option<CollectionRef> collectionRefO() {
        return collectionId.map(colId -> dbRef().consColRef(colId));
    }

    private String getSnapshotFullName() {
        return dbRef().appNameO() + "." + dbRef().databaseId() + "." + collectionId.getOrElse("");
    }

    public String getRedumpedPartitionPattern(int partition) {
        return getPartitionPattern(REDUMPED_PARTITION_PREFIX, partition);
    }

    public String getRedumpedFinishedPartitionPattern(int partition) {
        return getPartitionPattern(REDUMPED_FINISHED_PARTITION_PREFIX, partition);
    }

    public String getRedumpedFinishedPartitionPrefix(int groupNumber) {
        return REDUMPED_FINISHED_PARTITION_PREFIX + groupNumber;
    }

    public String getRedumpedFinishedGroupPattern(int groupNumber) {
        return REDUMPED_FINISH_GROUP_PREFIX + groupNumber;
    }

    private String getPartitionPattern(String prefix, int partition) {
        return prefix + getGroupNumber(partition) + PARTITION_SUFFIX + partition;
    }

    public static int getGroupNumber(int partition) {
        return partition / COUNT_PARTITIONS_IN_GROUP;
    }

    protected DatabaseRef dbRef() {
        return dbRef;
    }

    public int getSkippedUsers() {
        return skippedUsers.get();
    }

    public boolean wereTooManyUsersSkipped() {
        return skippedUsers.get() > maxSkippedUsers;
    }

    @Override
    public void close() {
        logger.info("Stopping dump, skippedUsers: {}", skippedUsers);
        prefetchedUsers.clear();
        firstWrite.clear();
        pusherExecutor.shutdown();
        fetcherSupplier.shutdown();
        cleanTemp(processedCollections, REDUMPED_PARTITION_PREFIX);
    }
}
