package ru.yandex.intranet.d.tms.jobs;

import java.time.Duration;
import java.util.concurrent.TimeUnit;

import com.yandex.ydb.table.transaction.TransactionMode;
import it.unimi.dsi.fastutil.longs.LongArrayList;
import it.unimi.dsi.fastutil.longs.LongList;
import it.unimi.dsi.fastutil.longs.LongOpenHashSet;
import it.unimi.dsi.fastutil.longs.LongSet;
import it.unimi.dsi.fastutil.longs.LongSets;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

import ru.yandex.direct.scheduler.Hourglass;
import ru.yandex.direct.scheduler.support.BaseDirectJob;
import ru.yandex.intranet.d.dao.Tenants;
import ru.yandex.intranet.d.dao.services.ServicesDao;
import ru.yandex.intranet.d.datasource.model.YdbTableClient;
import ru.yandex.intranet.d.util.Long2LongMultimap;
import ru.yandex.intranet.d.util.OneShotStopWatch;

/**
 * Cron job to collect all parents for each service.
 *
 * @author Vladimir Zaytsev <vzay@yandex-team.ru>
 * @since 03.11.2020
 */
@Hourglass(periodInSeconds = 1800)
public class CollectServicesParents extends BaseDirectJob {

    private static final Logger LOG = LoggerFactory.getLogger(CollectServicesParents.class);

    private static final int BATCH_SIZE = 100;
    private static final Duration TIMEOUT_TO_READ_SERVICES = Duration.ofMinutes(60);
    private static final Duration TIMEOUT_TO_WRITE_BATCH = Duration.ofMinutes(60);

    private final YdbTableClient tableClient;
    private final ServicesDao servicesDao;

    public CollectServicesParents(
            YdbTableClient tableClient,
            ServicesDao servicesDao
    ) {
        this.tableClient = tableClient;
        this.servicesDao = servicesDao;
    }

    @Override
    public void execute() {
        LOG.info("Start collecting services parents...");
        OneShotStopWatch stopWatch = new OneShotStopWatch();
        Long2LongMultimap serviceIdsByParentId =
                tableClient.usingSessionFluxRetryable(servicesDao::getAllServiceIdsWithParents)
                        .reduce(new Long2LongMultimap(), (map, service) ->
                                map.put(service.getParentId(), service.getServiceId())
                        )
                        .block(TIMEOUT_TO_READ_SERVICES);
        if (serviceIdsByParentId == null || serviceIdsByParentId.size() == 0) {
            throw new RuntimeException("Empty `abc_sync/public_services_service` table.");
        }

        Long2LongMultimap parentsByServiceId = collectAllParents(serviceIdsByParentId);

        parentsByServiceId.forEachBatch(BATCH_SIZE, (Long2LongMultimap batch) ->
                tableClient.usingSessionMonoRetryable(session ->
                        session.usingTxMonoRetryable(TransactionMode.SERIALIZABLE_READ_WRITE, ts ->
                                servicesDao.upsertAllParentsRetryable(ts, batch, Tenants.DEFAULT_TENANT_ID)
                        ))
                        .block(TIMEOUT_TO_WRITE_BATCH)
        );
        LOG.info("Successfully finished collecting services parents for {} services in {} seconds",
                parentsByServiceId.size(), stopWatch.elapsed(TimeUnit.SECONDS));
    }

    /**
     * Collect all parents for each service.
     *
     * @param serviceIdsByParentId Mandatory contain the 0 key for services without parent
     */
    static Long2LongMultimap collectAllParents(Long2LongMultimap serviceIdsByParentId) {
        if (serviceIdsByParentId.get(0) == null) {
            throw new IllegalArgumentException(
                    "Param serviceIdsByParentId mast contain the 0 key for services without parent");
        }
        Long2LongMultimap allParentsByServiceId = new Long2LongMultimap();
        LongList parentsQueue = new LongArrayList();
        parentsQueue.add(0);
        int nextParentsQueueIndex = 0;
        while (nextParentsQueueIndex < parentsQueue.size()) {
            long parentId = parentsQueue.getLong(nextParentsQueueIndex);
            LongSet parentsOfParent = parentId > 0 ?
                    allParentsByServiceId.get(parentId) :
                    LongSets.EMPTY_SET;
            LongSet serviceIds = serviceIdsByParentId.get(parentId);
            if (serviceIds != null) {
                serviceIds.forEach((long serviceId) -> {
                    LongOpenHashSet parents = new LongOpenHashSet(parentsOfParent);
                    parents.add(parentId);
                    allParentsByServiceId.resetAll(serviceId, parents);
                    parentsQueue.add(serviceId);
                });
            }
            nextParentsQueueIndex++;
        }
        return allParentsByServiceId;
    }
}
