package ru.yandex.direct.grid.processing.service.forecast;

import java.math.BigDecimal;
import java.util.List;
import java.util.Map;
import java.util.Set;

import com.google.common.collect.HashBasedTable;
import com.google.common.collect.Table;
import one.util.streamex.EntryStream;
import one.util.streamex.StreamEx;
import org.apache.commons.lang3.StringUtils;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
import org.springframework.beans.factory.annotation.Autowired;
import org.springframework.stereotype.Service;

import ru.yandex.direct.advq.SearchKeywordResult;
import ru.yandex.direct.advq.search.AdvqRequestKeyword;
import ru.yandex.direct.advq.search.SearchItem;
import ru.yandex.direct.core.entity.auction.exception.BsAuctionUnavailableException;
import ru.yandex.direct.core.entity.keyword.model.ForecastCtr;
import ru.yandex.direct.core.entity.keyword.service.KeywordForecastService;
import ru.yandex.direct.core.entity.showcondition.model.ShowStatRequest;
import ru.yandex.direct.core.entity.showcondition.service.ShowStatService;
import ru.yandex.direct.core.entity.statistics.AdvqHits;
import ru.yandex.direct.core.entity.statistics.AdvqHitsDeviceType;
import ru.yandex.direct.core.entity.statistics.repository.AdvqHitsRepository;
import ru.yandex.direct.dbutil.model.ClientId;
import ru.yandex.direct.dbutil.sharding.ShardHelper;
import ru.yandex.direct.grid.processing.context.container.GridGraphQLContext;
import ru.yandex.direct.grid.processing.model.forecast.GdAge;
import ru.yandex.direct.grid.processing.model.forecast.GdDeviceType;
import ru.yandex.direct.grid.processing.model.forecast.GdForecast;
import ru.yandex.direct.grid.processing.model.forecast.GdForecastContainer;
import ru.yandex.direct.grid.processing.model.forecast.GdGender;
import ru.yandex.direct.grid.processing.model.forecast.GdShowStat;
import ru.yandex.direct.grid.processing.model.forecast.GdShowStatContainer;
import ru.yandex.direct.grid.processing.model.showcondition.GdAuctionData;
import ru.yandex.direct.grid.processing.service.forecast.clicks.ClicksForecaster;
import ru.yandex.direct.grid.processing.service.showcondition.keywords.ShowConditionDataService;
import ru.yandex.direct.utils.CollectionUtils;

import static com.google.common.base.Preconditions.checkState;
import static java.lang.Math.round;
import static java.util.function.Function.identity;
import static ru.yandex.direct.grid.processing.model.forecast.GdAge.ALL;
import static ru.yandex.direct.grid.processing.model.forecast.GdAge.UNKNOWN;
import static ru.yandex.direct.grid.processing.model.forecast.GdAge._0_17;
import static ru.yandex.direct.grid.processing.model.forecast.GdAge._18_24;
import static ru.yandex.direct.grid.processing.model.forecast.GdAge._25_34;
import static ru.yandex.direct.grid.processing.model.forecast.GdAge._35_44;
import static ru.yandex.direct.grid.processing.model.forecast.GdAge._45_54;
import static ru.yandex.direct.grid.processing.model.forecast.GdAge._55_;
import static ru.yandex.direct.utils.CommonUtils.ifNotNull;
import static ru.yandex.direct.utils.FunctionalUtils.listToSet;
import static ru.yandex.direct.utils.FunctionalUtils.mapList;
import static ru.yandex.direct.utils.FunctionalUtils.mapSet;

@Service
public class ForecastService {

    private static final Logger logger = LoggerFactory.getLogger(ForecastService.class);

    private static final Long GLOBAL_REGION_ID = 10000L;

    private static final Table<GdGender, GdAge, Double> SOCDEM_REACH_PCT_BY_AGE_AND_GENDER;

