package ru.yandex.direct.oneshot.oneshots.cpmbannerfixminprice;

import java.math.BigDecimal;
import java.util.List;
import java.util.Map;
import java.util.concurrent.atomic.AtomicInteger;
import java.util.function.ToIntFunction;

import javax.annotation.Nullable;

import com.google.common.collect.Iterables;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
import org.springframework.beans.factory.annotation.Autowired;
import org.springframework.stereotype.Component;

import ru.yandex.direct.currency.Currencies;
import ru.yandex.direct.currency.Currency;
import ru.yandex.direct.currency.CurrencyCode;
import ru.yandex.direct.oneshot.oneshots.cpmbannerfixminprice.repository.OneshotBidRepository;
import ru.yandex.direct.oneshot.oneshots.sharded.State;
import ru.yandex.direct.oneshot.worker.def.Approvers;
import ru.yandex.direct.oneshot.worker.def.Multilaunch;
import ru.yandex.direct.oneshot.worker.def.ShardedOneshot;
import ru.yandex.direct.validation.result.Defect;
import ru.yandex.direct.validation.result.ValidationResult;

@Component
@Multilaunch
@Approvers({"ppalex"})
public class CpmBannerFixMinimumPriceOneshot implements ShardedOneshot<Void, State> {
    private static final Logger logger = LoggerFactory.getLogger(CpmBannerFixMinimumPriceOneshot.class);
    private static final int CHUNK_SIZE = 500;

    private final OneshotBidRepository bidRepository;

    @Autowired
    public CpmBannerFixMinimumPriceOneshot(OneshotBidRepository bidRepository) {
        this.bidRepository = bidRepository;
    }

    @Nullable
    @Override
    public State execute(Void inputData, State state, int shard) {
        state = state != null ? state : new State().withIteration(0);

        CurrencyCode[] currencyCodes = CurrencyCode.values();
        Currency currency = Currencies.getCurrency(currencyCodes[state.getIteration()]);

        logger.info("shard: {}, {}, going to update values with low price (minimum is {})",
                shard, currency.getCode(), currency.getMinCpmPrice());

        fixBidsAndBidsBasePrice(shard, currency);

        fixBidsManualPrice(shard, currency);

        if (state.getIteration() == currencyCodes.length - 1) {
            logger.info("shard: {}, {}, it was a last iteration", shard, currency.getCode());
            return null;
        }
        return new State().withIteration(state.getIteration() + 1);
    }

    private void fixBidsAndBidsBasePrice(int shard,
                                         Currency currency) {
        Map<Long, BigDecimal> bidIdsToPrice =
                bidRepository.getBidIdsToPriceForCpmBannerWithTooLowPrice(shard, currency);
        Map<Long, BigDecimal> bidBaseIdsToPrice = bidRepository
                .getBidBaseIdsToPriceForCpmBannerWithTooLowPrice(shard, currency);

        fixPrice(shard, bidIdsToPrice, currency, chunk -> bidRepository
                .updatePriceForBidsIfItsLower(shard, chunk, currency.getMinCpmPrice()), "BIDS");
        fixPrice(shard, bidBaseIdsToPrice, currency, chunk -> bidRepository
                .updatePriceForBidsBaseIfItsLower(shard, chunk, currency.getMinCpmPrice()), "BIDS_BASE");
    }

    private void fixBidsManualPrice(int shard,
                                    Currency currency) {
        Map<Long, BigDecimal> bidsManualPriceIdsToPrice = bidRepository
                .getBidsManualPriceIdsToPriceForCpmBannerWithTooLowPrice(shard, currency);
        fixPrice(shard, bidsManualPriceIdsToPrice, currency, chunk -> bidRepository
                .updatePriceForBidsManualPriceIfItsLower(shard, chunk, currency.getMinCpmPrice()), "BIDS_MANUAL_PRICE");
    }

    private void fixPrice(int shard,
                          Map<Long, BigDecimal> idToPrice,
                          Currency currency,
                          ToIntFunction<List<Long>> updateFunction,
                          String tableName) {
        logger.info("shard: {}, {}, found {} rows in {} table with low price",
                shard, currency.getCode(), idToPrice.size(), tableName);
        if (idToPrice.isEmpty()) {
            return;
        }

        AtomicInteger rowsUpdated = new AtomicInteger();
        Iterables.partition(idToPrice.keySet(), CHUNK_SIZE).forEach(chunk -> {

            chunk.forEach(id -> logger.info("shard: {}, {}, update {} row bid_id: {}, old price: {}, new price: {}",
                    shard, currency.getCode(), tableName, id, idToPrice.get(id), currency.getMinCpmPrice()));

            rowsUpdated.addAndGet(updateFunction.applyAsInt(chunk));
        });
        logger.info("shard: {}, {}, updated {} rows in {} table",
                shard, currency.getCode(), rowsUpdated.get(), tableName);
    }

    @Override
    public ValidationResult<Void, Defect> validate(Void inputData) {
        return ValidationResult.success(inputData);
    }
}
