#include "model.h"

#include <drive/backend/offers/factors/factors_info.h>
#include <drive/backend/offers/price/price.h>
#include <drive/backend/offers/ranking/calcer.h>

#include <drive/backend/logging/events.h>
#include <drive/backend/proto/models.pb.h>

#include <drive/library/cpp/catboost/multiclass.h>
#include <drive/library/cpp/schedule/sd.h>

#include <catboost/private/libs/text_features/helpers.h>

#include <kernel/catboost/catboost_calcer.h>
#include <kernel/relevfml/relev_fml.h>

#include <library/cpp/json/json_reader.h>
#include <library/cpp/logger/global/global.h>
#include <library/cpp/lua/eval.h>
#include <library/cpp/protobuf/json/proto2json.h>

#include <rtline/library/json/adapters.h>
#include <rtline/library/json/exception.h>
#include <rtline/library/json/parse.h>
#include <rtline/library/lmtree/model.h>
#include <rtline/util/algorithm/container.h>
#include <rtline/util/algorithm/ptr.h>
#include <rtline/util/types/exception.h>

#include <util/digest/fnv.h>
#include <util/random/fast.h>
#include <util/random/normal.h>
#include <util/stream/file.h>
#include <util/stream/str.h>
#include <util/string/builder.h>
#include <util/system/tls.h>

Y_STATIC_THREAD(TSimpleSharedPtr<TLuaEval>) LuaEvaluator;

namespace {
    template <class T>
    THolder<T> ContructFromString(const TString& data) {
        if (data.empty()) {
            return nullptr;
        }

        auto result = MakeHolder<T>();
        TStringInput input(data);
        result->Load(&input);
        return result;
    }

    template <>
    THolder<SRelevanceFormula> ContructFromString<SRelevanceFormula>(const TString& data) {
        if (data.empty()) {
            return nullptr;
        }

        auto result = MakeHolder<SRelevanceFormula>();
        Decode(result.Get(), data, NDriveOfferFactors::FI_FACTOR_COUNT);
        return result;
    }
}

THolder<NDrive::IOfferModel> NDrive::IOfferModel::Construct(const TString& data) {
    NDrive::NProto::TOfferModel proto;
    if (!proto.ParseFromString(data)) {
        return nullptr;
    }
    return Construct(proto);
}

THolder<NDrive::IOfferModel> NDrive::IOfferModel::Construct(const NDrive::NProto::TOfferModel& proto) {
    const TString& name = proto.GetName();
    if (proto.HasCatboostModel()) {
        const auto& meta = proto.GetCatboostModel();
        TCatboostModel::TOptions options;
        options.CatboostData = meta.GetCatboost();
        options.PolynomData = meta.GetPolynom();
        options.LMTreeProto = meta.HasLMTree() ? &meta.GetLMTree() : nullptr;
        options.Name = name;
        options.UseNewFeatures = meta.GetUseNewFeatures();
        return MakeHolder<TCatboostModel>(options);
    }
    if (proto.HasCatboostMulticlassModel()) {
        const auto& meta = proto.GetCatboostMulticlassModel();
        TCatboostMulticlassModel::TOptions options;
        options.CatboostData = meta.GetCatboost();
        options.Name = name;
        options.SoftMax = meta.GetSoftMax();
        options.UseNewFeatures = meta.GetUseNewFeatures();
        return MakeHolder<TCatboostMulticlassModel>(options);
    }
    if (proto.HasAdditionThresholdModel()) {
        NDrive::TAdditionThresholdModel::TOptions options;
        options.Name = name;

        const auto& meta = proto.GetAdditionThresholdModel();
        options.BaseIndex = meta.GetBaseIndex();
        options.Threshold = meta.GetThreshold();
        options.Addition = meta.GetAddition();
        options.ModelData = meta.GetCatboostData();
        options.PolynomData = meta.GetPolynomData();
        return MakeHolder<TAdditionThresholdModel>(options);
    }
    if (proto.HasConstantModel()) {
        const auto& meta = proto.GetConstantModel();
        return MakeHolder<TConstantOfferModel>(name, meta.GetValue());
    }
    if (proto.HasJitterModel()) {
        const auto& meta = proto.GetJitterModel();
        return MakeHolder<TJitterOfferModel>(name, meta.GetHashType(), meta.GetMinMultiplier(), meta.GetMaxMultiplier(), meta.GetNormalDispersion(), meta.GetNormalMean());
    }
    if (proto.HasSupplyDemandHistoryModel()) {
        const auto& meta = proto.GetSupplyDemandHistoryModel();
        auto history = ContructFromString<NDrive::TSupplyDemandHistory>(meta.GetHistory());
        auto polynom = ContructFromString<SRelevanceFormula>(meta.GetPolynom());
        TSupplyDemandHistoryModel::TOptions options;
        options.SigmoidDelta = meta.HasSigmoidDelta() ? meta.GetSigmoidDelta() : options.SigmoidDelta;
        options.SigmoidMultiplier = meta.HasSigmoidMultiplier() ? meta.GetSigmoidMultiplier() : options.SigmoidMultiplier;
        options.FutureDepth = meta.HasFutureDepth() ? meta.GetFutureDepth() : options.FutureDepth;
        options.FutureDemandDiscount = meta.HasFutureDemandDiscount() ? meta.GetFutureDemandDiscount() : options.FutureDemandDiscount;
        options.FutureSupplyDiscount = meta.HasFutureSupplyDiscount() ? meta.GetFutureSupplyDiscount() : options.FutureSupplyDiscount;
        options.FutureStep = meta.HasFutureStep() ? TDuration::MicroSeconds(meta.GetFutureStep()) : options.FutureStep;
        return MakeHolder<TSupplyDemandHistoryModel>(name, std::move(history), std::move(polynom), options);
    }
    if (proto.HasVariationOptimizerModel()) {
        const auto& meta = proto.GetVariationOptimizerModel();
        TVariationOptimizerModel::TOptions options;
        options.CatboostData = meta.GetCatboost();
        options.Catboost2Data = meta.GetCatboost2();
        options.PolynomData = meta.GetPolynom();
        options.Name = name;
        options.VariationIndex = meta.HasVariationIndex() ? meta.GetVariationIndex() : options.VariationIndex;
        options.VariationAbsoluteMax = meta.HasVariationAbsoluteMax() ? meta.GetVariationAbsoluteMax() : options.VariationAbsoluteMax;
        options.VariationAbsoluteMin = meta.HasVariationAbsoluteMin() ? meta.GetVariationAbsoluteMin() : options.VariationAbsoluteMin;
        options.VariationMax = meta.HasVariationMax() ? meta.GetVariationMax() : options.VariationMax;
        options.VariationMin = meta.HasVariationMin() ? meta.GetVariationMin() : options.VariationMin;
        options.VariationStep = meta.HasVariationStep() ? meta.GetVariationStep() : options.VariationStep;
        return MakeHolder<TVariationOptimizerModel>(options);
    }
    if (proto.HasGeoLocalModel()) {
        TGeoLocalModel::TOptions options;
        options.Name = name;
        const auto& meta = proto.GetGeoLocalModel();
        options.IsFridayWeekend = meta.GetIsFridayWeekend();
        options.Polynom = meta.GetPolynomData();
        TStringInput si(meta.GetElementsData());
        Load(&si, options.Elements);
        if (meta.HasBaseIndex()) {
            options.BaseIndex = meta.GetBaseIndex();
        }
        if (meta.HasVersion()) {
            if (!TryFromString(meta.GetVersion(), options.Version)) {
                ERROR_LOG << name << ": cannot parse Version from " << meta.GetVersion() << Endl;
                return nullptr;
            }
        }
        return MakeHolder<TGeoLocalModel>(options);
    }
    if (proto.HasRoundModel()) {
        const auto& meta = proto.GetRoundModel();
        TRoundModel::TOptions options;
        options.Name = name;
        options.MaxDiscount = meta.GetMaxDiscount();
        return MakeHolder<TRoundModel>(options);
    }
    if (proto.HasRandomSetModel()) {
        const auto& meta = proto.GetRandomSetModel();
        TRandomSetModel::TOptions options;
        options.Name = name;
        options.HashType = meta.GetHashType();
        options.Values = MakeSet<double>(meta.GetValue());
        return MakeHolder<TRandomSetModel>(options);
    }
    if (proto.HasTimeScheduleModel()) {
        const auto& meta = proto.GetTimeScheduleModel();
        NJson::TJsonValue configJson;
        if (!NJson::ReadJsonFastTree(meta.GetConfig(), &configJson)) {
            ERROR_LOG << name << ": cannot parse Json from " << meta.GetConfig() << Endl;
            return nullptr;
        }
        auto config = MakeHolder<TPriceByTimeConfig>();
        if (!config->DeserializeFromJson(configJson)) {
            ERROR_LOG << name << ": cannot parse Config from " << configJson.GetStringRobust() << Endl;
            return nullptr;
        }
        return MakeHolder<TTimeScheduleModel>(name, std::move(config));
    }
    if (proto.HasMetaMulticlassModel()) {
        TMetaMulticlassModel::TSubmodels submodels;
        for (auto&& data : proto.GetMetaMulticlassModel().GetSubmodelData()) {
            auto submodel = IOfferModel::Construct(data);
            if (!submodel) {
                ERROR_LOG << name << ": cannot construct submodel from " << Base64Encode(data) << Endl;
                return nullptr;
            }
            submodels.push_back(std::move(submodel));
        }
        return MakeHolder<TMetaMulticlassModel>(name, std::move(submodels));
    }
    if (proto.HasLuaModel()) {
        const auto& meta = proto.GetLuaModel();
        return MakeHolder<TLuaModel>(name, meta.GetScript());
    }
    return nullptr;
}

