#include "matrixnet.h"

#include <library/cpp/pybind/typedesc.h>
#include <quality/deprecated/Misc/FastIO.h>
#include <quality/relev_tools/mx_model_lib/mx_model.h>
#include <util/folder/path.h>
#include <util/system/winint.h>

#include <kernel/matrixnet/convert.h>

#include <util/stream/file.h>
#include <util/stream/str.h>
#include <util/ysaveload.h>
#include <kernel/matrixnet/mn_dynamic.h>
#include <kernel/matrixnet/mn_multi_categ.h>
#include <quality/relev_tools/mx_model_lib/nfelem.h>

#include <quality/relev_tools/mx_ops/lib/sliced_input_helper.h>

struct TMXNetInfoHolder {
    THolder<NMatrixnet::TMnSseDynamic> MXNetInfo;

    TMXNetInfoHolder(NMatrixnet::TMnSseDynamic *info)
        : MXNetInfo(info) {
    }
};

struct TMXNetMCHolder {
    THolder<NMatrixnet::TMnMultiCateg> MXNetMC;

    TMXNetMCHolder(NMatrixnet::TMnMultiCateg *info)
        : MXNetMC(info) {
    }
};

class TMXNetInfoTraits : public NPyBind::TPythonType<TMXNetInfoHolder, NMatrixnet::TMnSseDynamic, TMXNetInfoTraits> {
private:
    typedef class NPyBind::TPythonType<TMXNetInfoHolder, NMatrixnet::TMnSseDynamic, TMXNetInfoTraits> TParent;
    friend class NPyBind::TPythonType<TMXNetInfoHolder, NMatrixnet::TMnSseDynamic, TMXNetInfoTraits>;
    TMXNetInfoTraits();

public:
    static NMatrixnet::TMnSseDynamic *GetObject(const TMXNetInfoHolder &holder) {
        return holder.MXNetInfo.Get();
    }

    static TMXNetInfoHolder *DoInitObject(PyObject *args, PyObject *kwargs);
    static TMXNetInfoHolder *DoInitPureObject(const TVector<TString> &);
};

class TMXNetMCTraits : public NPyBind::TPythonType<TMXNetMCHolder, NMatrixnet::TMnMultiCateg, TMXNetMCTraits> {
private:
    typedef class NPyBind::TPythonType<TMXNetMCHolder, NMatrixnet::TMnMultiCateg, TMXNetMCTraits> TParent;
    friend class NPyBind::TPythonType<TMXNetMCHolder, NMatrixnet::TMnMultiCateg, TMXNetMCTraits>;
    TMXNetMCTraits();

public:
    static NMatrixnet::TMnMultiCateg *GetObject(const TMXNetMCHolder &holder) {
        return holder.MXNetMC.Get();
    }

    static TMXNetMCHolder *DoInitObject(PyObject *args, PyObject *kwargs);
    static TMXNetMCHolder *DoInitPureObject(const TVector<TString> &);
};


namespace NCustomSlices {
    struct TSparseFactorsHolder {
        TVector<float>* Factors;
        struct TBeginSize {
            size_t Begin;
            size_t Size;
        };
        THashMap<TString, TBeginSize>* SliceNameToBorders;
    };

    TArrayRef<const float> GetFactorsRegion(const TSparseFactorsHolder& factors, TStringBuf name) {
        if (const TSparseFactorsHolder::TBeginSize* ptr = factors.SliceNameToBorders->FindPtr(name)) {
            return {factors.Factors->begin() + ptr->Begin, ptr->Size};
        }
        return {};
    }
};


