#pragma once

#include "features.h"

#include <drive/backend/models/fwd.h>
#include <util/datetime/base.h>
#include <util/generic/array_ref.h>
#include <util/string/cast.h>

namespace NCatboostCalcer {
    class TCatboostCalcer;
    class TMulticlassPredictor;
}

namespace NDrive::NProto {
    class TOfferModel;
}

namespace NLinearModelsTree {
    namespace NProto {
        class TTree;
    }
    class TModel;
}

class TGeoCoord;
class TPriceByTimeConfig;
struct SRelevanceFormula;

namespace NDrive {
    using TFeaturesHash = std::function<ui32(const TOfferFeatures&)>;
    TFeaturesHash CreateHash(TStringBuf type);

    class IOfferModel {
    public:
        static THolder<IOfferModel> Construct(const TString& data);
        static THolder<IOfferModel> Construct(const NDrive::NProto::TOfferModel& proto);
        static THolder<IOfferModel> Construct(const NJson::TJsonValue& json) noexcept(false);

    public:
        virtual ~IOfferModel() = default;

        virtual double Calc(TOfferFeatures& features) const = 0;
        virtual void Calc(TArrayRef<TOfferFeatures> features, TArrayRef<double> results) const;

        virtual TString GetName() const = 0;
        virtual TString GetType() const = 0;

        template <class T>
        T Serialize() const {
            T result;
            Serialize(result);
            return result;
        }

    protected:
        virtual void Serialize(TString& data) const;
        virtual void Serialize(NDrive::NProto::TOfferModel& proto) const = 0;
        virtual void Serialize(NJson::TJsonValue& json) const;
    };

    class IMulticlassModel {
    public:
        virtual ~IMulticlassModel() = default;

        virtual TVector<double> Predict(const TOfferFeatures& features) const = 0;
    };

    class TMetaMulticlassModel
        : public IOfferModel
        , public IMulticlassModel
    {
    public:
        using TSubmodels = TVector<THolder<IOfferModel>>;

    public:
        static TStringBuf Type() {
            return "meta_multiclass"sv;
        }

    public:
        TMetaMulticlassModel(const TString& name, TSubmodels&& submodels)
            : Name(name)
            , Submodels(std::move(submodels))
        {
        }

        TString GetName() const override {
            return Name;
        }
        TString GetType() const override {
            return ToString(Type());
        }

        virtual double Calc(TOfferFeatures& features) const override;
        virtual TVector<double> Predict(const TOfferFeatures& features) const override;

    protected:
        virtual void Serialize(NDrive::NProto::TOfferModel& proto) const override;
        virtual void Serialize(NJson::TJsonValue& json) const override;

    private:
        TString Name;
        TSubmodels Submodels;
    };

    class TCatboostModel : public IOfferModel {
    public:
        struct TOptions {
            TString Name;
            TString CatboostData;
            TString PolynomData;
            NJson::TJsonValue LMTreeJson;
            const NLinearModelsTree::NProto::TTree* LMTreeProto = nullptr;
            bool UseNewFeatures = false;
        };

    public:
        static TStringBuf Type() {
            return "catboost"sv;
        }

    public:
        TCatboostModel(const TOptions& options);

        TString GetName() const override {
            return Options.Name;
        }
        TString GetType() const override {
            return ToString(Type());
        }

        virtual double Calc(TOfferFeatures& features) const override;

    protected:
        virtual void Serialize(NDrive::NProto::TOfferModel& proto) const override;
        virtual void Serialize(NJson::TJsonValue& json) const override;

    private:
        TOptions Options;
        THolder<NCatboostCalcer::TCatboostCalcer> Catboost;
        THolder<NLinearModelsTree::TModel> LMTree;
        THolder<SRelevanceFormula> Polynom;
    };

    class TCatboostMulticlassModel
        : public IOfferModel
        , public IMulticlassModel
    {
    public:
        struct TOptions {
            TString Name;
            TString CatboostData;
            bool SoftMax = false;
            bool UseNewFeatures = false;
        };

    public:
        static TStringBuf Type() {
            return "catboost_multiclass"sv;
        }

    public:
        TCatboostMulticlassModel(const TOptions& options);

        virtual TVector<double> Predict(const TOfferFeatures& features) const override;
        virtual double Calc(TOfferFeatures& features) const override {
            return features.Floats[NDriveOfferFactors::FI_PRICE];
        }
        TString GetName() const override {
            return Options.Name;
        }
        TString GetType() const override {
            return ToString(Type());
        }

    protected:
        const NCatboostCalcer::TMulticlassPredictor& GetPredictor() const;

        virtual void Serialize(NDrive::NProto::TOfferModel& proto) const override;
        virtual void Serialize(NJson::TJsonValue& json) const override;

    private:
        TOptions Options;
        THolder<NCatboostCalcer::TMulticlassPredictor> Predictor;
    };

