#include "flows_control.h"

#include <drive/backend/offers/context.h>
#include <drive/backend/offers/offers/standart.h>

#include <drive/backend/areas/areas.h>

TMinimalPriceOfferCorrector::TFactory::TRegistrator<TMinimalPriceOfferCorrector> TMinimalPriceOfferCorrector::Registrator(TMinimalPriceOfferCorrector::GetTypeName());
TFlowsOfferCorrector::TFactory::TRegistrator<TFlowsOfferCorrector> TFlowsOfferCorrector::Registrator(TFlowsOfferCorrector::GetTypeName());
TOfferSDModelCorrector::TFactory::TRegistrator<TOfferSDModelCorrector> TOfferSDModelCorrector::Registrator(TOfferSDModelCorrector::GetTypeName());

EOfferCorrectorResult TFlowsOfferCorrector::DoApplyForOffer(IOfferReport* offer, const TVector<TDBTag>& tags, const TOffersBuildingContext& context, const TString& userId, const NDrive::IServer* server, NDrive::TInfoEntitySession& session) const {
    auto baseResult = TBase::DoApplyForOffer(offer, tags, context, userId, server, session);
    if (baseResult != EOfferCorrectorResult::Success) {
        return baseResult;
    }
    if (!context.GetStartPosition()) {
        return EOfferCorrectorResult::Unimplemented;
    }

    TStandartOffer* stOffer = offer->GetOfferAs<TStandartOffer>();
    if (!stOffer) {
        return EOfferCorrectorResult::Unimplemented;
    }

    for (auto&& i : ConflictedCorrectors) {
        if (stOffer->GetRiding().HasCorrection(i)) {
            return EOfferCorrectorResult::Success;
        }
    }

    const TString areaId = offer->PredictDestination(context, server, *this);
    if (!!areaId) {
        if (FromToFees) {
            for (auto&& i : context.GetLocationAreaIds()) {
                TMaybe<ui32> fee = FromToFees->GetValue(i, areaId);
                if (fee) {
                    stOffer->MutableRiding().AddPrice(*fee, "flow_correction_" + areaId);
                    return EOfferCorrectorResult::Success;
                }
            }
            TMaybe<ui32> fee = FromToFees->GetValue("other", areaId);
            if (fee) {
                stOffer->MutableRiding().AddPrice(*fee, "flow_correction_other");
                return EOfferCorrectorResult::Success;
            }
        }
        const TSet<TString> areaIdsStart = server->GetDriveAPI()->GetAreasDB()->GetAreaIdsInPoint(context.OptionalOriginalRidingStart().GetOrElse(*context.GetStartPosition()));
        if (areaIdsStart.contains(areaId)) {
            if (server->GetSettings().GetValueDef("offers.flows_corrector.internal_correction_enabled", true)) {
                offer->ApplyInternalCorrection(areaId, context, server);
            } else {
                return EOfferCorrectorResult::Unimplemented;
            }
        } else {
            if (server->GetSettings().GetValueDef("offers.flows_corrector.transfer_correction_enabled", true)) {
                offer->ApplyFlowCorrection(areaId, context, server);
            } else {
                return EOfferCorrectorResult::Unimplemented;
            }
        }
    }
    return EOfferCorrectorResult::Success;
}

EOfferCorrectorResult TOfferSDModelCorrector::DoApplyForOffer(IOfferReport* offer, const TVector<TDBTag>& tags, const TOffersBuildingContext& context, const TString& userId, const NDrive::IServer* server, NDrive::TInfoEntitySession& session) const {
    auto baseResult = TBase::DoApplyForOffer(offer, tags, context, userId, server, session);
    if (baseResult != EOfferCorrectorResult::Success) {
        return baseResult;
    }
    if (!offer || !offer->GetOffer()) {
        return EOfferCorrectorResult::Unimplemented;
    }
    TStandartOffer* stOffer = offer->GetOfferAs<TStandartOffer>();
    if (!stOffer) {
        return EOfferCorrectorResult::Unimplemented;
    }

    TMaybe<double> resultScore = CalcScore(stOffer->GetFeatures(), context.MutableFlowControlContext(), server, nullptr);
    if (!resultScore) {
        return EOfferCorrectorResult::Unimplemented;
    }
    for (auto&& p : Predictors) {
        auto val = p.GetValue(*resultScore);
        if (val) {
            const double value = p.GetDetectedValueDef(*val);
            TMaybe<double> priceDelta = PriceDeltaByValue.GetValue(value);
            context.MutableSurgeTypes().insert(p.GetId());
            stOffer->MutableRiding().AddPrice(priceDelta.GetOrElse(0), "correction_" + GetName() + "_" + p.GetId());
            return EOfferCorrectorResult::Success;
        }
    }
    stOffer->MutableRiding().AddPrice(0, "correction_" + GetName() + "_UNDEFINED");
    return EOfferCorrectorResult::Success;
}