class TMXNetInfoSlicedCalcCaller: public NPyBind::TBaseMethodCaller<NMatrixnet::TMnSseDynamic> {
private:
    void CheckSlices(const NMLPool::TFeatureSlices& formulaBorders,
                     const TString& formulaSlices,
                     const THashMap<TString,
                                    NCustomSlices::TSparseFactorsHolder::TBeginSize>
                                        & sliceNameToBorders,
                     const TString& dataSlices,
                     yssize_t totalGivenFactorsBySlices,
                     yssize_t maximalDescribedFactorBySlice,
                     const TVector<TVector<float>>& factors) const {
        TStringStream error("Can't calculate: found more formula factors than data has:");
        bool foundBadSlices = false;
        for (const auto& slice : formulaBorders) {
            size_t numFactorsForSlice = slice.End - slice.Begin;
            if (numFactorsForSlice > 0) {
                if (const NCustomSlices::TSparseFactorsHolder::TBeginSize* ptr = sliceNameToBorders.FindPtr(slice.Name)) {
                    if (numFactorsForSlice > ptr->Size) {
                        foundBadSlices = true;
                        error << "formula slice " << slice.Name << "["
                              << slice.Begin << ";" << slice.End << ")"
                              << " data slice " << slice.Name << "["
                              << ptr->Begin << ";"
                              << ptr->Begin + ptr->Size << "). ";
                    }
                } else {
                    foundBadSlices = true;
                    error << "formula slice " << slice.Name
                          << "[" << slice.Begin << ";" << slice.End
                          << ") is missing in data. ";
                }
            }
        }
        if (foundBadSlices) {
            ythrow yexception() << error.Str();
        }
        if (totalGivenFactorsBySlices != maximalDescribedFactorBySlice) {
            ythrow yexception() << "Number of factors described by"
                                << " factor slices " << formulaSlices
                                << " is " << totalGivenFactorsBySlices
                                << " while maximum number is"
                                << maximalDescribedFactorBySlice;
        }
        for (yssize_t i = 0; i < factors.ysize(); ++i) {
            if (factors[i].ysize() != totalGivenFactorsBySlices) {
                ythrow yexception() << "Line " << i << " has " << factors[i].ysize()
                                    << " while given factor slices "
                                    << dataSlices
                                    << " describe " << totalGivenFactorsBySlices;
            }
        }
    }
public:
    bool CallMethod(PyObject *, NMatrixnet::TMnSseDynamic *self, PyObject *args, PyObject *kwargs, PyObject *&res) const override {
        try {
            TVector< TVector<float> > factors;
            TString dataSlices = "";
            bool disallowCustomSlices = false;
            bool skipSlicesCheck = false;
            const char* keywords[] = {"factors", "slices", "disallow_custom_slices", "skip_slices_check", nullptr};
            if (!NPyBind::ExtractOptionalArgs(args, kwargs, keywords, factors, dataSlices, disallowCustomSlices, skipSlicesCheck)) {
                ythrow yexception() << "Can't calculate: factors should be matrix of float or unknown flag given";
            }
            bool useCustomSlices = !disallowCustomSlices;

            TVector<double> result;
            result.resize(factors.size());
            TString formulaSlices = "";
            if (self->GetInfo()->contains("Slices")) {
                formulaSlices = self->GetInfo()->at("Slices");
            }
            if (!formulaSlices && !!dataSlices) {
                ythrow yexception() << "Can't calculate: formula has no slices info and data has slices: " + dataSlices;
                ythrow yexception() << "Can't calculate: formula has no slices info and data has slices: " + dataSlices;
            }
            if (!!formulaSlices && !dataSlices) {
                ythrow yexception() << "Can't calculate: data has no slices and formula has slices info: " + formulaSlices;
            }
            if (!formulaSlices) { //No slices, use simple old code
                TVector<double> result;
                result.resize(factors.size());
                self->CalcRelevs(factors, result);
                res = NPyBind::BuildPyObject(result);
                return true;
            }
            NMLPool::TFeatureSlices dataBorders;
            NFactorSlices::DeserializeFeatureSlices(dataSlices, useCustomSlices, dataBorders);

            NMLPool::TFeatureSlices formulaBorders;
            NFactorSlices::DeserializeFeatureSlices(formulaSlices, useCustomSlices, formulaBorders);

            yssize_t totalGivenFactorsBySlices = 0;
            yssize_t maximalDescribedFactorBySlice = 0;
            THashMap<TString, NCustomSlices::TSparseFactorsHolder::TBeginSize> sliceNameToBorders;
            for (const auto& slice : dataBorders) {
                sliceNameToBorders[slice.Name] = {slice.Begin, slice.End - slice.Begin};
                totalGivenFactorsBySlices += slice.End - slice.Begin;
                maximalDescribedFactorBySlice = Max(maximalDescribedFactorBySlice,
                                                    static_cast<yssize_t>(slice.End));
            }

            /*
               I see no reason to allow to apply formula that requires factor missing in data.
               But some users seem to do it, so I can't break it.
               It's still allowed under flag skip_slices_check.
            */
            if (!skipSlicesCheck) {
                CheckSlices(formulaBorders, formulaSlices,
                            sliceNameToBorders, dataSlices,
                            totalGivenFactorsBySlices,
                            maximalDescribedFactorBySlice, factors);
            }

            for (const auto& slice : formulaBorders) {
                if (NCustomSlices::TSparseFactorsHolder::TBeginSize* ptr = sliceNameToBorders.FindPtr(slice.Name)) {
                    ptr->Size = Min(ptr->Size, slice.Begin - slice.End);
                } else {
                    sliceNameToBorders[slice.Name] = {0, 0};
                }
            }

            TVector<NCustomSlices::TSparseFactorsHolder> sparse_factors;
            for (size_t i = 0; i < factors.size(); ++i) {
                sparse_factors.push_back({&factors[i], &sliceNameToBorders});
            }

            NMatrixnet::TMnSseDynamic calcer(*self); //Have no other way to renew slices, should they be changed with "SetProperty"
            calcer.SetInfo("Slices", formulaSlices);
            calcer.SetInfo("CustomSlices", ToString(useCustomSlices));
            calcer.SetSlicesFromInfo();
            calcer.CustomSlicedCalcRelevs<NCustomSlices::TSparseFactorsHolder>(
                &sparse_factors[0],
                result, factors.size());
            res = NPyBind::BuildPyObject(result);
            return true;
        } catch (const std::exception &ex) {
            PyErr_SetString(PyExc_RuntimeError, ex.what());
        } catch (...) {
            PyErr_SetString(PyExc_RuntimeError, "Can't calculate: unknown exception");
        }
        res = nullptr;
        return true;
    }
};

