#include "linear_regression.h"

#include <library/cpp/pybind/typedesc.h>

#include <library/cpp/linear_regression/linear_model.h>
#include <library/cpp/linear_regression/linear_regression.h>


// Model = [coefficients] + intercept
struct TLinearModelHolder {
    TLinearModel Model;

    TLinearModelHolder(const TLinearModel &model)
        : Model(model) {
    }
};

class TLinearModelTraits : public NPyBind::TPythonType<TLinearModelHolder, TLinearModel, TLinearModelTraits> {
private:
    typedef class NPyBind::TPythonType<TLinearModelHolder, TLinearModel, TLinearModelTraits> TParent;
    friend class NPyBind::TPythonType<TLinearModelHolder, TLinearModel, TLinearModelTraits>;
    TLinearModelTraits();

public:
    static TLinearModel *GetObject(TLinearModelHolder &holder) {
        return &holder.Model;
    }

    static TLinearModelHolder *DoInitObject(PyObject *args, PyObject *kwargs);
};


// Solver
struct TLinearRegressionSolverHolder {
    THolder<TLinearRegressionSolver> Solver;

    TLinearRegressionSolverHolder(TLinearRegressionSolver *solver)
        : Solver(solver) {
    }
};

class TLinearRegressionSolverTraits : public NPyBind::TPythonType<TLinearRegressionSolverHolder, TLinearRegressionSolver, TLinearRegressionSolverTraits> {
private:
    typedef class NPyBind::TPythonType<TLinearRegressionSolverHolder, TLinearRegressionSolver, TLinearRegressionSolverTraits> TParent;
    friend class NPyBind::TPythonType<TLinearRegressionSolverHolder, TLinearRegressionSolver, TLinearRegressionSolverTraits>;
    TLinearRegressionSolverTraits();

public:
    static TLinearRegressionSolver *GetObject(const TLinearRegressionSolverHolder &holder) {
        return holder.Solver.Get();
    }

    static TLinearRegressionSolverHolder *DoInitObject(PyObject *args, PyObject *kwargs);
};


// FastSolver
struct TFastLinearRegressionSolverHolder {
    THolder<TFastLinearRegressionSolver> Solver;

    TFastLinearRegressionSolverHolder(TFastLinearRegressionSolver *solver)
        : Solver(solver) {
    }
};

class TFastLinearRegressionSolverTraits : public NPyBind::TPythonType<TFastLinearRegressionSolverHolder, TFastLinearRegressionSolver, TFastLinearRegressionSolverTraits> {
private:
    typedef class NPyBind::TPythonType<TFastLinearRegressionSolverHolder, TFastLinearRegressionSolver, TFastLinearRegressionSolverTraits> TParent;
    friend class NPyBind::TPythonType<TFastLinearRegressionSolverHolder, TFastLinearRegressionSolver, TFastLinearRegressionSolverTraits>;
    TFastLinearRegressionSolverTraits();

public:
    static TFastLinearRegressionSolver *GetObject(const TFastLinearRegressionSolverHolder &holder) {
        return holder.Solver.Get();
    }

    static TFastLinearRegressionSolverHolder *DoInitObject(PyObject *args, PyObject *kwargs);
};


// Methods
TLinearModelTraits::TLinearModelTraits()
    : TParent("liblinregr.TModel", "coefficients for linear model") {
    AddCaller("Prediction", NPyBind::CreateConstMethodCaller<TLinearModel>(&TLinearModel::Prediction<double>));
    AddGetter("Coefficients", NPyBind::CreateMethodAttrGetter<TLinearModel>(&TLinearModel::GetCoefficients));
    AddGetter("Intercept", NPyBind::CreateMethodAttrGetter<TLinearModel>(&TLinearModel::GetIntercept));
}