THolder<NDrive::IOfferModel> NDrive::IOfferModel::Construct(const NJson::TJsonValue& json) noexcept(false) {
    const TString& name = json["name"].GetStringSafe();
    const TString& type = json["type"].GetStringRobust();
    if (type == TCatboostModel::Type()) {
        TCatboostModel::TOptions options;
        options.Name = name;
        options.CatboostData = json["data"].GetString();
        options.PolynomData = json["polynom"].GetString();
        options.LMTreeJson = json["lmtree"];
        options.UseNewFeatures = NJson::FromJson<TMaybe<bool>>(json["use_new_features"]).GetOrElse(false);
        return MakeHolder<TCatboostModel>(options);
    }
    if (type == TCatboostMulticlassModel::Type()) {
        TCatboostMulticlassModel::TOptions options;
        options.Name = name;
        options.CatboostData = json["data"].GetStringSafe();
        options.SoftMax = NJson::FromJson<TMaybe<bool>>(json["soft_max"]).GetOrElse(false);
        options.UseNewFeatures = NJson::FromJson<TMaybe<bool>>(json["use_new_features"]).GetOrElse(false);
        return MakeHolder<TCatboostMulticlassModel>(options);
    }
    if (type == TAdditionThresholdModel::Type()) {
        NDrive::TAdditionThresholdModel::TOptions options;
        options.Name = name;
        options.ModelData = json["data"].GetString();
        options.PolynomData = json["polynom"].GetString();
        if (json["base_index"].IsDefined()) {
            options.BaseIndex = json["base_index"].GetUIntegerSafe();
        }
        options.Threshold = json["threshold"].GetDoubleSafe();
        options.Addition = json["addition"].GetDoubleSafe();
        return MakeHolder<TAdditionThresholdModel>(options);
    }
    if (type == TConstantOfferModel::Type()) {
        double value = json["value"].GetDoubleRobust();
        return MakeHolder<TConstantOfferModel>(name, value);
    }
    if (type == TJitterOfferModel::Type()) {
        double minMultiplier = json["min_multiplier"].GetDoubleSafe();
        double maxMultiplier = json["max_multiplier"].GetDoubleSafe();
        double normalDistribution = NJson::FromJson<TMaybe<double>>(json["normal_distribution"]).GetOrElse(0);
        double normalMean = NJson::FromJson<TMaybe<double>>(json["normal_mean"]).GetOrElse(0);
        auto hashType = NJson::FromJson<TMaybe<TString>>(json["hash_type"]).GetOrElse("");
        return MakeHolder<TJitterOfferModel>(name, hashType, minMultiplier, maxMultiplier, normalDistribution, normalMean);
    }
    if (type == TSupplyDemandHistoryModel::Type()) {
        const TString& data = json["data"].GetStringSafe();
        const TString& polynom = json["polynom"].GetStringSafe();
        auto history = ContructFromString<NDrive::TSupplyDemandHistory>(data);
        auto poly = ContructFromString<SRelevanceFormula>(polynom);
        TSupplyDemandHistoryModel::TOptions options;
        if (auto sigmoidDelta = NJson::FromJson<TMaybe<float>>(json["sigmoid_delta"])) {
            options.SigmoidDelta = *sigmoidDelta;
        }
        if (auto sigmoidMultiplier = NJson::FromJson<TMaybe<float>>(json["sigmoid_multiplier"])) {
            options.SigmoidMultiplier = *sigmoidMultiplier;
        }
        if (auto futureDepth = NJson::FromJson<TMaybe<ui32>>(json["future_depth"])) {
            options.FutureDepth = *futureDepth;
        }
        if (auto futureDiscount = NJson::FromJson<TMaybe<float>>(json["future_demand_discount"])) {
            options.FutureDemandDiscount = *futureDiscount;
        }
        if (auto futureDiscount = NJson::FromJson<TMaybe<float>>(json["future_supply_discount"])) {
            options.FutureSupplyDiscount = *futureDiscount;
        }
        if (auto futureStep = NJson::FromJson<TMaybe<TDuration>>(json["future_step"])) {
            options.FutureStep = *futureStep;
        }
        return MakeHolder<TSupplyDemandHistoryModel>(name, std::move(history), std::move(poly), options);
    }
    if (type == TVariationOptimizerModel::Type()) {
        TVariationOptimizerModel::TOptions options;
        options.Name = name;
        const auto& data = json["data"];
        if (data.IsArray()) {
            const auto& arr = data.GetArray();
            Y_ENSURE(arr.size() == 2, "data array should be of size 2, actual " << arr.size());
            options.CatboostData = arr[0].GetStringSafe();
            options.Catboost2Data = arr[1].GetStringSafe();
        } else if (data.IsDefined()) {
            options.CatboostData = json["data"].GetStringSafe();
        }
        options.PolynomData = json["polynom"].GetStringSafe();
        if (auto variationIndex = NJson::FromJson<TMaybe<size_t>>(json["variation_index"])) {
            options.VariationIndex = *variationIndex;
        }
        if (auto variationAbsoluteMax = NJson::FromJson<TMaybe<double>>(json["variation_absolute_max"])) {
            options.VariationAbsoluteMax = *variationAbsoluteMax;
        }
        if (auto variationAbsoluteMin = NJson::FromJson<TMaybe<double>>(json["variation_absolute_min"])) {
            options.VariationAbsoluteMin = *variationAbsoluteMin;
        }
        if (auto variationMax = NJson::FromJson<TMaybe<double>>(json["variation_max"])) {
            options.VariationMax = *variationMax;
        }
        if (auto variationMin = NJson::FromJson<TMaybe<double>>(json["variation_min"])) {
            options.VariationMin = *variationMin;
        }
        if (auto variationStep = NJson::FromJson<TMaybe<double>>(json["variation_step"])) {
            options.VariationStep = *variationStep;
        }
        return MakeHolder<TVariationOptimizerModel>(options);
    }
    if (type == TGeoLocalModel::Type()) {
        TGeoLocalModel::TOptions options;
        options.Name = name;
        options.IsFridayWeekend = json["is_friday_weekend"].GetBoolean();
        options.Polynom = json["polynom"].GetString();
        TStringInput si(json["data"].GetStringSafe());
        Load(&si, options.Elements);
        NJson::ReadField(json, "base_index", options.BaseIndex);
        if (json.Has("version")) {
            Y_ENSURE(NJson::TryFromJson(json["version"], NJson::Stringify(options.Version)));
        }
        return MakeHolder<TGeoLocalModel>(options);
    }
    if (type == TRoundModel::Type()) {
        TRoundModel::TOptions options;
        options.Name = name;
        options.MaxDiscount = json["max_discount"].GetDoubleSafe();
        return MakeHolder<TRoundModel>(options);
    }
    if (type == TRandomSetModel::Type()) {
        TRandomSetModel::TOptions options;
        options.Name = name;
        options.HashType = NJson::FromJson<TMaybe<TString>>(json["hash_type"]).GetOrElse("");
        options.Values = NJson::FromJson<TSet<double>>(json["values"]);
        return MakeHolder<TRandomSetModel>(options);
    }
    if (type == TTimeScheduleModel::Type()) {
        auto config = MakeHolder<TPriceByTimeConfig>();
        Y_ENSURE(config->DeserializeFromJson(json["config"]), "cannot parse Config");
        return MakeHolder<TTimeScheduleModel>(name, std::move(config));
    }
    if (type == TMetaMulticlassModel::Type()) {
        TMetaMulticlassModel::TSubmodels submodels;
        for (auto&& data : json["submodels"].GetArraySafe()) {
            auto submodel = IOfferModel::Construct(data);
            Y_ENSURE(submodel, "cannot construct submodel from " << data.GetStringRobust());
            submodels.push_back(std::move(submodel));
        }
        return MakeHolder<TMetaMulticlassModel>(name, std::move(submodels));
    }
    if (type == TLuaModel::Type()) {
        auto model = MakeHolder<TLuaModel>(name, json["script"].GetStringSafe());
        model->Validate();
        return std::move(model);
    }
    return nullptr;
}