class TMXNetInfoCopyTruncatedCaller : public NPyBind::TBaseMethodCaller<NMatrixnet::TMnSseDynamic> {
public:
    bool CallMethod(PyObject *, NMatrixnet::TMnSseDynamic *self, PyObject *args, PyObject*, PyObject *&res) const override {
        try {
            size_t begin = 0, end = 0;
            if (!NPyBind::ExtractArgs(args, end) && !NPyBind::ExtractArgs(args, begin, end))
                ythrow yexception() << "Can't truncate: it should be one or two parameters - treeCount or (begin, end)";
            THolder<NMatrixnet::TMnSseDynamic> mx(new NMatrixnet::TMnSseDynamic(self->CopyTreeRange(begin, end)));
            NPyBind::TPyObjectPtr result(
                TMXNetInfoTraits::Instance().CreatePyObject(new TMXNetInfoHolder(mx.Release()))
            );
            res = result.RefGet();
            return true;
        } catch (const std::exception &ex) {
            PyErr_SetString(PyExc_RuntimeError, ex.what());
        } catch (...) {
            PyErr_SetString(PyExc_RuntimeError, "Can't truncate: unknown exception");
        }
        Py_INCREF(Py_None);
        res = Py_None;
        return true;
    }
};

class TMXNetInfoSplitBySpecifiedFactorsCaller : public NPyBind::TBaseMethodCaller<NMatrixnet::TMnSseDynamic> {
public:
    bool CallMethod(PyObject *, NMatrixnet::TMnSseDynamic *self, PyObject *args, PyObject*, PyObject *&res) const override {
        try {
            TVector<ui32> factorsVec;
            if (!NPyBind::ExtractArgs(args, factorsVec))
                ythrow yexception() << "Can't split: it should be exactly one parameter - list of factors";
            TSet<ui32> factors;
            factors.insert(factorsVec.begin(), factorsVec.end());
            THolder<NMatrixnet::TMnSseDynamic> mxWith(new NMatrixnet::TMnSseDynamic);
            THolder<NMatrixnet::TMnSseDynamic> mxWithout(new NMatrixnet::TMnSseDynamic);
            self->SplitTreesBySpecifiedFactors(factors, *mxWith, *mxWithout);
            NPyBind::TPyObjectPtr resWith(TMXNetInfoTraits::Instance().CreatePyObject(new TMXNetInfoHolder(mxWith.Release())));
            NPyBind::TPyObjectPtr resWithout(TMXNetInfoTraits::Instance().CreatePyObject(new TMXNetInfoHolder(mxWithout.Release())));
            res = BuildPyObject(std::make_pair(resWith, resWithout));
            return true;
        } catch (const std::exception &ex) {
            PyErr_SetString(PyExc_RuntimeError, ex.what());
        } catch (...) {
            PyErr_SetString(PyExc_RuntimeError, "Can't split: unknown exception");
        }
        Py_INCREF(Py_None);
        res = Py_None;
        return true;
    }
};