TLinearModelHolder *TLinearModelTraits::DoInitObject(PyObject *args, PyObject*) {
    try {
        TVector<double> coefficients;
        double intercept;
        if (!NPyBind::ExtractArgs(args, coefficients, intercept))
            ythrow yexception() << "Can't create TLinearModel: it should be exactly two parameters - vectors of coefficients and intercept";
        TLinearModel model(std::move(coefficients), intercept);
        return new TLinearModelHolder(model);
    } catch (const std::exception &ex) {
        PyErr_SetString(PyExc_RuntimeError, ex.what());
    } catch (...) {
        PyErr_SetString(PyExc_RuntimeError, "Can't create TLinearModel: unknown exception");
    }
    return nullptr;
}


template<typename T>
class TSolveCaller : public NPyBind::TBaseMethodCaller<T> {
public:
    bool CallMethod(PyObject *, T *self, PyObject *, PyObject*, PyObject *&res) const override {
        try {
            TLinearModel model(self->Solve());
            NPyBind::TPyObjectPtr ptr = TLinearModelTraits::Instance().CreatePyObject(new TLinearModelHolder(model));
            res = ptr.RefGet();
            return true;
        } catch (const std::exception &ex) {
            PyErr_SetString(PyExc_RuntimeError, ex.what());
        } catch (...) {
            PyErr_SetString(PyExc_RuntimeError, "Can't solve: unknown exception");
        }
        res = nullptr;
        return true;
    }
};


TLinearRegressionSolverTraits::TLinearRegressionSolverTraits()
    : TParent("liblinregr.TSolver", "class for solving multiple regression problem") {
    AddCaller("Add", NPyBind::CreateMethodCaller<TLinearRegressionSolver>(&TLinearRegressionSolver::Add));
    AddCaller("SumSquaredErrors", NPyBind::CreateConstMethodCaller<TLinearRegressionSolver>(&TLinearRegressionSolver::SumSquaredErrors));
    AddCaller("Solve", new TSolveCaller<TLinearRegressionSolver>);
}

TLinearRegressionSolverHolder *TLinearRegressionSolverTraits::DoInitObject(PyObject *, PyObject*) {
    try {
        THolder<TLinearRegressionSolver> solver(new TLinearRegressionSolver);
        return new TLinearRegressionSolverHolder(solver.Release());
    } catch (const std::exception &ex) {
        PyErr_SetString(PyExc_RuntimeError, ex.what());
    } catch (...) {
        PyErr_SetString(PyExc_RuntimeError, "Can't create TLinearRegressionSolver: unknown exception");
    }
    return nullptr;
}


TFastLinearRegressionSolverTraits::TFastLinearRegressionSolverTraits()
    : TParent("liblinregr.TFastSolver", "class for solving multiple regression problem") {
    AddCaller("Add", NPyBind::CreateMethodCaller<TFastLinearRegressionSolver>(&TFastLinearRegressionSolver::Add));
    AddCaller("SumSquaredErrors", NPyBind::CreateConstMethodCaller<TFastLinearRegressionSolver>(&TFastLinearRegressionSolver::SumSquaredErrors));
    AddCaller("Solve", new TSolveCaller<TFastLinearRegressionSolver>);
}


TFastLinearRegressionSolverHolder *TFastLinearRegressionSolverTraits::DoInitObject(PyObject *, PyObject*) {
    try {
        THolder<TFastLinearRegressionSolver> solver(new TFastLinearRegressionSolver);
        return new TFastLinearRegressionSolverHolder(solver.Release());
    } catch (const std::exception &ex) {
        PyErr_SetString(PyExc_RuntimeError, ex.what());
    } catch (...) {
        PyErr_SetString(PyExc_RuntimeError, "Can't create TFastLinearRegressionSolver: unknown exception");
    }
    return nullptr;
}


static PyMethodDef LibLinearRegressionMethods[] = {
    {nullptr, nullptr, 0, nullptr}
};

void DoInitLibLinearRegression() {
    PyObject* m = Py_InitModule("liblinregr", LibLinearRegressionMethods);
    TLinearModelTraits::Instance().Register(m, "TModel");
    TLinearRegressionSolverTraits::Instance().Register(m, "TSolver");
    TFastLinearRegressionSolverTraits::Instance().Register(m, "TFastSolver");
}

