#include "perceptron.h"

#include <library/cpp/pybind/typedesc.h>

#include <library/cpp/perceptron/perceptron.h>


// TRumelhartPerceptron
struct TRumelhartPerceptronHolder {
    THolder<NNeuralNetwork::TRumelhartPerceptron> Perceptron;

    TRumelhartPerceptronHolder(NNeuralNetwork::TRumelhartPerceptron *perceptron)
        : Perceptron(perceptron) {
    }
};

class TRumelhartPerceptronTraits : public NPyBind::TPythonType<TRumelhartPerceptronHolder, NNeuralNetwork::TRumelhartPerceptron, TRumelhartPerceptronTraits> {
private:
    typedef class NPyBind::TPythonType<TRumelhartPerceptronHolder, NNeuralNetwork::TRumelhartPerceptron, TRumelhartPerceptronTraits> TParent;
    friend class NPyBind::TPythonType<TRumelhartPerceptronHolder, NNeuralNetwork::TRumelhartPerceptron, TRumelhartPerceptronTraits>;
    TRumelhartPerceptronTraits();

public:
    static NNeuralNetwork::TRumelhartPerceptron *GetObject(const TRumelhartPerceptronHolder &holder) {
        return holder.Perceptron.Get();
    }

    static TRumelhartPerceptronHolder *DoInitObject(PyObject *args, PyObject *kwargs);
};


TRumelhartPerceptronTraits::TRumelhartPerceptronTraits()
    : TParent("libperceptron.TRumelhartPerceptron", "Rumelhart's perceptron") {
    AddCaller("Save", NPyBind::CreateConstMethodCaller<NNeuralNetwork::TRumelhartPerceptron>(&NNeuralNetwork::TRumelhartPerceptron::SaveAsString));
    AddCaller("Load", NPyBind::CreateMethodCaller<NNeuralNetwork::TRumelhartPerceptron>(&NNeuralNetwork::TRumelhartPerceptron::LoadFromString));
    AddCaller("Calculate", NPyBind::CreateConstMethodCaller<NNeuralNetwork::TRumelhartPerceptron>(&NNeuralNetwork::TRumelhartPerceptron::CalculateAndReturnResult));
    AddCaller("CalculateReduced", NPyBind::CreateConstMethodCaller<NNeuralNetwork::TRumelhartPerceptron>(&NNeuralNetwork::TRumelhartPerceptron::CalculateReducedAndReturnResult));
}

TRumelhartPerceptronHolder *TRumelhartPerceptronTraits::DoInitObject(PyObject *args, PyObject *) {
    try {
        THolder<NNeuralNetwork::TRumelhartPerceptron> perceptron;
        double maxInit;
        TVector<size_t> dimensions;
        TString data;
        if (NPyBind::ExtractArgs(args, maxInit, dimensions)) {
            perceptron.Reset(new NNeuralNetwork::TRumelhartPerceptron(maxInit, dimensions));
        } else {
            perceptron.Reset(new NNeuralNetwork::TRumelhartPerceptron);
            if (NPyBind::ExtractArgs(args, data)) {
                perceptron->LoadFromString(data);
            }
        }
        if (perceptron.Get() == nullptr)
            ythrow yexception() << "Failed to construct TRumelhartPerceptron";
        return new TRumelhartPerceptronHolder(perceptron.Release());
    } catch (const std::exception &ex) {
        PyErr_SetString(PyExc_RuntimeError, ex.what());
    } catch (...) {
        PyErr_SetString(PyExc_RuntimeError, "Can't create TRumelhartPerceptron: unknown exception");
    }
    return nullptr;
}


// TRMSELearner
struct TRMSELearnerHolder {
    THolder<NNeuralNetwork::TRMSELearner> Learner;

    TRMSELearnerHolder(NNeuralNetwork::TRMSELearner *learner)
        : Learner(learner) {
    }
};

class TRMSELearnerTraits : public NPyBind::TPythonType<TRMSELearnerHolder, NNeuralNetwork::TRMSELearner, TRMSELearnerTraits> {
private:
    typedef class NPyBind::TPythonType<TRMSELearnerHolder, NNeuralNetwork::TRMSELearner, TRMSELearnerTraits> TParent;
    friend class NPyBind::TPythonType<TRMSELearnerHolder, NNeuralNetwork::TRMSELearner, TRMSELearnerTraits>;
    TRMSELearnerTraits();

public:
    static NNeuralNetwork::TRMSELearner *GetObject(const TRMSELearnerHolder &holder) {
        return holder.Learner.Get();
    }

    static TRMSELearnerHolder *DoInitObject(PyObject *args, PyObject *kwargs);
};