template<class TSubObject>
class TMXNetSetPropertyCaller : public NPyBind::TBaseMethodCaller<TSubObject> {
public:
    bool CallMethod(PyObject*, TSubObject *self, PyObject *args, PyObject*, PyObject *&res) const override {
        try {
            TString name, value;
            if (!NPyBind::ExtractArgs(args, name, value))
                ythrow yexception() << "Can't set property: it should be exactly two parameters - name and value";
            self->SetInfo(name, value);
            if (name == "Slices") {
                TSubObject newModel(*self);
                self->Swap(newModel);
            }

        } catch (const std::exception &ex) {
            PyErr_SetString(PyExc_RuntimeError, ex.what());
        } catch (...) {
            PyErr_SetString(PyExc_RuntimeError, "Can't set property: unknown exception");
        }
        Py_INCREF(Py_None);
        res = Py_None;
        return true;
    }
};


template<class TSubObject>
class TMXNetGetPropertyCaller : public NPyBind::TBaseMethodCaller<TSubObject> {
public:
    bool CallMethod(PyObject*, TSubObject *self, PyObject *args, PyObject*, PyObject *&res) const override {
        try {
            TString name;
            if (!NPyBind::ExtractArgs(args, name))
                ythrow yexception() << "Can't get property: it should be exactly one parameter - name of the property";
            auto info = self->GetInfo();
            if (info->contains(name)) {
                res = NPyBind::BuildPyObject(info->at(name));
            } else {
                Py_INCREF(Py_None);
                res = Py_None;
            }
        } catch (const std::exception &ex) {
            PyErr_SetString(PyExc_RuntimeError, ex.what());
        } catch (...) {
            PyErr_SetString(PyExc_RuntimeError, "Can't get property: unknown exception");
        }
        return true;
    }
};


template<class TSubObject>
class TMXNetGetPropertiesCaller : public NPyBind::TBaseMethodCaller<TSubObject> {
public:
    bool CallMethod(PyObject*, TSubObject *self, PyObject * /*args*/, PyObject*, PyObject *&res) const override {
        try {
            TVector<TString> result;
            for (auto prop : *self->GetInfo()) {
                result.push_back(prop.first);
            }
            res = NPyBind::BuildPyObject(result);
        } catch (const std::exception &ex) {
            PyErr_SetString(PyExc_RuntimeError, ex.what());
        } catch (...) {
            PyErr_SetString(PyExc_RuntimeError, "Can't get properties: unknown exception");
        }
        return true;
    }
};

template<class TSubObject>
class TMXNetGetTreeCountCaller : public NPyBind::TBaseMethodCaller<TSubObject> {
public:
    bool CallMethod(PyObject*, TSubObject *self, PyObject * /*args*/, PyObject*, PyObject *&res) const override {
        try {
            size_t result = self->NumTrees();

            res = NPyBind::BuildPyObject(result);
        } catch (const std::exception &ex) {
            PyErr_SetString(PyExc_RuntimeError, ex.what());
        } catch (...) {
            PyErr_SetString(PyExc_RuntimeError, "Can't get properties: unknown exception");
        }
        return true;
    }
};