void NDrive::IOfferModel::Calc(TArrayRef<TOfferFeatures> features, TArrayRef<double> results) const {
    Y_ENSURE_BT(features.size() == results.size());
    for (size_t i = 0; i < std::min(features.size(), results.size()); ++i) {
        results[i] = Calc(features[i]);
    }
}

void NDrive::IOfferModel::Serialize(TString& data) const {
    NDrive::NProto::TOfferModel proto;
    proto.SetName(GetName());
    Serialize(proto);
    data = proto.SerializeAsString();
}

double NDrive::TMetaMulticlassModel::Calc(TOfferFeatures& features) const {
    auto submodel = !Submodels.empty() ? Submodels[0].Get() : nullptr;
    if (submodel) {
        return submodel->Calc(features);
    } else {
        return 0;
    }
}

TVector<double> NDrive::TMetaMulticlassModel::Predict(const TOfferFeatures& features) const {
    TVector<double> result;
    for (auto&& submodel : Submodels) {
        auto local = features;
        result.push_back(submodel ? submodel->Calc(local) : 0);
    }
    return result;
}

void NDrive::TMetaMulticlassModel::Serialize(NDrive::NProto::TOfferModel& proto) const {
    auto meta = proto.MutableMetaMulticlassModel();
    for (auto&& submodel : Submodels) {
        meta->AddSubmodelData(submodel ? submodel->Serialize<TString>() : TString{});
    }
}

