package ru.yandex.direct.core.entity.banner.service.validation.pricesales;

import java.util.HashSet;
import java.util.Map;
import java.util.Set;
import java.util.stream.Collectors;

import one.util.streamex.EntryStream;

import ru.yandex.direct.core.entity.adgroup.repository.AdGroupRepository;
import ru.yandex.direct.core.entity.banner.model.BannerWithSystemFields;
import ru.yandex.direct.core.entity.banner.repository.BannerModerationRepository;
import ru.yandex.direct.core.entity.campaign.model.Campaign;
import ru.yandex.direct.core.entity.campaign.model.CampaignType;
import ru.yandex.direct.core.entity.campaign.model.CpmPriceCampaign;
import ru.yandex.direct.core.entity.campaign.repository.CampaignTypedRepository;
import ru.yandex.direct.core.entity.pricepackage.service.PricePackageService;
import ru.yandex.direct.model.ModelChanges;
import ru.yandex.direct.regions.GeoTree;
import ru.yandex.direct.validation.builder.Constraint;
import ru.yandex.direct.validation.builder.ItemValidationBuilder;
import ru.yandex.direct.validation.result.Defect;
import ru.yandex.direct.validation.result.ValidationResult;
import ru.yandex.direct.validation.wrapper.DefaultValidator;

import static ru.yandex.direct.core.entity.adgroup.service.AdGroupCpmPriceUtils.isDefaultPriority;
import static ru.yandex.direct.core.entity.banner.service.validation.defects.BannerDefects.priceSalesCampaignGeoOverlapsBannerMinusGeo;
import static ru.yandex.direct.validation.builder.Constraint.fromPredicate;

/**
 * Валидатор проверяет, что нельзя возобновить баннер, если его минус-регионы пересекаются с гео кампании.
 */
public class BannerPriceSalesMinusGeoValidator implements DefaultValidator<ModelChanges<BannerWithSystemFields>> {

    private final CampaignTypedRepository campaignTypedRepository;
    private final BannerModerationRepository bannerModerationRepository;

    private final int shard;
    private final Map<Long, BannerWithSystemFields> banners;
    private final Map<Long, HashSet<Long>> cpmPriceBannersMinusGeo;
    private final Map<Long, HashSet<Long>> cpmPriceCampaignsGeo;
    private final Map<Long, Long> cpmPriceAdGroupsPriority;
    private final GeoTree priceSalesGeoTree;

    BannerPriceSalesMinusGeoValidator(
            int shard,
            Map<Long, BannerWithSystemFields> banners,
            Map<Long, Campaign> campaigns,
            CampaignTypedRepository campaignTypedRepository,
            BannerModerationRepository bannerModerationRepository,
            AdGroupRepository adGroupRepository,
            PricePackageService pricePackageService
    ) {
        this.shard = shard;
        this.banners = banners;
        this.campaignTypedRepository = campaignTypedRepository;
        this.bannerModerationRepository = bannerModerationRepository;

        Set<Long> cpmPriceCampaignIds = collectCpmPriceCampaignIds(campaigns);
        Set<Long> cpmPriceAdGroupIds = collectCpmPriceAdGroupIds(banners, cpmPriceCampaignIds);
        Set<Long> cpmPriceBannerIds = collectCpmPriceBannerIds(banners, cpmPriceCampaignIds);

        cpmPriceCampaignsGeo = getCpmPriceCampaignsGeo(cpmPriceCampaignIds);
        cpmPriceAdGroupsPriority = adGroupRepository.getAdGroupsPriority(shard, cpmPriceAdGroupIds);
        cpmPriceBannersMinusGeo = getBannersMinusGeo(cpmPriceBannerIds);

        priceSalesGeoTree = pricePackageService.getGeoTree();
    }

    private Set<Long> collectCpmPriceCampaignIds(Map<Long, Campaign> campaigns) {
        return EntryStream.of(campaigns)
                .mapValues(Campaign::getType)
                .filterValues(CampaignType.CPM_PRICE::equals)
                .keys()
                .toSet();
    }

    private Set<Long> collectCpmPriceAdGroupIds(Map<Long, BannerWithSystemFields> banners,
                                                Set<Long> cpmPriceCampaignIds) {
        return banners.values().stream()
                .filter(banner -> cpmPriceCampaignIds.contains(banner.getCampaignId()))
                .map(BannerWithSystemFields::getAdGroupId)
                .collect(Collectors.toSet());
    }

    private Set<Long> collectCpmPriceBannerIds(Map<Long, BannerWithSystemFields> banners,
                                               Set<Long> cpmPriceCampaignIds) {
        return banners.values().stream()
                .filter(banner -> cpmPriceCampaignIds.contains(banner.getCampaignId()))
                .map(BannerWithSystemFields::getId)
                .collect(Collectors.toSet());
    }

    private Map<Long, HashSet<Long>> getCpmPriceCampaignsGeo(Set<Long> cpmPriceCampaignIds) {
        Map<Long, CpmPriceCampaign> cpmPriceCampaigns =
                (Map<Long, CpmPriceCampaign>) campaignTypedRepository.getTypedCampaignsMap(shard, cpmPriceCampaignIds);
        return EntryStream.of(cpmPriceCampaigns)
                .mapValues(campaign -> campaign.getFlightTargetingsSnapshot().getGeoExpanded())
                .mapValues(HashSet::new)
                .toMap();
    }

    private Map<Long, HashSet<Long>> getBannersMinusGeo(Set<Long> bannerIds) {
        return EntryStream.of(bannerModerationRepository.getBannersMinusGeo(shard, bannerIds))
                .mapValues(HashSet::new)
                .toMap();
    }

    @Override
    public ValidationResult<ModelChanges<BannerWithSystemFields>, Defect> apply(ModelChanges<BannerWithSystemFields> modelChanges) {
        return ItemValidationBuilder.of(modelChanges, Defect.class)
                .check(bannerMinusRegionsNotSubtractCampaignGeo())
                .getResult();
    }

    private Constraint<ModelChanges<BannerWithSystemFields>, Defect> bannerMinusRegionsNotSubtractCampaignGeo() {
        return fromPredicate(mc -> {
            Long bannerId = mc.getId();
            BannerWithSystemFields banner = banners.get(bannerId);

            if (!cpmPriceBannersMinusGeo.containsKey(bannerId)) {
                return true;
            }
            if (!isDefaultPriority(cpmPriceAdGroupsPriority.get(banner.getAdGroupId()))) {
                return true;
            }

            Set<Long> bannerMinusGeo = cpmPriceBannersMinusGeo.get(bannerId);
            Set<Long> campaignGeo = cpmPriceCampaignsGeo.get(banner.getCampaignId());
            return !priceSalesGeoTree.isAnyRegionOrSubRegionIncludedIn(bannerMinusGeo, campaignGeo);
        }, priceSalesCampaignGeoOverlapsBannerMinusGeo());
    }
}