    class TAdditionThresholdModel : public IOfferModel {
    public:
        struct TOptions {
            TString Name;
            TString ModelData;
            TString ModelFilename;
            TString PolynomData;
            size_t BaseIndex = NDriveOfferFactors::FI_PRICE;
            double Threshold = 0;
            double Addition = 0;
        };

    public:
        static TStringBuf Type() {
            return "addition_threshold"sv;
        }

    public:
        TAdditionThresholdModel(const TOptions& options);
        ~TAdditionThresholdModel();

        const TOptions& GetOptions() const {
            return Options;
        }
        virtual TString GetName() const override {
            return Options.Name;
        }
        virtual TString GetType() const override {
            return ToString(Type());
        }

        virtual double Calc(TOfferFeatures& features) const override;

    protected:
        virtual double CalcOne(TOfferFeatures& features) const;
        virtual void Serialize(NDrive::NProto::TOfferModel& proto) const override;
        virtual void Serialize(NJson::TJsonValue& json) const override;

    private:
        const TOptions Options;

        THolder<NCatboostCalcer::TCatboostCalcer> Predictor;
        THolder<SRelevanceFormula> Polynom;
    };

    class TConstantOfferModel : public IOfferModel {
    public:
        static TStringBuf Type() {
            return "stub"sv;
        }

    public:
        TConstantOfferModel(const TString& name, double value)
            : Name(name)
            , Value(value)
        {
        }

        TString GetName() const override {
            return Name;
        }
        TString GetType() const override {
            return ToString(Type());
        }

        virtual double Calc(TOfferFeatures& features) const override;

    protected:
        virtual void Serialize(NDrive::NProto::TOfferModel& proto) const override;

    private:
        TString Name;
        double Value;
    };

    class TJitterOfferModel : public IOfferModel {
    public:
        static TStringBuf Type() {
            return "jitter"sv;
        }

    public:
        TJitterOfferModel(const TString& name, const TString& hashType, double min, double max, double d = 0, double m = 0);

        TString GetName() const override {
            return Name;
        }
        TString GetType() const override {
            return ToString(Type());
        }

        virtual double Calc(TOfferFeatures& features) const override;

    protected:
        virtual void Serialize(NDrive::NProto::TOfferModel& proto) const override;

    private:
        TString Name;
        TString HashType;
        double MaxMultiplier = 1.05;
        double MinMultiplier = 0.95;
        double NormalDispersion = 0;
        double NormalMean = 0;

        TFeaturesHash Hash;
    };

    struct TSupplyDemandStat;
    class TSupplyDemandHistory;
    class TSupplyDemandHistoryModel : public IOfferModel {
    public:
        struct TOptions {
            double SigmoidDelta = 1;
            double SigmoidMultiplier = 5;

            ui32 FutureDepth = 5;
            double FutureDemandDiscount = 0.8;
            double FutureSupplyDiscount = 0;
            TDuration FutureStep = TDuration::Minutes(10);
        };

    public:
        static TStringBuf Type() {
            return "supply_demand_history"sv;
        }

    public:
        TSupplyDemandHistoryModel(const TString& name, THolder<TSupplyDemandHistory>&& history, THolder<SRelevanceFormula>&& polynom, const TOptions& options);
        ~TSupplyDemandHistoryModel();

        TString GetName() const override {
            return Name;
        }
        TString GetType() const override {
            return ToString(Type());
        }

        virtual double Calc(TOfferFeatures& features) const override;

    protected:
        virtual void Serialize(NDrive::NProto::TOfferModel& proto) const override;

    private:
        TSupplyDemandStat GetHistory(const TGeoCoord& coordinate, TDuration offset) const;

    private:
        TString Name;
        THolder<TSupplyDemandHistory> History;
        THolder<SRelevanceFormula> Polynom;
        TOptions Options;
    };

    class TVariationOptimizerModel : public IOfferModel {
    public:
        struct TOptions {
            TString Name;
            TString CatboostData;
            TString Catboost2Data;
            TString PolynomData;

            size_t VariationIndex = NDriveOfferFactors::FI_PRICE;
            double VariationAbsoluteMax = 0;
            double VariationAbsoluteMin = 0;
            double VariationMax = 1;
            double VariationMin = 0;
            double VariationStep = 0.1;
        };

    public:
        static TStringBuf Type() {
            return "variation_optimizer"sv;
        }

    public:
        TVariationOptimizerModel(const TOptions& options);

        TString GetName() const override {
            return Options.Name;
        }
        TString GetType() const override {
            return ToString(Type());
        }

        virtual double Calc(TOfferFeatures& features) const override;

    protected:
        virtual double CalcOne(TOfferFeatures& features) const;
        virtual void Serialize(NDrive::NProto::TOfferModel& proto) const override;

    private:
        TOptions Options;
        THolder<NCatboostCalcer::TCatboostCalcer> Catboost;
        THolder<NCatboostCalcer::TCatboostCalcer> Catboost2;
        THolder<SRelevanceFormula> Polynom;
    };