void NDrive::TMetaMulticlassModel::Serialize(NJson::TJsonValue& json) const {
    json["name"] = Name;
    json["type"] = Type();
    NJson::TJsonValue submodels = NJson::JSON_ARRAY;
    for (auto&& submodel : Submodels) {
        submodels.AppendValue(submodel ? submodel->Serialize<NJson::TJsonValue>() : NJson::JSON_NULL);
    }
    json["submodels"] = std::move(submodels);
}

void NDrive::IOfferModel::Serialize(NJson::TJsonValue& json) const {
    NDrive::NProto::TOfferModel proto;
    Serialize(proto);
    NProtobufJson::Proto2Json(proto, json);
    if (json.IsMap()) {
        json["name"] = GetName();
        json["type"] = GetType();
    }
}

NDrive::TCatboostModel::TCatboostModel(const TOptions& options)
    : Options(options)
{
    Catboost = ContructFromString<NCatboostCalcer::TCatboostCalcer>(options.CatboostData);
    Polynom = ContructFromString<SRelevanceFormula>(options.PolynomData);
    if (Options.LMTreeProto) {
        LMTree = MakeHolder<NLinearModelsTree::TModel>(NLinearModelsTree::TModel::Construct(*Options.LMTreeProto));
        Options.LMTreeProto = nullptr;
    }
    if (Options.LMTreeJson.IsDefined()) {
        LMTree = MakeHolder<NLinearModelsTree::TModel>(NLinearModelsTree::TModel::Construct(Options.LMTreeJson));
    }
}

double NDrive::TCatboostModel::Calc(TOfferFeatures& features) const {
    if (Catboost) {
        if (Options.UseNewFeatures) {
            features.Floats[NDriveOfferFactors::FI_MATRIXNET] = Catboost->CalcRelev(features.FloatsView2(), features.CategoriesView2());
        } else {
            features.Floats[NDriveOfferFactors::FI_MATRIXNET] = Catboost->CalcRelev(features.FloatsView(), features.CategoriesView2());
        }
    }
    if (LMTree) {
        if (Options.UseNewFeatures) {
            return LMTree->DoCalcRelev(features.Floats2.data());
        } else {
            return LMTree->DoCalcRelev(features.Floats.data());
        }
    }
    if (Polynom) {
        if (Options.UseNewFeatures) {
            return Polynom->Calc(features.Floats2.data());
        } else {
            return Polynom->Calc(features.Floats.data());
        }
    } else {
        return features.Floats[NDriveOfferFactors::FI_MATRIXNET];
    }
}

void NDrive::TCatboostModel::Serialize(NDrive::NProto::TOfferModel& proto) const {
    auto meta = proto.MutableCatboostModel();
    meta->SetUseNewFeatures(Options.UseNewFeatures);
    if (Catboost) {
        TStringStream ss;
        Catboost->Save(&ss);
        meta->SetCatboost(ss.Str());
    }
    if (LMTree) {
        *meta->MutableLMTree() = LMTree->Serialize<NLinearModelsTree::NProto::TTree>();
    }
    if (Polynom) {
        meta->SetPolynom(Encode(*Polynom));
    }
}

void NDrive::TCatboostModel::Serialize(NJson::TJsonValue& json) const {
    json["name"] = Options.Name;
    json["use_new_features"] = Options.UseNewFeatures;
    if (Catboost) {
        json["catboost"] = NJson::ToJson(NJson::Dictionary(Catboost->GetModel().ModelInfo));
    }
    if (LMTree) {
        json["lmtree"] = LMTree->Serialize<NJson::TJsonValue>();
    }
    if (Polynom) {
        json["polynom"] = Encode(*Polynom);
    }
}

NDrive::TCatboostMulticlassModel::TCatboostMulticlassModel(const TOptions& options)
    : Options(options)
{
    Predictor = ContructFromString<NCatboostCalcer::TMulticlassPredictor>(options.CatboostData);
    Y_ENSURE(Predictor);
}

const NCatboostCalcer::TMulticlassPredictor& NDrive::TCatboostMulticlassModel::GetPredictor() const {
    return *Yensured(Predictor);
}

TVector<double> NDrive::TCatboostMulticlassModel::Predict(const TOfferFeatures& features) const {
    auto result = GetPredictor().CalcRaw(features.FloatsView(), features.CategoriesView2());
    if (Options.SoftMax) {
        Softmax(result);
    }
    return result;
}

void NDrive::TCatboostMulticlassModel::Serialize(NDrive::NProto::TOfferModel& proto) const {
    auto meta = proto.MutableCatboostMulticlassModel();
    if (Predictor) {
        TStringStream ss;
        Predictor->Save(&ss);
        meta->SetCatboost(ss.Str());
        meta->SetSoftMax(Options.SoftMax);
    }
}

void NDrive::TCatboostMulticlassModel::Serialize(NJson::TJsonValue& json) const {
    json["name"] = Options.Name;
    json["type"] = Type();
    json["soft_max"] = Options.SoftMax;
    if (Predictor) {
        json["predictor"] = NJson::ToJson(NJson::Dictionary(Predictor->GetModel().ModelInfo));
    }
}

NDrive::TAdditionThresholdModel::TAdditionThresholdModel(const TOptions& options)
    : Options(options)
{
    TString modelData;
    if (Options.ModelData) {
        modelData = Options.ModelData;
    } else if (Options.ModelFilename) {
        modelData = TIFStream(Options.ModelFilename).ReadAll();
    }
    if (modelData) {
        Predictor = ContructFromString<NCatboostCalcer::TCatboostCalcer>(modelData);
    }
    if (Options.PolynomData) {
        Polynom = ContructFromString<SRelevanceFormula>(options.PolynomData);
    }
}

NDrive::TAdditionThresholdModel::~TAdditionThresholdModel() {
}

double NDrive::TAdditionThresholdModel::Calc(TOfferFeatures& features) const {
    auto value = CalcOne(features);
    if (value > Options.Threshold) {
        return features.Floats[Options.BaseIndex] + Options.Addition;
    } else {
        return features.Floats[Options.BaseIndex];
    }
}