TRMSELearnerTraits::TRMSELearnerTraits()
    : TParent("libperceptron.TRMSELearner", "Class for learning Rumelhart's perceptron using standart rmse loss function") {
    AddCaller("Save", NPyBind::CreateConstMethodCaller<NNeuralNetwork::TRMSELearner>(&NNeuralNetwork::TRMSELearner::SavePerceptronAsString));
    AddCaller("Load", NPyBind::CreateMethodCaller<NNeuralNetwork::TRMSELearner>(&NNeuralNetwork::TRMSELearner::LoadPerceptronFromString));
    AddCaller("Add", NPyBind::CreateMethodCaller<NNeuralNetwork::TRMSELearner>(&NNeuralNetwork::TRMSELearner::Add));
    AddCaller("FlushAdaDelta", NPyBind::CreateMethodCaller<NNeuralNetwork::TRMSELearner>(&NNeuralNetwork::TBackPropagationLearner::FlushAdaDelta));
    AddCaller("FlushSimple", NPyBind::CreateMethodCaller<NNeuralNetwork::TRMSELearner>(&NNeuralNetwork::TBackPropagationLearner::FlushSimple));
    AddCaller("GradientChecking", NPyBind::CreateConstMethodCaller<NNeuralNetwork::TRMSELearner>(&NNeuralNetwork::TRMSELearner::GradientChecking));
}

TRMSELearnerHolder *TRMSELearnerTraits::DoInitObject(PyObject *args, PyObject *) {
    try {
        THolder<NNeuralNetwork::TRMSELearner> learner;
        if (args && PyTuple_Check(args) && PyTuple_Size(args) == 0) {
            learner.Reset(new NNeuralNetwork::TRMSELearner);
        } else if (args && PyTuple_Check(args) && PyTuple_Size(args) == 2) {
            double maxInit;
            TVector<size_t> dimensions;
            if (NPyBind::ExtractArgs(args, maxInit, dimensions)) {
                learner.Reset(new NNeuralNetwork::TRMSELearner(maxInit, dimensions));
            }
        }
        if (learner.Get() == nullptr)
            ythrow yexception() << "Failed to construct TRMSELearner";
        return new TRMSELearnerHolder(learner.Release());
    } catch (const std::exception &ex) {
        PyErr_SetString(PyExc_RuntimeError, ex.what());
    } catch (...) {
        PyErr_SetString(PyExc_RuntimeError, "Can't create TRMSELearner: unknown exception");
    }
    return nullptr;
}


// TRegressionLearner
struct TRegressionLearnerHolder {
    THolder<NNeuralNetwork::TRegressionLearner> Learner;

    TRegressionLearnerHolder(NNeuralNetwork::TRegressionLearner *learner)
        : Learner(learner) {
    }
};

class TRegressionLearnerTraits : public NPyBind::TPythonType<TRegressionLearnerHolder, NNeuralNetwork::TRegressionLearner, TRegressionLearnerTraits> {
private:
    typedef class NPyBind::TPythonType<TRegressionLearnerHolder, NNeuralNetwork::TRegressionLearner, TRegressionLearnerTraits> TParent;
    friend class NPyBind::TPythonType<TRegressionLearnerHolder, NNeuralNetwork::TRegressionLearner, TRegressionLearnerTraits>;
    TRegressionLearnerTraits();

public:
    static NNeuralNetwork::TRegressionLearner *GetObject(const TRegressionLearnerHolder &holder) {
        return holder.Learner.Get();
    }

    static TRegressionLearnerHolder *DoInitObject(PyObject *args, PyObject *kwargs);
};


TRegressionLearnerTraits::TRegressionLearnerTraits()
    : TParent("libperceptron.TRegressionLearner", "Class for learning Rumelhart's perceptron using basis decomposition loss function") {
    AddCaller("Save", NPyBind::CreateConstMethodCaller<NNeuralNetwork::TRegressionLearner>(&NNeuralNetwork::TRegressionLearner::SavePerceptronAsString));
    AddCaller("Load", NPyBind::CreateMethodCaller<NNeuralNetwork::TRegressionLearner>(&NNeuralNetwork::TRegressionLearner::LoadPerceptronFromString));
    AddCaller("Add", NPyBind::CreateMethodCaller<NNeuralNetwork::TRegressionLearner>(&NNeuralNetwork::TRegressionLearner::Add));
    AddCaller("FlushAdaDelta", NPyBind::CreateMethodCaller<NNeuralNetwork::TRegressionLearner>(&NNeuralNetwork::TBackPropagationLearner::FlushAdaDelta));
    AddCaller("FlushSimple", NPyBind::CreateMethodCaller<NNeuralNetwork::TRegressionLearner>(&NNeuralNetwork::TBackPropagationLearner::FlushSimple));
}