    static {
        // Значения взяты из https://st.yandex-team.ru/DIRECT-132191#5fd8fbbae7fe074a52f8ad93
        SOCDEM_REACH_PCT_BY_AGE_AND_GENDER = HashBasedTable.create();

        SOCDEM_REACH_PCT_BY_AGE_AND_GENDER.put(GdGender.FEMALE, _0_17, 0.020);
        SOCDEM_REACH_PCT_BY_AGE_AND_GENDER.put(GdGender.FEMALE, _18_24, 0.036);
        SOCDEM_REACH_PCT_BY_AGE_AND_GENDER.put(GdGender.FEMALE, _25_34, 0.103);
        SOCDEM_REACH_PCT_BY_AGE_AND_GENDER.put(GdGender.FEMALE, _35_44, 0.143);
        SOCDEM_REACH_PCT_BY_AGE_AND_GENDER.put(GdGender.FEMALE, _45_54, 0.098);
        SOCDEM_REACH_PCT_BY_AGE_AND_GENDER.put(GdGender.FEMALE, _55_, 0.096);
        SOCDEM_REACH_PCT_BY_AGE_AND_GENDER.put(GdGender.FEMALE, UNKNOWN, 0.013);

        SOCDEM_REACH_PCT_BY_AGE_AND_GENDER.put(GdGender.MALE, _0_17, 0.014);
        SOCDEM_REACH_PCT_BY_AGE_AND_GENDER.put(GdGender.MALE, _18_24, 0.031);
        SOCDEM_REACH_PCT_BY_AGE_AND_GENDER.put(GdGender.MALE, _25_34, 0.107);
        SOCDEM_REACH_PCT_BY_AGE_AND_GENDER.put(GdGender.MALE, _35_44, 0.120);
        SOCDEM_REACH_PCT_BY_AGE_AND_GENDER.put(GdGender.MALE, _45_54, 0.082);
        SOCDEM_REACH_PCT_BY_AGE_AND_GENDER.put(GdGender.MALE, _55_, 0.083);
        SOCDEM_REACH_PCT_BY_AGE_AND_GENDER.put(GdGender.MALE, UNKNOWN, 0.010);

        SOCDEM_REACH_PCT_BY_AGE_AND_GENDER.put(GdGender.UNKNOWN, _0_17, 0.001);
        SOCDEM_REACH_PCT_BY_AGE_AND_GENDER.put(GdGender.UNKNOWN, _18_24, 0.001);
        SOCDEM_REACH_PCT_BY_AGE_AND_GENDER.put(GdGender.UNKNOWN, _25_34, 0.004);
        SOCDEM_REACH_PCT_BY_AGE_AND_GENDER.put(GdGender.UNKNOWN, _35_44, 0.004);
        SOCDEM_REACH_PCT_BY_AGE_AND_GENDER.put(GdGender.UNKNOWN, _45_54, 0.003);
        SOCDEM_REACH_PCT_BY_AGE_AND_GENDER.put(GdGender.UNKNOWN, _55_, 0.006);
        SOCDEM_REACH_PCT_BY_AGE_AND_GENDER.put(GdGender.UNKNOWN, UNKNOWN, 0.025);
    }

    private static final Double CONVERSIONS_TO_CLICKS_RATE = 0.1;

    private final ShowStatService showStatService;
    private final AdvqHitsRepository advqHitsRepository;
    private final ShowConditionDataService showConditionDataService;
    private final ShardHelper shardHelper;
    private final ClicksForecaster clicksForecaster;
    private final KeywordForecastService keywordForecastService;

    @Autowired
    public ForecastService(ShowStatService showStatService,
                           AdvqHitsRepository advqHitsRepository,
                           ShowConditionDataService showConditionDataService,
                           ShardHelper shardHelper,
                           ClicksForecaster clicksForecaster,
                           KeywordForecastService keywordForecastService) {
        this.showStatService = showStatService;
        this.advqHitsRepository = advqHitsRepository;
        this.showConditionDataService = showConditionDataService;
        this.shardHelper = shardHelper;
        this.clicksForecaster = clicksForecaster;
        this.keywordForecastService = keywordForecastService;
    }

    public GdShowStat getShowStat(GdShowStatContainer input) {
        return input.getKeywords() != null ? getKeywordsShowStat(input) : getNonKeywordsShowStat(input);
    }