double NDrive::TAdditionThresholdModel::CalcOne(TOfferFeatures& features) const {
    double result = 0;
    if (Predictor) {
        result = Predictor->CalcRelev(features.FloatsView(), features.CategoriesView2());
        features.Floats[NDriveOfferFactors::FI_MATRIXNET] = result;
    }
    if (Polynom) {
        result = Polynom->Calc(features.Floats.data());
    }
    return result;
}

void NDrive::TAdditionThresholdModel::Serialize(NDrive::NProto::TOfferModel& proto) const {
    auto meta = proto.MutableAdditionThresholdModel();
    meta->SetBaseIndex(Options.BaseIndex);
    meta->SetThreshold(Options.Threshold);
    meta->SetAddition(Options.Addition);

    auto data = meta->MutableCatboostData();
    if (data && Predictor) {
        TStringOutput so(*data);
        Predictor->Save(&so);
    }
    if (Polynom) {
        meta->SetPolynomData(Options.PolynomData);
    }
}

void NDrive::TAdditionThresholdModel::Serialize(NJson::TJsonValue& json) const {
    json["name"] = Options.Name;
    json["base_index"] = Options.BaseIndex;
    json["threshold"] = Options.Threshold;
    json["addition"] = Options.Addition;
    if (Options.PolynomData) {
        json["polynom"] = Options.PolynomData;
    }
    if (Predictor) {
        json["catboost"] = NJson::ToJson(NJson::Dictionary(Predictor->GetModel().ModelInfo));
    }
}

double NDrive::TConstantOfferModel::Calc(TOfferFeatures& features) const {
    features.Floats[NDriveOfferFactors::FI_MATRIXNET] = Value;
    return Value;
}

void NDrive::TConstantOfferModel::Serialize(NDrive::NProto::TOfferModel& proto) const {
    auto meta = proto.MutableConstantModel();
    meta->SetValue(Value);
}

ui32 CalcJitterHash(const NDrive::TOfferFeatures& features) {
    auto day = features.Floats[NDriveOfferFactors::FI_DAY_OF_THE_WEEK];
    ui32 halfhour = features.Floats[NDriveOfferFactors::FI_TIME_OF_THE_DAY] / TDuration::Minutes(30).Seconds();
    ui32 timestamp = features.Floats[NDriveOfferFactors::FI_TIMESTAMP] / TDuration::Minutes(30).Seconds();
    auto lat = static_cast<ui32>(features.Floats[NDriveOfferFactors::FI_LATITUDE] * 100);
    auto lon = static_cast<ui32>(features.Floats[NDriveOfferFactors::FI_LONGITUDE] * 100);

    auto hash =
        FnvHash<ui32>(&day, sizeof(day)) ^
        FnvHash<ui32>(&halfhour, sizeof(halfhour)) ^
        FnvHash<ui32>(&timestamp, sizeof(timestamp)) ^
        FnvHash<ui32>(&lat, sizeof(lat)) ^
        FnvHash<ui32>(&lon, sizeof(lon));
    return hash;
}

ui32 CalcUserTimestamp3Hash(const NDrive::TOfferFeatures& features) {
    auto userId = features.Categories2[NDriveOfferCatFactors2::FI_USER_ID];
    ui32 timestamp = features.Floats[NDriveOfferFactors::FI_TIMESTAMP] / TDuration::Hours(3).Seconds();
    auto hash =
        FnvHash<ui32>(userId) ^
        FnvHash<ui32>(&timestamp, sizeof(timestamp));
    return hash;
}

NDrive::TJitterOfferModel::TJitterOfferModel(const TString& name, const TString& hashType, double min, double max, double d, double m)
    : Name(name)
    , HashType(hashType)
    , MaxMultiplier(max)
    , MinMultiplier(min)
    , NormalDispersion(d)
    , NormalMean(m)
{
    Hash = CreateHash(HashType);
}

double NDrive::TJitterOfferModel::Calc(TOfferFeatures& features) const {
    auto price = features.GetPrice();
    auto hash = Hash(features);
    double multiplier = 1;
    if (NormalDispersion > 0) {
        TReallyFastRng32 uniformDistribution(hash);
        auto normal = NormalDistribution<double>(uniformDistribution, NormalMean, NormalDispersion);
        multiplier = std::clamp(1 + normal, MinMultiplier, MaxMultiplier);
    } else {
        multiplier = (hash % 100) * std::abs(MaxMultiplier - MinMultiplier) / 100 + std::min(MinMultiplier, MaxMultiplier);
    }
    features.Floats[NDriveOfferFactors::FI_MATRIXNET] = multiplier;
    return multiplier * price;
}

void NDrive::TJitterOfferModel::Serialize(NDrive::NProto::TOfferModel& proto) const {
    auto meta = proto.MutableJitterModel();
    meta->SetMaxMultiplier(MaxMultiplier);
    meta->SetMinMultiplier(MinMultiplier);
    meta->SetNormalDispersion(NormalDispersion);
    meta->SetNormalMean(NormalMean);
    meta->SetHashType(HashType);
}

NDrive::TSupplyDemandHistoryModel::TSupplyDemandHistoryModel(const TString& name, THolder<TSupplyDemandHistory>&& history, THolder<SRelevanceFormula>&& polynom, const TOptions& options)
    : Name(name)
    , History(std::move(history))
    , Polynom(std::move(polynom))
    , Options(options)
{
}

NDrive::TSupplyDemandHistoryModel::~TSupplyDemandHistoryModel() {
}