    class TGeoLocalModel: public IOfferModel {
    public:
        enum class EVersion {
            V1,
            V2,
        };

        struct TElement {
            ui32 Hash;
            float Value;

            inline TElement() = default;
            inline TElement(ui32 hash, float value)
                : Hash(hash)
                , Value(value)
            {
            }

            inline bool operator<(const TElement& other) const {
                return Hash < other.Hash;
            }
            inline bool operator<(ui32 hash) const {
                return Hash < hash;
            }

            Y_SAVELOAD_DEFINE(
                Hash,
                Value
            );
        };
        using TElements = TVector<TElement>;

        struct TOptions {
            TString Name;
            TElements Elements;
            TString Polynom;
            size_t BaseIndex = NDriveOfferFactors::FI_PRICE;
            EVersion Version = EVersion::V1;
            bool IsFridayWeekend = false;
        };

    public:
        static ui32 CalcHash(const TOfferFeatures& features, bool isFridayWeekend);
        static ui32 CalcHash(bool isBaseModel, bool isWeekend, ui8 hour, TStringBuf street);
        static ui32 CalcHash2(TStringBuf model, ui8 weekday, ui8 hour, TStringBuf location);

        static TString GetCoordinates(const TOfferFeatures& features);
        static ui8 GetHour(const TOfferFeatures& features);

        static TStringBuf Type() {
            return "geolocal"sv;
        }

    public:
        TGeoLocalModel(const TOptions& options);
        ~TGeoLocalModel();

        TString GetName() const override {
            return Options.Name;
        }
        TString GetType() const override {
            return ToString(Type());
        }

        virtual double Calc(TOfferFeatures& features) const override;

    protected:
        TMaybe<float> Find(const TOfferFeatures& features) const;

        virtual void Serialize(NDrive::NProto::TOfferModel& proto) const override;
        virtual void Serialize(NJson::TJsonValue& json) const override;

    private:
        TOptions Options;
        THolder<SRelevanceFormula> Polynom;
    };

    class TRoundModel: public IOfferModel {
    public:
        struct TOptions {
            TString Name;
            float MaxDiscount = 0.07;
        };

    public:
        static TStringBuf Type() {
            return "round"sv;
        }

    public:
        TRoundModel(const TOptions& options)
            : Options(options)
        {
        }

        TString GetName() const override {
            return Options.Name;
        }
        TString GetType() const override {
            return ToString(Type());
        }

        virtual double Calc(TOfferFeatures& features) const override;

    protected:
        virtual void Serialize(NDrive::NProto::TOfferModel& proto) const override;

    private:
        TOptions Options;
    };

    class TRandomSetModel : public IOfferModel {
    public:
        struct TOptions {
            TString Name;
            TString HashType;
            TSet<double> Values;
        };

    public:
        static TStringBuf Type() {
            return "random_set"sv;
        }

    public:
        TRandomSetModel(const TOptions& options)
            : Options(options)
            , Values(options.Values.begin(), options.Values.end())
        {
            Hash = CreateHash(Options.HashType);
        }

        TString GetName() const override {
            return Options.Name;
        }
        TString GetType() const override {
            return ToString(Type());
        }

        virtual double Calc(TOfferFeatures& features) const override;

    protected:
        virtual void Serialize(NDrive::NProto::TOfferModel& proto) const override;

    private:
        TOptions Options;
        TVector<double> Values;

        TFeaturesHash Hash;
    };

    class TTimeScheduleModel: public IOfferModel {
    public:
        static TStringBuf Type() {
            return "time_schedule"sv;
        }

    public:
        TTimeScheduleModel(const TString& name, THolder<TPriceByTimeConfig>&& config);
        ~TTimeScheduleModel();

        TString GetName() const override {
            return Name;
        }
        TString GetType() const override {
            return ToString(Type());
        }

        virtual double Calc(TOfferFeatures& features) const override;

    protected:
        virtual void Serialize(NDrive::NProto::TOfferModel& proto) const override;

    private:
        TString Name;
        THolder<TPriceByTimeConfig> Config;
    };

    class TLuaModel : public IOfferModel {
    public:
        static TStringBuf Type() {
            return "lua"sv;
        }

    public:
        TLuaModel(TString name, TString script);
        ~TLuaModel();

        TString GetName() const override {
            return Name;
        }

        TString GetType() const override {
            return ToString(Type());
        }

        double Calc(TOfferFeatures& features) const override;

        void Validate() const;

    protected:
        void Serialize(NDrive::NProto::TOfferModel& proto) const override;
        void Serialize(NJson::TJsonValue& json) const override;

    private:
        TString Name;
        TString Script;
    };

    // TOfferMultiModel is an array of models that are applied sequentially.
    class TOfferMultiModel {
    public:
        TOfferMultiModel(const TVector<TOfferModelConstPtr>& models);

        double Calc(TOfferFeatures& features) const;

    private:
        TVector<TOfferModelConstPtr> Models;
    };
}