    private GdShowStat getKeywordsShowStat(GdShowStatContainer input) {
        Map<AdvqRequestKeyword, SearchKeywordResult> keywordShowStatByPhrase = getKeywordShowStatByPhrase(input);
        Long shows = EntryStream.of(keywordShowStatByPhrase)
                .values()
                .map(SearchKeywordResult::getResult)
                .nonNull()
                .map(SearchItem::getTotalCount)
                .reduce(Long::sum)
                .orElse(0L);

        return new GdShowStat()
                .withShows(shows);
    }

    private Map<AdvqRequestKeyword, SearchKeywordResult> getKeywordShowStatByPhrase(GdShowStatContainer input) {
        ShowStatRequest request = convertRequest(input);

        Map<AdvqRequestKeyword, SearchKeywordResult> resultByKeyword = showStatService.getStatShowByPhrase(request);
        applySocDem(resultByKeyword, input);

        return resultByKeyword;
    }

    private static ShowStatRequest convertRequest(GdShowStatContainer input) {
        var keywords = input.getKeywords();
        var request = new ShowStatRequest()
                .withCommonMinusPhrases(keywords.getCommonMinusPhrases())
                .withCampaignId(keywords.getCampaignId())
                .withAdGroupId(keywords.getAdGroupId())
                .withLibraryMinusPhrasesIds(keywords.getLibraryMinusPhrasesIds())
                .withDeviceTypes(mapSet(input.getDeviceTypes(), GdDeviceType::toSource))
                .withGender(GdGender.toSource(input.getGender()))
                .withAges(mapSet(input.getAges(), GdAge::toSource))
                .withIsContentPromotionVideo(false)
                .withNeedSearchedWith(false);

        if (!CollectionUtils.isEmpty(input.getGeo())) {
            request.withGeo(StringUtils.join(input.getGeo(), ","));
        }

        if (input.getKeywords().getKeywordIds() != null) {
            List<AdvqRequestKeyword> advqRequestKeywords = StreamEx.zip(
                    keywords.getKeywordIds(),
                    keywords.getPhrases(),
                    (keywordId, phrase) -> new AdvqRequestKeyword(phrase, keywordId))
                    .toList();
            request.withKeywords(advqRequestKeywords);
        } else {
            request.withPhrases(keywords.getPhrases());
        }

        return request;
    }

    private GdShowStat getNonKeywordsShowStat(GdShowStatContainer input) {
        List<Long> regionIds = normalizeGeo(input.getGeo());
        List<AdvqHitsDeviceType> deviceTypes = normalizeDeviceTypes(input.getDeviceTypes());

        List<AdvqHits> advqHits = advqHitsRepository.getAdvqHits(getRegionIdsForQuery(regionIds), deviceTypes);

        return new GdShowStat()
                .withShows(calcShows(advqHits, listToSet(regionIds)));
    }

    private static List<Long> normalizeGeo(List<Long> geo) {
        return geo == null ? List.of(GLOBAL_REGION_ID) : geo;
    }

    private static List<AdvqHitsDeviceType> normalizeDeviceTypes(Set<GdDeviceType> deviceTypes) {
        if (deviceTypes == null || deviceTypes.contains(GdDeviceType.ALL)) {
            return List.of(AdvqHitsDeviceType.DESKTOP, AdvqHitsDeviceType.PHONE, AdvqHitsDeviceType.TABLET);
        }

        return mapList(deviceTypes, ForecastService::convertSingleDeviceType);
    }

    private static List<Long> getRegionIdsForQuery(List<Long> regionIds) {
        return StreamEx.of(regionIds)
                .map(Math::abs)
                .toList();
    }

    private static AdvqHitsDeviceType convertSingleDeviceType(GdDeviceType deviceType) {
        switch (deviceType) {
            case DESKTOP:
                return AdvqHitsDeviceType.DESKTOP;
            case PHONE:
                return AdvqHitsDeviceType.PHONE;
            case TABLET:
                return AdvqHitsDeviceType.TABLET;
            default:
                throw new IllegalArgumentException();
        }
    }

    private static Long calcShows(List<AdvqHits> advqHits, Set<Long> signedRegionIds) {
        return StreamEx.of(advqHits)
                .map(hits -> {
                    Long regionId = hits.getRegionId();

                    boolean containsPos = signedRegionIds.contains(regionId);
                    boolean containsNeg = signedRegionIds.contains(-regionId);
                    checkState(containsPos ^ containsNeg);

                    return containsPos ? hits.getShows() : -hits.getShows();
                })
                .reduce(Long::sum)
                .orElse(0L);
    }