double NDrive::TSupplyDemandHistoryModel::Calc(TOfferFeatures& features) const {
    if (History) {
        auto offset =
            TDuration::Days(features.Floats[NDriveOfferFactors::FI_DAY_OF_THE_WEEK]) +
            TDuration::Seconds(features.Floats[NDriveOfferFactors::FI_TIME_OF_THE_DAY] * 24 * 60 * 60);
        auto coordinate = TGeoCoord(features.Floats[NDriveOfferFactors::FI_LONGITUDE], features.Floats[NDriveOfferFactors::FI_LATITUDE]);
        auto stat = GetHistory(coordinate, offset);
        auto invSupply = 1 / std::max(1.0f, stat.Supply);
        features.Floats[NDriveOfferFactors::FI_HISTORICAL_DEMAND] = stat.Demand;
        features.Floats[NDriveOfferFactors::FI_INV_HISTORICAL_SUPPLY] = invSupply;
        features.Floats[NDriveOfferFactors::FI_HISTORICAL_SIGMOID] = Sigmoid(Options.SigmoidMultiplier * (stat.Demand * invSupply - Options.SigmoidDelta));
    }
    if (Polynom) {
        float result = Polynom->Calc(features.Floats.data());
        features.Floats[NDriveOfferFactors::FI_MATRIXNET] = result;
        return result;
    } else {
        return features.Floats[NDriveOfferFactors::FI_PRICE];
    }
}

void NDrive::TSupplyDemandHistoryModel::Serialize(NDrive::NProto::TOfferModel& proto) const {
    auto meta = proto.MutableSupplyDemandHistoryModel();
    meta->SetSigmoidDelta(Options.SigmoidDelta);
    meta->SetSigmoidMultiplier(Options.SigmoidMultiplier);
    meta->SetFutureDepth(Options.FutureDepth);
    meta->SetFutureDemandDiscount(Options.FutureDemandDiscount);
    meta->SetFutureSupplyDiscount(Options.FutureSupplyDiscount);
    meta->SetFutureStep(Options.FutureStep.MicroSeconds());
    if (History) {
        TStringStream ss;
        History->Save(&ss);
        meta->SetHistory(ss.Str());
    }
    if (Polynom) {
        meta->SetPolynom(Encode(*Polynom));
    }
}

NDrive::TSupplyDemandStat NDrive::TSupplyDemandHistoryModel::GetHistory(const TGeoCoord& coordinate, TDuration offset) const {
    if (!History) {
        return {};
    }
    double demandMultiplier = 1;
    double supplyMultiplier = 1;
    double demandNormalizer = 0;
    double supplyNormalizer = 0;
    TDuration shift = TDuration::Zero();
    NDrive::TSupplyDemandStat result;
    for (size_t i = 0; i < Options.FutureDepth; ++i) {
        auto stat = History->Get(coordinate, offset + shift);
        result.Demand += demandMultiplier * stat.Demand;
        result.Supply += supplyMultiplier * stat.Supply;
        demandNormalizer += demandMultiplier;
        supplyNormalizer += supplyMultiplier;
        shift += Options.FutureStep;
        demandMultiplier *= Options.FutureDemandDiscount;
        supplyMultiplier *= Options.FutureSupplyDiscount;
    }
    {
        result.Demand /= demandNormalizer;
        result.Supply /= supplyNormalizer;
    }
    return result;
}

NDrive::TVariationOptimizerModel::TVariationOptimizerModel(const TOptions& options)
    : Options(options)
{
    Catboost = ContructFromString<NCatboostCalcer::TCatboostCalcer>(options.CatboostData);
    Catboost2 = ContructFromString<NCatboostCalcer::TCatboostCalcer>(options.Catboost2Data);
    Polynom = ContructFromString<SRelevanceFormula>(options.PolynomData);
}

double NDrive::TVariationOptimizerModel::Calc(TOfferFeatures& features) const {
    const double original = features.Floats[Options.VariationIndex];
    double argument = original;
    double value = CalcOne(features);
    double matrixnet = features.Floats[NDriveOfferFactors::FI_MATRIXNET];
    double matrixnet2 = features.Floats[NDriveOfferFactors::FI_MATRIXNET_2];
    if (Options.VariationAbsoluteMax > 0) {
        for (double arg = Options.VariationAbsoluteMin; arg <= Options.VariationAbsoluteMax; arg += Options.VariationStep) {
            features.Floats[Options.VariationIndex] = arg;
            double v = CalcOne(features);
            if (v > value) {
                argument = arg;
                value = v;
                matrixnet = features.Floats[NDriveOfferFactors::FI_MATRIXNET];
                matrixnet2 = features.Floats[NDriveOfferFactors::FI_MATRIXNET_2];
            }
        }
    } else {
        for (double delta = Options.VariationMin; delta <= Options.VariationMax; delta += Options.VariationStep) {
            double arg = original + delta;
            features.Floats[Options.VariationIndex] = arg;
            double v = CalcOne(features);
            if (v > value) {
                argument = arg;
                value = v;
                matrixnet = features.Floats[NDriveOfferFactors::FI_MATRIXNET];
                matrixnet2 = features.Floats[NDriveOfferFactors::FI_MATRIXNET_2];
            }
        }
    }
    features.Floats[Options.VariationIndex] = original;
    features.Floats[NDriveOfferFactors::FI_MATRIXNET] = matrixnet;
    features.Floats[NDriveOfferFactors::FI_MATRIXNET_2] = matrixnet2;
    return argument;
}

double NDrive::TVariationOptimizerModel::CalcOne(TOfferFeatures& features) const {
    if (Catboost) {
        features.Floats[NDriveOfferFactors::FI_MATRIXNET] = Catboost->CalcRelev(features.FloatsView(), features.CategoriesView2());
    }
    if (Catboost2) {
        features.Floats[NDriveOfferFactors::FI_MATRIXNET_2] = Catboost2->CalcRelev(features.FloatsView(), features.CategoriesView2());
        features.Floats[NDriveOfferFactors::FI_MATRIXNET_COMPOSITE] = 1 / (features.Floats[NDriveOfferFactors::FI_MATRIXNET] + features.Floats[NDriveOfferFactors::FI_MATRIXNET_2]);
    }
    if (Polynom) {
        return Polynom->Calc(features.Floats.data());
    } else {
        return features.Floats[NDriveOfferFactors::FI_MATRIXNET];
    }
}

void NDrive::TVariationOptimizerModel::Serialize(NDrive::NProto::TOfferModel& proto) const {
    auto meta = proto.MutableVariationOptimizerModel();
    if (Catboost) {
        TStringStream ss;
        Catboost->Save(&ss);
        meta->SetCatboost(ss.Str());
    }
    if (Catboost2) {
        TStringStream ss;
        Catboost2->Save(&ss);
        meta->SetCatboost2(ss.Str());
    }
    if (Polynom) {
        meta->SetPolynom(Encode(*Polynom));
    }
    meta->SetVariationIndex(Options.VariationIndex);
    meta->SetVariationAbsoluteMax(Options.VariationAbsoluteMax);
    meta->SetVariationAbsoluteMin(Options.VariationAbsoluteMin);
    meta->SetVariationMax(Options.VariationMax);
    meta->SetVariationMin(Options.VariationMin);
    meta->SetVariationStep(Options.VariationStep);
}