TRegressionLearnerHolder *TRegressionLearnerTraits::DoInitObject(PyObject *args, PyObject *) {
    try {
        THolder<NNeuralNetwork::TRegressionLearner> learner;
        TVector<double> basis;
        if (args && PyTuple_Check(args) && PyTuple_Size(args) == 1) {
            if (NPyBind::ExtractArgs(args, basis)) {
                learner.Reset(new NNeuralNetwork::TRegressionLearner(basis));
            }
        } else if (args && PyTuple_Check(args) && PyTuple_Size(args) == 3) {
            double maxInit;
            TVector<size_t> dimensions;
            if (NPyBind::ExtractArgs(args, basis, maxInit, dimensions)) {
                learner.Reset(new NNeuralNetwork::TRegressionLearner(basis, maxInit, dimensions));
            }
        }
        if (learner.Get() == nullptr)
            ythrow yexception() << "Failed to construct TRegressionLearner";
        return new TRegressionLearnerHolder(learner.Release());
    } catch (const std::exception &ex) {
        PyErr_SetString(PyExc_RuntimeError, ex.what());
    } catch (...) {
        PyErr_SetString(PyExc_RuntimeError, "Can't create TRegressionLearner: unknown exception");
    }
    return nullptr;
}


// TArgmaxLearner
struct TArgmaxLearnerHolder {
    THolder<NNeuralNetwork::TArgmaxLearner> Learner;

    TArgmaxLearnerHolder(NNeuralNetwork::TArgmaxLearner *learner)
        : Learner(learner) {
    }
};

class TArgmaxLearnerTraits : public NPyBind::TPythonType<TArgmaxLearnerHolder, NNeuralNetwork::TArgmaxLearner, TArgmaxLearnerTraits> {
private:
    typedef class NPyBind::TPythonType<TArgmaxLearnerHolder, NNeuralNetwork::TArgmaxLearner, TArgmaxLearnerTraits> TParent;
    friend class NPyBind::TPythonType<TArgmaxLearnerHolder, NNeuralNetwork::TArgmaxLearner, TArgmaxLearnerTraits>;
    TArgmaxLearnerTraits();

public:
    static NNeuralNetwork::TArgmaxLearner *GetObject(const TArgmaxLearnerHolder &holder) {
        return holder.Learner.Get();
    }

    static TArgmaxLearnerHolder *DoInitObject(PyObject *args, PyObject *kwargs);
};


TArgmaxLearnerTraits::TArgmaxLearnerTraits()
    : TParent("libperceptron.TArgmaxLearner", "Class for learning Rumelhart's perceptron using softargmax loss function") {
    AddCaller("Save", NPyBind::CreateConstMethodCaller<NNeuralNetwork::TArgmaxLearner>(&NNeuralNetwork::TArgmaxLearner::SavePerceptronAsString));
    AddCaller("Load", NPyBind::CreateMethodCaller<NNeuralNetwork::TArgmaxLearner>(&NNeuralNetwork::TArgmaxLearner::LoadPerceptronFromString));
    AddCaller("Add", NPyBind::CreateMethodCaller<NNeuralNetwork::TArgmaxLearner>(&NNeuralNetwork::TArgmaxLearner::Add));
    AddCaller("FlushAdaDelta", NPyBind::CreateMethodCaller<NNeuralNetwork::TArgmaxLearner>(&NNeuralNetwork::TBackPropagationLearner::FlushAdaDelta));
    AddCaller("FlushSimple", NPyBind::CreateMethodCaller<NNeuralNetwork::TArgmaxLearner>(&NNeuralNetwork::TBackPropagationLearner::FlushSimple));
}

TArgmaxLearnerHolder *TArgmaxLearnerTraits::DoInitObject(PyObject *args, PyObject *) {
    try {
        THolder<NNeuralNetwork::TArgmaxLearner> learner;
        if (args && PyTuple_Check(args) && PyTuple_Size(args) == 0) {
            learner.Reset(new NNeuralNetwork::TArgmaxLearner);
        } else if (args && PyTuple_Check(args) && PyTuple_Size(args) == 2) {
            double maxInit;
            TVector<size_t> dimensions;
            if (NPyBind::ExtractArgs(args, maxInit, dimensions)) {
                learner.Reset(new NNeuralNetwork::TArgmaxLearner(maxInit, dimensions));
            }
        }
        if (learner.Get() == nullptr)
            ythrow yexception() << "Failed to construct TArgmaxLearner";
        return new TArgmaxLearnerHolder(learner.Release());
    } catch (const std::exception &ex) {
        PyErr_SetString(PyExc_RuntimeError, ex.what());
    } catch (...) {
        PyErr_SetString(PyExc_RuntimeError, "Can't create TArgmaxLearner: unknown exception");
    }
    return nullptr;
}


// TSoftmaxLearner
struct TSoftmaxLearnerHolder {
    THolder<NNeuralNetwork::TSoftmaxLearner> Learner;

