package ru.yandex.mail.cerberus.worker.yt_tasks.staff_sync.sync;

import lombok.AllArgsConstructor;
import lombok.Value;
import lombok.val;
import one.util.streamex.StreamEx;
import org.slf4j.Logger;
import reactor.core.publisher.Mono;
import ru.yandex.mail.cerberus.worker.api.TaskExecutionContext;
import ru.yandex.mail.cerberus.yt.staff.StaffEntity;
import ru.yandex.mail.micronaut.common.Async;

import java.time.OffsetDateTime;
import java.util.ArrayList;
import java.util.List;
import java.util.Map;
import java.util.Optional;
import java.util.concurrent.CompletableFuture;
import java.util.function.BinaryOperator;
import java.util.stream.Collectors;

import static com.ea.async.Async.await;
import static java.lang.Boolean.FALSE;
import static java.lang.Boolean.TRUE;
import static java.util.Collections.unmodifiableList;
import static java.util.function.Function.identity;
import static ru.yandex.mail.micronaut.common.Async.runIf;
import static ru.yandex.mail.micronaut.common.CerberusUtils.mapToList;
import static ru.yandex.mail.micronaut.common.CerberusUtils.mapToMap;

@AllArgsConstructor
class DefaultSynchronizer<ID, Ctx, Dto, StaffDto> implements Synchronizer {
    @Value
    private static class Diff<Dto, StaffDto> {
        List<StaffDto> newDtos;
        List<Dto> changedDtos;

        public boolean isEmpty() {
            return newDtos.isEmpty() && changedDtos.isEmpty();
        }
    }

    private final SyncProvider<ID, Ctx, Dto, StaffDto> provider;

    private Logger log() {
        return provider.getLog();
    }

    private String entityName() {
        return provider.getSyncEntityName();
    }

    private Diff<Dto, StaffDto> findDiff(Map<ID, Dto> dtoById, List<StaffDto> chunk, Ctx context) {
        val added = new ArrayList<StaffDto>();
        val updated = new ArrayList<Dto>();

        for (val staffDto : chunk) {
            val id = provider.getIdForStaffDto(staffDto);

            val decision = provider.isReadyToSync(staffDto, context);
            if (decision instanceof SyncDecision.DontSync) {
                val dontSync = (SyncDecision.DontSync) decision;
                log().info("{} with id={} sync skipped, reason='{}'", entityName(), id, dontSync.getReasons());
                continue;
            }

            val dto = dtoById.get(id);
            if (dto == null) {
                if (decision instanceof SyncDecision.UpdateOnly) {
                    val updateOnly = (SyncDecision.UpdateOnly) decision;
                    log().info("{} with id={} new instance sync skipped, reason='{}'", entityName(), id, updateOnly.getReasons());
                } else {
                    added.add(staffDto);
                }
            } else {
                val newDto = provider.update(dto, staffDto, context);
                if (newDto.equals(dto)) {
                    log().info("{} with id={} is up to date", entityName(), id);
                } else {
                    log().info("Update from:\n{}\nto:\n{}", dto, newDto);
                    updated.add(newDto);
                }
            }
        }

        return new Diff<>(unmodifiableList(added), unmodifiableList(updated));
    }

    private CompletableFuture<Void> commitNew(List<StaffDto> staffDtos, Ctx context) {
        val dtos = await(provider.commitNew(staffDtos, context));
        log().info("Add {} new {}s:\n{}", dtos.size(), entityName(), dtos);
        return Async.done();
    }

    private CompletableFuture<Void> commitChanged(List<Dto> dtos, Ctx context) {
        await(provider.commitChanged(dtos, context));
        log().info("Update {} {}s", dtos.size(), entityName());
        return Async.done();
    }

    private CompletableFuture<Void> commitDiff(Diff<Dto, StaffDto> diff, Ctx context) {
        val newDtos = diff.getNewDtos();
        val changedDtos = diff.getChangedDtos();

        return CompletableFuture.allOf(
            runIf(!newDtos.isEmpty(), () -> commitNew(newDtos, context)),
            runIf(!changedDtos.isEmpty(), () -> commitChanged(changedDtos, context))
        );
    }

    private static final Boolean VALID = TRUE;
    private static final Boolean INVALID = FALSE;

    private CompletableFuture<Void> synchronizeChunk(Batch<StaffDto, Ctx> batch) {
        val chunk = batch.getValues();
        val context = batch.getContext();

        log().info("Process chunk with size: {}", chunk.size());

        if (chunk.isEmpty()) {
            log().info("An empty chunk received, nothing to process");
            return Async.done();
        }

        val elementsByValidity = StreamEx.of(chunk)
            .partitioningBy(StaffEntity::isValid, Collectors.toUnmodifiableList());

        val invalidElements = elementsByValidity.get(INVALID);
        if (!invalidElements.isEmpty()) {
            val invalidNodes = StreamEx.of(invalidElements)
                .select(StaffEntity.Invalid.class)
                .map(StaffEntity.Invalid::getNode)
                .toImmutableList();
            log().error("Skipping invalid {}s: {}", entityName(), invalidNodes);
        }

        val validElements = mapToList(elementsByValidity.get(VALID), StaffEntity::getEntity);
        if (validElements.isEmpty()) {
            log().info("Valid {}s not found in this chunk", entityName());
            return Async.done();
        }

        val existing = await(provider.findExisting(validElements, context));
        log().info("{} existing {}s found", existing.size(), entityName());

        val dtoById = mapToMap(existing, provider::getIdForDto, identity());
        val diff = findDiff(dtoById, validElements, context);

        log().info("'{}' new and '{}' changed {}s found",
            diff.getNewDtos().size(), diff.getChangedDtos().size(), entityName());

        if (diff.isEmpty()) {
            return Async.done();
        } else {
            return commitDiff(diff, context);
        }
    }

    @Override
    public String getName() {
        return entityName() + " synchronizer";
    }

    @Override
    public int getOrder() {
        return provider.getOrder().getValue();
    }

    private static final int MAX_CONCURRENT_BATCHES = 2;
    private static final int PREFETCH = 1;

    @Override
    public Mono<OffsetDateTime> synchronize(TaskExecutionContext executionContext, Optional<OffsetDateTime> syncPoint) {
        log().info("Start staff {} synchronization with chunk size = {}",
            entityName(), provider.getMaxChunkSize());

        return provider.batches(syncPoint)
            .doOnEach(batch -> {
                executionContext.setInterruptionPoint();
                if (batch.isOnComplete()) {
                    log().info("No more chunks");
                } else if (batch.hasValue()) {
                    log().info("Chunk found [size={}]", batch.get().getValues().size());
                }
            })
            .flatMap(batch -> {
                val maxModificationTime = batch.findMaxModificationTime(provider::getModifiedAt);

                return Mono.fromFuture(synchronizeChunk(batch))
                    .then(Mono.justOrEmpty(maxModificationTime));
            }, MAX_CONCURRENT_BATCHES, PREFETCH)
            .reduce(BinaryOperator.maxBy(OffsetDateTime.timeLineOrder()))
            .doOnSuccess(ignored -> {
                log().info("{}s successfully synchronized", entityName());
            })
            .doOnError(e -> {
                log().error("{}s synchronization failed", entityName(), e);
            });
    }
}