template<class TSubObject>
class TMXNetGetUsedFactorsCaller : public NPyBind::TBaseMethodCaller<TSubObject> {
public:
    bool CallMethod(PyObject*, TSubObject *self, PyObject * /*args*/, PyObject*, PyObject *&res) const override {
        try {
            TSet<ui32> factors;
            self->UsedFactors(factors);
            res = NPyBind::BuildPyObject(factors);
        } catch (const std::exception &ex) {
            PyErr_SetString(PyExc_RuntimeError, ex.what());
        } catch (...) {
            PyErr_SetString(PyExc_RuntimeError, "Can't get used factors: unknown exception");
        }
        return true;
    }
};


template<class TSubObject>
class TObjectSaveCaller : public NPyBind::TBaseMethodCaller<TSubObject> {
public:
    bool CallMethod(PyObject*, TSubObject *self, PyObject *args, PyObject*, PyObject *&res) const override {
        try {
            TString path;
            if (!NPyBind::ExtractArgs(args, path))
                ythrow yexception() << "Can't save: it should be exactly one parameter - path for saving model";
            TFixedBufferFileOutput out(path);
            ::Save(&out, *self);
        } catch (const std::exception &ex) {
            PyErr_SetString(PyExc_RuntimeError, ex.what());
        } catch (...) {
            PyErr_SetString(PyExc_RuntimeError, "Can't save: unknown exception");
        }
        Py_INCREF(Py_None);
        res = Py_None;
        return true;
    }
};

template<class TSubObject>
class TObjectLoadCaller : public NPyBind::TBaseMethodCaller<TSubObject> {
public:
    bool CallMethod(PyObject*, TSubObject *self, PyObject *args, PyObject*, PyObject *&res) const override {
        try {
            TString path;
            if (!NPyBind::ExtractArgs(args, path))
                ythrow yexception() << "Can't load: it should be exactly one parameter - file path";
            TFileInput in(path);
            ::Load(&in, *self);
            Py_INCREF(Py_None);
            res = Py_None;
            return true;
        } catch (const std::exception &ex) {
            PyErr_SetString(PyExc_RuntimeError, ex.what());
        } catch (...) {
            PyErr_SetString(PyExc_RuntimeError, "Can't load: unknown exception");
        }
        res = nullptr;
        return true;
    }
};

template<class TSubObject>
class TObjectLoadInfoPartCaller : public NPyBind::TBaseMethodCaller<TSubObject> {
public:
    bool CallMethod(PyObject*, TSubObject *self, PyObject *args, PyObject*, PyObject *&res) const override {
        try {
            TString path;
            size_t begin = 0;
            unsigned long end = 0;
            if (!NPyBind::ExtractArgs(args, path) && !NPyBind::ExtractArgs(args, path, end) && !NPyBind::ExtractArgs(args, path, begin, end))
                ythrow yexception() << "Can't load: it should be either path, or (path, end) or (path, begin, end)";
            TFsPath modelPath(path);
            TFileInput input(modelPath);
            NMatrixnet::TMnSseDynamic fullModel;

            ::Load(&input, fullModel);
            NMatrixnet::TMnSseDynamic partModel(fullModel.CopyTreeRange(begin, end));
            self->Swap(partModel);
            Py_INCREF(Py_None);
            res = Py_None;
            return true;
        } catch (const std::exception &ex) {
            PyErr_SetString(PyExc_RuntimeError, ex.what());
        } catch (...) {
            PyErr_SetString(PyExc_RuntimeError, "Can't load: unknown exception");
        }
        res = nullptr;
        return true;
    }
};


