#include "model.h"

#include <drive/backend/offers/ranking/calcer.h>

#include <drive/backend/models/storage.h>
#include <drive/backend/proto/offer.pb.h>

void TMarketPricesContext::ApplyPriceModel(TPriceContext& context, const NDrive::IOfferModel& model, NDrive::TOfferFeatures& features, const i32 factorIndexFrom, const i32 factorIndexTo) {
    const double price = model.Calc(features);
    context.SetPriceModelName(model.GetName());

    if (factorIndexFrom > -1) {
        TPriceModelInfo info;
        info.Name = model.GetName();
        info.Before = features.Floats[factorIndexFrom] * 100;
        info.After = static_cast<float>(price * 100);
        context.MutablePriceModelInfos().emplace_back(std::move(info));
    }
    if (factorIndexTo > -1) {
        features.Floats[factorIndexTo] = price;
    }
    context.SetPrice(price * 100);
}

bool TMarketPricesContext::DeserializePricesContextFromProto(const NDrive::NProto::TMarketPricesContext& proto) {
    if (!TPricesContext::DeserializeContextFromProto(proto)) {
        return false;
    }
    const auto& features = proto.GetFeatures();
    if (!DeserializeFeatures(features)) {
        return false;
    }
    return true;
}

bool TMarketPricesContext::DeserializeFeatures(const NDrive::NProto::TOfferFeatures& features) {
    for (size_t i = 0; i < std::min(features.FloatSize(), Features.Floats.size()); ++i) {
        Features.Floats[i] = features.GetFloat(i);
    }
    for (size_t i = 0; i < std::min(features.CategorySize(), Features.Categories2.size()); ++i) {
        Features.Categories2[i] = features.GetCategory(i);
    }
    for (size_t i = 0; i < std::min(features.Float2Size(), Features.Floats2.size()); ++i) {
        Features.Floats2[i] = features.GetFloat2(i);
    }
    for (size_t i = 0; i < std::min(features.Category2Size(), Features.Categories2.size()); ++i) {
        Features.Categories2[i] = features.GetCategory2(i);
    }
    for (auto&& i : features.GetModelsDebugInfo()) {
        TVector<double> scores;
        for (auto&& score : i.GetScores()) {
            scores.emplace_back(score);
        }
        Features.StoreDebugInfo(i.GetModelName(), std::move(scores));
    }
    return true;
}

NDrive::NProto::TMarketPricesContext TMarketPricesContext::SerializePricesContextToProto(bool serializeFeatures) const {
    NDrive::NProto::TMarketPricesContext result = TPricesContext::SerializeContextToProto();
    if (serializeFeatures) {
        *result.MutableFeatures() = SerializeFeatures();
    }
    return result;
}

NDrive::NProto::TOfferFeatures TMarketPricesContext::SerializeFeatures() const {
    auto result = NDrive::NProto::TOfferFeatures();
    auto features = &result;
    for (auto&& value : Features.Floats2) {
        features->AddFloat2(value);
    }
    for (auto&& value : Features.Categories2) {
        features->AddCategory2(value);
    }
    for (auto&& i : Features.GetModelsDebugInfo()) {
        auto* modelDebugInfo = features->AddModelsDebugInfo();
        modelDebugInfo->SetModelName(i.first);
        for (auto&& score : i.second) {
            modelDebugInfo->AddScores(score);
        }
    }
    return result;
}

void TFullPricesContextBuilder::ApplyModels(const NDrive::IServer* server) {
    auto context = GetFullPricesContext();
    if (!context) {
        return;
    }
    const NDrive::TModelsStorage* models = server->GetModelsStorage();
    {
        NDrive::TOfferModelConstPtr model = models && context->GetRiding().GetPriceModelName() ? models->GetOfferModel(context->GetRiding().GetPriceModelName()) : nullptr;
        if (model) {
            ApplyRidingPriceModel(*model);
        }
    }
    {
        NDrive::TOfferModelConstPtr model = models && context->GetParking().GetPriceModelName() ? models->GetOfferModel(context->GetParking().GetPriceModelName()) : nullptr;
        if (model) {
            ApplyParkingPriceModel(*model);
        }
    }
}

void TFullPricesContextBuilder::ApplyRidingPriceModel(const NDrive::IOfferModel& model) {
    auto context = GetFullPricesContext();
    if (!context) {
        return;
    }
    TMarketPricesContext::ApplyPriceModel(context->MutableRiding(), model, context->MutableFeatures(), NDriveOfferFactors::FI_PRICE, NDriveOfferFactors::FI_MODEL_PRICE);
}

void TFullPricesContextBuilder::ApplyParkingPriceModel(const NDrive::IOfferModel& model) {
    auto context = GetFullPricesContext();
    if (!context) {
        return;
    }
    TMarketPricesContext::ApplyPriceModel(context->MutableParking(), model, context->MutableFeatures(), NDriveOfferFactors::FI_WAITING_PRICE, -1);
}

void TFullPricesContextBuilder::CalculateFeatures(const NDrive::TOfferFeatures& base) {
    auto context = GetFullPricesContext();
    if (!context) {
        return;
    }
    context->SetFeatures(base);
    RecalculateFeatures();
}

void TFullPricesContextBuilder::RecalcPrices(const NDrive::IServer* server) {
    DoRecalcPrices(server);
    RecalculateFeatures();
}

NDrive::NProto::TFullPricesContext TFullPricesContext::SerializePricesContextToProto(bool serializedFeatures) const {
    NDrive::NProto::TFullPricesContext result;
    *result.MutableMarket() = TMarketPricesContext::SerializePricesContextToProto(serializedFeatures);
    *result.MutableEquilibrium() = Equilibrium.SerializePricesContextToProto();
    return result;
}

bool TFullPricesContext::DeserializePricesContextFromProto(const NDrive::NProto::TFullPricesContext& proto) {
    return
        TMarketPricesContext::DeserializePricesContextFromProto(proto.GetMarket()) &&
        Equilibrium.DeserializePricesContextFromProto(proto.GetEquilibrium());
}

ui32 TFullPricesContext::CalcPackPrice(const TDuration d, const double distance, const ui32 leasingDayPrice) const {
    return CalcMarketPackPrice(d, distance, leasingDayPrice);
}

ui32 TFullPricesContext::CalcFixPrice(const TDuration d, const double distance) const {
    return CalcMarketFixPrice(d, distance);
}