namespace {
    TSet<TString> GeoLocalModelBaseModelCodes = {
        "hyundai_solaris"
        "kia_rio",
        "kia_rio_xline",
        "renault_kaptur",
        "skoda_octavia",
        "skoda_rapid",
        "vw_polo",
    };
}

ui32 NDrive::TGeoLocalModel::CalcHash(const TOfferFeatures& features, bool isFridayWeekend) {
    ui32 weekendThreshold = isFridayWeekend ? 4 : 5;
    bool isBaseModel = GeoLocalModelBaseModelCodes.contains(features.Categories2[NDriveOfferCatFactors2::FI_MODEL]);
    bool isWeekend = features.Floats[NDriveOfferFactors::FI_DAY_OF_THE_WEEK] >= weekendThreshold;
    auto hour = GetHour(features);
    TStringBuf street = features.Categories2[NDriveOfferCatFactors2::FI_GEOCODED_STREET];
    return CalcHash(isBaseModel, isWeekend, hour, street);
}

ui32 NDrive::TGeoLocalModel::CalcHash(bool isBaseModel, bool isWeekend, ui8 hour, TStringBuf street) {
    auto result =
        FnvHash<ui32>(street) ^
        FnvHash<ui32>(&hour, sizeof(hour)) ^
        FnvHash<ui32>(&isWeekend, sizeof(isWeekend)) ^
        FnvHash<ui32>(&isBaseModel, sizeof(isBaseModel));
    return result;
}

ui32 NDrive::TGeoLocalModel::CalcHash2(TStringBuf model, ui8 weekday, ui8 hour, TStringBuf location) {
    auto result =
        FnvHash<ui32>(model) ^
        FnvHash<ui32>(ToString<ui32>(weekday) + "-weekday") ^
        FnvHash<ui32>(ToString<ui32>(hour) + "-hour") ^
        FnvHash<ui32>(location);
    return result;
}

TString NDrive::TGeoLocalModel::GetCoordinates(const TOfferFeatures& features) {
    return TStringBuilder()
        << static_cast<ui32>(std::floor(features.Floats[NDriveOfferFactors::FI_LATITUDE] * 100))
        << '_'
        << static_cast<ui32>(std::floor(features.Floats[NDriveOfferFactors::FI_LONGITUDE] * 100))
    ;
}

ui8 NDrive::TGeoLocalModel::GetHour(const TOfferFeatures& features) {
    return static_cast<ui8>(std::floor(features.Floats[NDriveOfferFactors::FI_TIME_OF_THE_DAY] * 24));
}

NDrive::TGeoLocalModel::TGeoLocalModel(const TOptions& options)
    : Options(options)
{
    std::sort(Options.Elements.begin(), Options.Elements.end());
    Polynom = ContructFromString<SRelevanceFormula>(options.Polynom);
}

NDrive::TGeoLocalModel::~TGeoLocalModel() {
}

double NDrive::TGeoLocalModel::Calc(TOfferFeatures& features) const {
    double multiplier = 1;
    auto optionalValue = Find(features);
    if (optionalValue) {
        features.Floats[NDriveOfferFactors::FI_MATRIXNET] = *optionalValue;
        multiplier = *optionalValue;
    } else {
        return features.Floats[Options.BaseIndex];
    }
    if (Polynom) {
        return Polynom->Calc(features.Floats.data());
    } else {
        return multiplier * features.Floats[Options.BaseIndex];
    }
}

TMaybe<float> NDrive::TGeoLocalModel::Find(const TOfferFeatures& features) const {
    switch (Options.Version) {
    case EVersion::V1:
    {
        auto hash = CalcHash(features, Options.IsFridayWeekend);
        auto p = std::lower_bound(Options.Elements.begin(), Options.Elements.end(), hash);
        if (p != Options.Elements.end() && p->Hash == hash) {
            return p->Value;
        }
        break;
    }
    case EVersion::V2:
    {
        auto hour = static_cast<ui8>(std::floor(features.Floats[NDriveOfferFactors::FI_TIME_OF_THE_DAY] * 24));
        auto hash = CalcHash2(
            features.Categories2[NDriveOfferCatFactors2::FI_MODEL],
            features.Floats[NDriveOfferFactors::FI_DAY_OF_THE_WEEK],
            hour,
            features.Categories2[NDriveOfferCatFactors2::FI_GEOCODED_STREET]
        );
        auto p = std::lower_bound(Options.Elements.begin(), Options.Elements.end(), hash);
        if (p != Options.Elements.end() && p->Hash == hash) {
            return p->Value;
        }

        auto coordinates = GetCoordinates(features);
        auto hash2 = CalcHash2(
            features.Categories2[NDriveOfferCatFactors2::FI_MODEL],
            features.Floats[NDriveOfferFactors::FI_DAY_OF_THE_WEEK],
            hour,
            coordinates
        );
        auto p2 = std::lower_bound(Options.Elements.begin(), Options.Elements.end(), hash2);
        if (p2 != Options.Elements.end() && p2->Hash == hash2) {
            return p2->Value;
        }

        auto hash3 = CalcHash2(
            features.Categories2[NDriveOfferCatFactors2::FI_MODEL],
            features.Floats[NDriveOfferFactors::FI_DAY_OF_THE_WEEK],
            hour,
            TStringBuf()
        );
        auto p3 = std::lower_bound(Options.Elements.begin(), Options.Elements.end(), hash3);
        if (p3 != Options.Elements.end() && p3->Hash == hash3) {
            return p3->Value;
        }
        break;
    }
    }
    return {};
}

void NDrive::TGeoLocalModel::Serialize(NDrive::NProto::TOfferModel& proto) const {
    auto meta = proto.MutableGeoLocalModel();
    {
        TStringStream ss;
        Save(&ss, Options.Elements);
        meta->SetElementsData(ss.Str());
        meta->SetIsFridayWeekend(Options.IsFridayWeekend);
        if (Polynom) {
            meta->SetPolynomData(Options.Polynom);
        }
        meta->SetBaseIndex(Options.BaseIndex);
        meta->SetVersion(ToString(Options.Version));
    }
}