template<class TSubObject>
class TObjectLoadBinPartCaller : public NPyBind::TBaseMethodCaller<TSubObject> {
public:
    bool CallMethod(PyObject*, TSubObject *self, PyObject *args, PyObject*, PyObject *&res) const override {
        try {
            TString path;
            size_t begin = 0;
            size_t end = 0;
            if (!NPyBind::ExtractArgs(args, path) && !NPyBind::ExtractArgs(args, path, end) && !NPyBind::ExtractArgs(args, path, begin, end))
                ythrow yexception() << "Can't load: it should be either path, or (path, end) or (path, begin, end)";
            TFsPath modelPath(path);
            TFileInput input(modelPath);
            TFullMatrixClassifierInfo fullModel;

            if (!Serialize(true, modelPath.c_str(), fullModel)) {
                ythrow yexception() << "Can't load model in .bin format from '" << modelPath << '\'';
            }

            size_t numberOfTrees = fullModel.NFList.size();
            if (end == 0) {
              end = numberOfTrees;
            }
            TFullMatrixClassifierInfo partModel;
            partModel.NFList = TVector<TNFElement>(
                fullModel.NFList.cbegin() + begin,
                fullModel.NFList.cbegin() + std::min(end, numberOfTrees)
            );
            partModel.BinFeatures = TVector<TBinaryFeatureStat>(fullModel.BinFeatures);
            MatrixnetConvert(partModel, *self);

            Py_INCREF(Py_None);
            res = Py_None;
            return true;
        } catch (const std::exception &ex) {
            PyErr_SetString(PyExc_RuntimeError, ex.what());
        } catch (...) {
            PyErr_SetString(PyExc_RuntimeError, "Can't load: unknown exception");
        }
        res = nullptr;
        return true;
    }
};




template<class TSubObject>
class TObjectGetStateCaller : public NPyBind::TBaseMethodCaller<TSubObject> {
public:
    bool CallMethod(PyObject*, TSubObject *self, PyObject*, PyObject*, PyObject *&res) const override {
        try {
            TString state;
            TStringOutput out(state);
            self->Save(&out);
            NPyBind::TPyObjectPtr result(NPyBind::BuildPyObject(state), true);
            res = result.RefGet();
            return true;
        } catch (const std::exception &ex) {
            PyErr_SetString(PyExc_RuntimeError, ex.what());
        } catch (...) {
            PyErr_SetString(PyExc_RuntimeError, "Can't get state: unknown exception");
        }
        res = nullptr;
        return true;
    }
};

template<class TSubObject>
class TObjectSetStateCaller : public NPyBind::TBaseMethodCaller<TSubObject> {
public:
    bool CallMethod(PyObject*, TSubObject *self, PyObject *args, PyObject*, PyObject *&res) const override {
        try {
            TString state;
            if (!NPyBind::ExtractArgs(args, state))
                ythrow yexception() << "Can't set state: it should be exactly one string parameter - the state";
            TStringInput in(state);
            ::Load(&in, *self);
            Py_INCREF(Py_None);
            res = Py_None;
            return true;
        } catch (const std::exception &ex) {
            PyErr_SetString(PyExc_RuntimeError, ex.what());
        } catch (...) {
            PyErr_SetString(PyExc_RuntimeError, "Can't set state: unknown exception");
        }
        res = nullptr;
        return true;
    }
};

class TMXNetMCCalcCaller : public NPyBind::TBaseMethodCaller<NMatrixnet::TMnMultiCateg> {
public:
    bool CallMethod(PyObject *, NMatrixnet::TMnMultiCateg *self, PyObject *args, PyObject*, PyObject *&res) const override {
        try {
            TVector< TVector<float> > factors;
            if (!NPyBind::ExtractArgs(args, factors))
                ythrow yexception() << "Can't calculate: factors should be matrix of float";
            TVector<double> result;
            result.resize(factors.size());
            self->CalcRelevs(factors, result);
            res = NPyBind::BuildPyObject(result);
            return true;
        } catch (const std::exception &ex) {
            PyErr_SetString(PyExc_RuntimeError, ex.what());
        } catch (...) {
            PyErr_SetString(PyExc_RuntimeError, "Can't calculate: unknown exception");
        }
        res = nullptr;
        return true;
    }
};