    private void applySocDem(Map<AdvqRequestKeyword, SearchKeywordResult> resultByKeyword, GdShowStatContainer input) {
        if (input.getAges() == null && input.getGender() == null) {
            return;
        }

        double socDemReachPct = calcSocDemReachPct(input);

        EntryStream.of(resultByKeyword)
                .values()
                .map(SearchKeywordResult::getResult)
                .nonNull()
                .forEach(result -> {
                    long shows = result.getTotalCount();
                    long showsWithAppliedSocDem = round(shows * socDemReachPct);
                    result.getStat().setTotalCount(showsWithAppliedSocDem);
                });
    }

    private double calcSocDemReachPct(GdShowStatContainer input) {
        double socDemReachPct = 0.0;

        for (Table.Cell<GdGender, GdAge, Double> socDemReachPart : SOCDEM_REACH_PCT_BY_AGE_AND_GENDER.cellSet()) {
            if (!includes(input.getGender(), socDemReachPart.getRowKey())) {
                continue;
            }
            if (!includes(input.getAges(), socDemReachPart.getColumnKey())) {
                continue;
            }

            socDemReachPct += socDemReachPart.getValue();
        }

        return socDemReachPct;
    }

    private boolean includes(GdGender subj, GdGender obj) {
        if (subj == null || subj == GdGender.ALL) {
            return true;
        }

        return subj == obj;
    }

    private boolean includes(Set<GdAge> subj, GdAge obj) {
        if (subj == null || subj.contains(ALL)) {
            return true;
        }

        return subj.contains(obj);
    }

    public GdForecast getForecast(GridGraphQLContext context, GdForecastContainer input) {
        ClientId clientId = context.getSubjectUser().getClientId();
        int shard = shardHelper.getShardByClientId(clientId);

        List<String> phrases = StreamEx.of(input.getTarget().getKeywords().getPhrases())
                .distinct()
                .toList();
        input.getTarget().getKeywords().setPhrases(phrases);

        Map<String, ForecastCtr> forecastCtrByPhrase = keywordForecastService.getForecastByPhrase(phrases);
        Map<String, GdAuctionData> auctionDataByPhrase;
        try {
            auctionDataByPhrase = showConditionDataService.getGdAuctionDataByPhrase(shard, clientId, input, forecastCtrByPhrase);
        } catch (BsAuctionUnavailableException e) {
            return new GdForecast()
                    .withIsForecastAvailable(false);
        }
        Map<String, SearchKeywordResult> keywordShowStatByPhrase =
                EntryStream.of(getKeywordShowStatByPhrase(input.getTarget()))
                        .mapKeys(AdvqRequestKeyword::getPhrase)
                        .toMap();

        if (!allResponsesOk(phrases, keywordShowStatByPhrase)) {
            return new GdForecast()
                    .withIsForecastAvailable(false);
        }

        Long clicks = clicksForecaster.forecastClicks(input.getWeeklyBudget(),
                ifNotNull(input.getCpa(), BigDecimal::doubleValue), phrases, auctionDataByPhrase,
                keywordShowStatByPhrase, forecastCtrByPhrase);

        return new GdForecast()
                .withClicks(clicks)
                .withConversions(round(clicks * CONVERSIONS_TO_CLICKS_RATE))
                .withIsForecastAvailable(true);
    }

    private static boolean allResponsesOk(List<String> phrases, Map<String, SearchKeywordResult> keywordShowStatByPhrase) {
        return StreamEx.of(phrases)
                .mapToEntry(identity(), keywordShowStatByPhrase::get)
                .allMatch(ForecastService::isResponseOk);
    }

    private static boolean isResponseOk(String phrase, SearchKeywordResult result) {
        if (result == null) {
            logger.warn("No response from ADVQ to phrase: {}", phrase);
            return false;
        }

        if (result.isEmpty()) {
            logger.warn("Empty response from ADVQ to phrase: {}", phrase);
            return false;
        }

        if (result.hasErrors()) {
            logger.warn("Got errors from ADVQ: {} to phrase: {}", result.getErrors(), phrase);
            return false;
        }

        return true;
    }
}