    TSoftmaxLearnerHolder(NNeuralNetwork::TSoftmaxLearner *learner)
        : Learner(learner) {
    }
};

class TSoftmaxLearnerTraits : public NPyBind::TPythonType<TSoftmaxLearnerHolder, NNeuralNetwork::TSoftmaxLearner, TSoftmaxLearnerTraits> {
private:
    typedef class NPyBind::TPythonType<TSoftmaxLearnerHolder, NNeuralNetwork::TSoftmaxLearner, TSoftmaxLearnerTraits> TParent;
    friend class NPyBind::TPythonType<TSoftmaxLearnerHolder, NNeuralNetwork::TSoftmaxLearner, TSoftmaxLearnerTraits>;
    TSoftmaxLearnerTraits();

public:
    static NNeuralNetwork::TSoftmaxLearner *GetObject(const TSoftmaxLearnerHolder &holder) {
        return holder.Learner.Get();
    }

    static TSoftmaxLearnerHolder *DoInitObject(PyObject *args, PyObject *kwargs);
};


TSoftmaxLearnerTraits::TSoftmaxLearnerTraits()
    : TParent("libperceptron.TSoftmaxLearner", "Class for learning Rumelhart's perceptron using softargmax loss function") {
    AddCaller("Save", NPyBind::CreateConstMethodCaller<NNeuralNetwork::TSoftmaxLearner>(&NNeuralNetwork::TSoftmaxLearner::SavePerceptronAsString));
    AddCaller("Load", NPyBind::CreateMethodCaller<NNeuralNetwork::TSoftmaxLearner>(&NNeuralNetwork::TSoftmaxLearner::LoadPerceptronFromString));
    AddCaller("Add", NPyBind::CreateMethodCaller<NNeuralNetwork::TSoftmaxLearner>(&NNeuralNetwork::TSoftmaxLearner::Add));
    AddCaller("FlushAdaDelta", NPyBind::CreateMethodCaller<NNeuralNetwork::TSoftmaxLearner>(&NNeuralNetwork::TBackPropagationLearner::FlushAdaDelta));
    AddCaller("FlushSimple", NPyBind::CreateMethodCaller<NNeuralNetwork::TSoftmaxLearner>(&NNeuralNetwork::TBackPropagationLearner::FlushSimple));
    AddCaller("GradientChecking", NPyBind::CreateConstMethodCaller<NNeuralNetwork::TSoftmaxLearner>(&NNeuralNetwork::TSoftmaxLearner::GradientChecking));
}

TSoftmaxLearnerHolder *TSoftmaxLearnerTraits::DoInitObject(PyObject *args, PyObject *) {
    try {
        THolder<NNeuralNetwork::TSoftmaxLearner> learner;
        if (args && PyTuple_Check(args) && PyTuple_Size(args) == 0) {
            learner.Reset(new NNeuralNetwork::TSoftmaxLearner);
        } else if (args && PyTuple_Check(args) && PyTuple_Size(args) == 2) {
            double maxInit;
            TVector<size_t> dimensions;
            if (NPyBind::ExtractArgs(args, maxInit, dimensions)) {
                learner.Reset(new NNeuralNetwork::TSoftmaxLearner(maxInit, dimensions));
            }
       } else if (args && PyTuple_Check(args) && PyTuple_Size(args) == 3) {
            double maxInit, alpha;
            TVector<size_t> dimensions;
            if (NPyBind::ExtractArgs(args, maxInit, alpha, dimensions)) {
                learner.Reset(new NNeuralNetwork::TSoftmaxLearner(maxInit, alpha, dimensions));
            }
        }
        if (learner.Get() == nullptr)
            ythrow yexception() << "Failed to construct TSoftmaxLearner";
        return new TSoftmaxLearnerHolder(learner.Release());
    } catch (const std::exception &ex) {
        PyErr_SetString(PyExc_RuntimeError, ex.what());
    } catch (...) {
        PyErr_SetString(PyExc_RuntimeError, "Can't create TSoftmaxLearner: unknown exception");
    }
    return nullptr;
}


static PyMethodDef LibPerceptronMethods[] = {
    {nullptr, nullptr, 0, nullptr}
};

void DoInitLibPerceptron() {
    PyObject* m = Py_InitModule("libperceptron", LibPerceptronMethods);
    TRumelhartPerceptronTraits::Instance().Register(m, "TRumelhartPerceptron");
    TRMSELearnerTraits::Instance().Register(m, "TRMSELearner");
    TRegressionLearnerTraits::Instance().Register(m, "TRegressionLearner");
    TArgmaxLearnerTraits::Instance().Register(m, "TArgmaxLearner");
    TSoftmaxLearnerTraits::Instance().Register(m, "TSoftmaxLearner");
}

