package ru.yandex.direct.oneshot.oneshots.updatecreativesgeo;

import java.util.Collection;
import java.util.List;

import javax.annotation.Nullable;
import javax.annotation.ParametersAreNonnullByDefault;

import org.jooq.DSLContext;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
import org.springframework.stereotype.Component;

import ru.yandex.direct.core.entity.banner.type.creative.BannerCreativeRepository;
import ru.yandex.direct.core.entity.client.service.ClientGeoService;
import ru.yandex.direct.dbutil.model.ClientId;
import ru.yandex.direct.dbutil.wrapper.DslContextProvider;
import ru.yandex.direct.oneshot.base.ShardedIterativeOneshotWithoutInput;
import ru.yandex.direct.oneshot.worker.def.Approvers;
import ru.yandex.direct.oneshot.worker.def.Multilaunch;
import ru.yandex.direct.oneshot.worker.def.PausedStatusOnFail;
import ru.yandex.direct.oneshot.worker.def.Retries;

@Component
@Approvers({"elwood"})
@Multilaunch
@Retries(value = 10, timeoutSeconds = 60)
@PausedStatusOnFail
@ParametersAreNonnullByDefault
public class UpdateCreativesGeo extends ShardedIterativeOneshotWithoutInput<State> {
    private static final Logger logger = LoggerFactory.getLogger(UpdateCreativesGeo.class);
    private static final int BATCH_SIZE = 2000;

    private final ClientGeoService clientGeoService;
    private final BannerCreativeRepository bannerCreativeRepository;
    private final DslContextProvider dslContextProvider;
    private final CreativesRepository creativesRepository;

    public UpdateCreativesGeo(
            ClientGeoService clientGeoService,
            BannerCreativeRepository bannerCreativeRepository,
            DslContextProvider dslContextProvider,
            CreativesRepository creativesRepository
    ) {
        this.clientGeoService = clientGeoService;
        this.bannerCreativeRepository = bannerCreativeRepository;
        this.dslContextProvider = dslContextProvider;
        this.creativesRepository = creativesRepository;
    }

    @Nullable
    @Override
    protected State execute(@Nullable State prevState, int shard) {
        if (prevState == null) {
            prevState = new State().withLastCreativeId(-1L);
        }

        logger.info("Current state {} for shard {}", prevState, shard);
        var context = dslContextProvider.ppc(shard);

        var creativeIdsToClientId =
                creativesRepository.getCreativeIdsToClientIds(context, BATCH_SIZE, prevState.getLastCreativeId());
        if (creativeIdsToClientId.isEmpty()) {
            prevState = null;
            logger.info("Processing finished for shard " + shard);
        } else {
            for (var entry : creativeIdsToClientId.entrySet()) {
                updateCreativesGeo(context, entry.getValue(), ClientId.fromLong(entry.getKey()));
            }

            var lastCreativeId =
                    creativeIdsToClientId.values().stream().flatMap(Collection::stream).max(Long::compareTo);
            prevState.setLastCreativeId(lastCreativeId.get());
            logger.info("Processed {} creatives on shard {}",
                    creativeIdsToClientId.values().stream().mapToLong(Collection::size).sum(), shard);
        }

        logger.info("Next state {} for shard {}", prevState, shard);
        return prevState;
    }

    private void updateCreativesGeo(DSLContext context, List<Long> creativeIds, ClientId clientId) {
        if (creativeIds.isEmpty()) {
            return;
        }

        var geoTree = clientGeoService.getClientTranslocalGeoTree(clientId);
        var geoByCreativeId = bannerCreativeRepository.getJoinedGeo(context.configuration(), geoTree, creativeIds);
        creativesRepository.updateCreativesGeo(context, geoByCreativeId);
    }
}