class TMXNetMCBaseCalcCategsCaller : public NPyBind::TBaseMethodCaller<NMatrixnet::TMnMultiCateg> {
private:
    virtual void DoCalculate(NMatrixnet::TMnMultiCateg *self, const TVector< TVector<float> > &factors, double *result) const = 0;

public:
    bool CallMethod(PyObject *, NMatrixnet::TMnMultiCateg *self, PyObject *args, PyObject*, PyObject *&res) const override {
        try {
            TVector< TVector<float> > factors;
            if (!NPyBind::ExtractArgs(args, factors))
                ythrow yexception() << "Can't calculate: factors should be matrix of float";
            TVector<double> allCategs(factors.size() * self->CategValues().size());
            DoCalculate(self, factors, allCategs.data());
            TVector< TVector<double> > result;
            for (size_t i = 0; i < factors.size(); ++i) {
                TVector<double> categs(allCategs.begin() + i * self->CategValues().size(),
                                       allCategs.begin() + (i + 1) * self->CategValues().size());
                result.push_back(categs);
            }
            res = NPyBind::BuildPyObject(result);
            return true;
        } catch (const std::exception &ex) {
            PyErr_SetString(PyExc_RuntimeError, ex.what());
        } catch (...) {
            PyErr_SetString(PyExc_RuntimeError, "Can't calculate: unknown exception");
        }
        res = nullptr;
        return true;
    }
};

class TMXNetMCCalcCategsCaller : public TMXNetMCBaseCalcCategsCaller {
    private:
        void DoCalculate(NMatrixnet::TMnMultiCateg *self, const TVector< TVector<float> > &factors, double *result) const {
            self->CalcCategs(factors, result);
        }
};

class TMXNetMCCalcCategoriesRankingCaller : public TMXNetMCBaseCalcCategsCaller {
    private:
        void DoCalculate(NMatrixnet::TMnMultiCateg *self, const TVector< TVector<float> > &factors, double *result) const {
            self->CalcCategoriesRanking(factors, result);
        }
};

class TMXNetMCCategValuesCaller : public NPyBind::TBaseMethodCaller<NMatrixnet::TMnMultiCateg> {
public:
    bool CallMethod(PyObject *, NMatrixnet::TMnMultiCateg *self, PyObject *, PyObject*, PyObject *&res) const override {
        try {
            auto categValuesRef = self->CategValues();
            TVector<double> categValues(categValuesRef.begin(), categValuesRef.end());
            res = NPyBind::BuildPyObject(categValues);
            return true;
        } catch (const std::exception &ex) {
            PyErr_SetString(PyExc_RuntimeError, ex.what());
        } catch (...) {
            PyErr_SetString(PyExc_RuntimeError, "Can't get categ values: unknown exception");
        }
        res = nullptr;
        return true;
    }
};

TMXNetInfoTraits::TMXNetInfoTraits()
    : TParent("libmxnet.TMXNetInfo", "matrixnet data")
{
    AddCaller("Calculate", new TMXNetInfoSlicedCalcCaller);
    AddCaller("CopyTruncated", new TMXNetInfoCopyTruncatedCaller());
    AddCaller("SplitBySpecifiedFactors", new TMXNetInfoSplitBySpecifiedFactorsCaller());
    AddCaller("SetProperty", new TMXNetSetPropertyCaller<NMatrixnet::TMnSseDynamic>());
    AddCaller("GetProperty", new TMXNetGetPropertyCaller<NMatrixnet::TMnSseDynamic>());
    AddCaller("GetProperties", new TMXNetGetPropertiesCaller<NMatrixnet::TMnSseDynamic>());
    AddCaller("GetUsedFactors", new TMXNetGetUsedFactorsCaller<NMatrixnet::TMnSseDynamic>());
    AddCaller("GetTreeCount", new TMXNetGetTreeCountCaller<NMatrixnet::TMnSseDynamic>());
    AddCaller("Save", new TObjectSaveCaller<NMatrixnet::TMnSseDynamic>());
    AddCaller("Load", new TObjectLoadCaller<NMatrixnet::TMnSseDynamic>());
    AddCaller("LoadBinPart", new TObjectLoadBinPartCaller<NMatrixnet::TMnSseDynamic>());
    AddCaller("LoadInfoPart", new TObjectLoadInfoPartCaller<NMatrixnet::TMnSseDynamic>());
    AddCaller("__getstate__", new TObjectGetStateCaller<NMatrixnet::TMnSseDynamic>());
    AddCaller("__setstate__", new TObjectSetStateCaller<NMatrixnet::TMnSseDynamic>());
}