void NDrive::TGeoLocalModel::Serialize(NJson::TJsonValue& json) const {
    json["name"] = Options.Name;
    json["type"] = Type();
    json["version"] = ToString(Options.Version);
    json["polynom"] = Options.Polynom;
    json["base_index"] = Options.BaseIndex;
    json["is_friday_weekend"] = Options.IsFridayWeekend;
    json["element_count"] = Options.Elements.size();
}

double NDrive::TRoundModel::Calc(TOfferFeatures& features) const {
    const auto original = features.Floats[NDriveOfferFactors::FI_PRICE];
    const auto normalized = static_cast<i64>(original * 100);
    const auto lowerBound = normalized - (normalized % 100) - 1;
    const auto discount = 1.0 * (normalized - lowerBound) / normalized;
    if (discount <= Options.MaxDiscount) {
        auto result = 0.01 * lowerBound;
        features.Floats[NDriveOfferFactors::FI_MATRIXNET] = result;
        return result;
    } else {
        return original;
    }
}

void NDrive::TRoundModel::Serialize(NDrive::NProto::TOfferModel& proto) const {
    auto meta = proto.MutableRoundModel();
    {
        meta->SetMaxDiscount(Options.MaxDiscount);
    }
}

double NDrive::TRandomSetModel::Calc(TOfferFeatures& features) const {
    if (Values.empty()) {
        return features.Floats[NDriveOfferFactors::FI_PRICE];
    }
    auto hash = Hash(features);
    return Values[hash % Values.size()];
}

void NDrive::TRandomSetModel::Serialize(NDrive::NProto::TOfferModel& proto) const {
    auto meta = proto.MutableRandomSetModel();
    meta->SetHashType(Options.HashType);
    for (auto&& value : Options.Values) {
        meta->AddValue(value);
    }
}

NDrive::TTimeScheduleModel::TTimeScheduleModel(const TString& name, THolder<TPriceByTimeConfig>&& config)
    : Name(name)
    , Config(std::move(config))
{
}

NDrive::TTimeScheduleModel::~TTimeScheduleModel() {
}

double NDrive::TTimeScheduleModel::Calc(NDrive::TOfferFeatures& features) const {
    if (Config) {
        TInstant timestamp = TInstant::Seconds(features.Floats[NDriveOfferFactors::FI_TIMESTAMP]);
        return 0.01 * Config->GetBasePrice(timestamp);
    } else {
        return features.Floats[NDriveOfferFactors::FI_PRICE];
    }
}

void NDrive::TTimeScheduleModel::Serialize(NDrive::NProto::TOfferModel& proto) const {
    auto meta = proto.MutableTimeScheduleModel();
    if (Config) {
        meta->SetConfig(Config->SerializeToJson().GetStringRobust());
    }
}

namespace {
    const TMap<TString, NDrive::TFeaturesHash> HashFunctions = {
        { "", CalcJitterHash },
        { "jitter", CalcJitterHash },
        { "user-timestamp-3", CalcUserTimestamp3Hash },
    };
}

NDrive::TFeaturesHash NDrive::CreateHash(TStringBuf type) {
    auto p = HashFunctions.find(type);
    Y_ENSURE(p != HashFunctions.end(), "unknown hash type: " << type);
    return p->second;
}

NDrive::TLuaModel::TLuaModel(TString name, TString script)
    : Name(std::move(name))
    , Script(std::move(script))
{
}

NDrive::TLuaModel::~TLuaModel() {
}

double NDrive::TLuaModel::Calc(TOfferFeatures& features) const try {
    NJson::TJsonValue value;
    {
        const auto* floatsFactorInfo2 = NDrive::GetOfferFactorsInfo2();
        value.InsertValue("floats2", NJson::TJsonValue());
        for (size_t i = 0; i < floatsFactorInfo2->GetFactorCount(); ++i) {
            value["floats2"].InsertValue(Sprintf("%s", floatsFactorInfo2->GetFactorName(i)), features.Floats2[i]);
        }
    }
    {
        const auto* catFactorInfo2 = NDrive::GetCatOfferFactorsInfo2();
        value.InsertValue("categories2", NJson::TJsonValue());
        for (size_t i = 0; i < features.Categories2.size(); ++i) {
            value["categories2"].InsertValue(Sprintf("%s", catFactorInfo2->GetFactorName(i)), features.Categories2[i]);
        }
    }

    if (!TlsRef(LuaEvaluator)) {
        LuaEvaluator = MakeSimpleShared<TLuaEval>();
    }
    auto luaEvaluator = TlsRef(LuaEvaluator).Get();
    luaEvaluator->SetVariable("coefs", value);
    return FromString<double>(luaEvaluator->EvalRaw(Script));
} catch (...) {
    NDrive::TEventLog::Log("LuaModelCalcError", NJson::TMapBuilder
        ("name", GetName())
        ("features", NJson::ToJson(features))
        ("error", CurrentExceptionInfo())
    );
    throw TCodedException(HTTP_INTERNAL_SERVER_ERROR).AddInfo("model", GetName()) << "lua error";
}

void NDrive::TLuaModel::Serialize(NDrive::NProto::TOfferModel& proto) const {
    auto meta = proto.MutableLuaModel();
    if (Script) {
        meta->SetScript(Script);
    }
}

void NDrive::TLuaModel::Serialize(NJson::TJsonValue& json) const {
    json["name"] = Name;
    json["type"] = Type();
    json["script"] = Script;
}

void NDrive::TLuaModel::Validate() const {
    TOfferFeatures features{};
    try {
        Calc(features);
    } catch (const yexception& error) { // TLuaStateHolder::TError can not catch non-lua errors =, i.e. parse of 1/0.
        ythrow TCodedException(400) << "Error while validating model \"" << Name << "\":" << error.what();
    }
}

NDrive::TOfferMultiModel::TOfferMultiModel(const TVector<TOfferModelConstPtr>& models)
    : Models(models)
{
}

double NDrive::TOfferMultiModel::Calc(TOfferFeatures& features) const {
    double result = 0;
    CalcModelResultFeatures(features, result);
    for (auto&& model : Models) {
        Y_ENSURE(model);
        result = model->Calc(features);
        CalcModelResultFeatures(features, result);
    }
    return result;
}