TMXNetMCTraits::TMXNetMCTraits()
    : TParent("libmxnet.TMXNetMC", "matrixnet data")
{
    AddCaller("Calculate", new TMXNetMCCalcCaller);
    AddCaller("CalculateCategs", new TMXNetMCCalcCategsCaller);
    AddCaller("CalculateCategoriesRanking", new TMXNetMCCalcCategoriesRankingCaller);
    AddCaller("CategValues", new TMXNetMCCategValuesCaller);
    AddCaller("SetProperty", new TMXNetSetPropertyCaller<NMatrixnet::TMnMultiCateg>());
    AddCaller("GetProperty", new TMXNetGetPropertyCaller<NMatrixnet::TMnMultiCateg>());
    AddCaller("Save", new TObjectSaveCaller<NMatrixnet::TMnMultiCateg>());
    AddCaller("Load", new TObjectLoadCaller<NMatrixnet::TMnMultiCateg>());
    AddCaller("__getstate__", new TObjectGetStateCaller<NMatrixnet::TMnMultiCateg>());
    AddCaller("__setstate__", new TObjectSetStateCaller<NMatrixnet::TMnMultiCateg>());
}

TMXNetInfoHolder *TMXNetInfoTraits::DoInitPureObject(const TVector<TString> &) {
    THolder<NMatrixnet::TMnSseDynamic> mx(new NMatrixnet::TMnSseDynamic);
    return new TMXNetInfoHolder(mx.Release());
}

TMXNetInfoHolder *TMXNetInfoTraits::DoInitObject(PyObject *args, PyObject*) {
    try {
        THolder<NMatrixnet::TMnSseDynamic> mx(new NMatrixnet::TMnSseDynamic);
        TString modelPath;
        if (NPyBind::ExtractArgs(args, modelPath)) {
            TFileInput in(modelPath);
            ::Load(&in, *mx);
        } else if (args && (!PyTuple_Check(args) || PyTuple_Size(args) > 0)) {
            ythrow yexception() << "Can't create TMXNetInfo: there should be no params or exactly one parameter - path to model file";
        }
        return new TMXNetInfoHolder(mx.Release());
    } catch (const std::exception &ex) {
        PyErr_SetString(PyExc_RuntimeError, ex.what());
    } catch (...) {
        PyErr_SetString(PyExc_RuntimeError, "Can't create TMXNetInfo: unknown exception");
    }
    return nullptr;
}

TMXNetMCHolder *TMXNetMCTraits::DoInitPureObject(const TVector<TString> &) {
    THolder<NMatrixnet::TMnMultiCateg> mx(new NMatrixnet::TMnMultiCateg);
    return new TMXNetMCHolder(mx.Release());
}

TMXNetMCHolder *TMXNetMCTraits::DoInitObject(PyObject *args, PyObject*) {
    try {
        TString modelPath;
        if (!NPyBind::ExtractArgs(args, modelPath))
            ythrow yexception() << "Can't create TMXNetMC: it should be exactly one parameter - path to model file";
        TFileInput in(modelPath);
        THolder<NMatrixnet::TMnMultiCateg> mx(new NMatrixnet::TMnMultiCateg);
        ::Load(&in, *mx);
        return new TMXNetMCHolder(mx.Release());
    } catch (const std::exception &ex) {
        PyErr_SetString(PyExc_RuntimeError, ex.what());
    } catch (...) {
        PyErr_SetString(PyExc_RuntimeError, "Can't create TMXNetMC: unknown exception");
    }
    return nullptr;
}

static PyMethodDef LibMatrixnetMethods[] = {
    {nullptr, nullptr, 0, nullptr}
};

void DoInitLibmxnet() {
    PyObject* m = Py_InitModule("libmxnet", LibMatrixnetMethods);
    TMXNetInfoTraits::Instance().Register(m, "TMXNetInfo");
    TMXNetMCTraits::Instance().Register(m, "TMXNetMC");
}
