#include "catboost.h"

#include <library/cpp/pybind/typedesc.h>
#include <kernel/catboost/catboost_calcer.h>

#include <util/stream/file.h>

using NCatboostCalcer::TCatboostCalcer;

struct TCatboostModelHolder {
    THolder<TCatboostCalcer> CatboostModel;

    TCatboostModelHolder(TCatboostCalcer *info)
        : CatboostModel(info) {
    }
};

class TCatboostModelTraits : public NPyBind::TPythonType<TCatboostModelHolder, TCatboostCalcer, TCatboostModelTraits> {
private:
    typedef class NPyBind::TPythonType<TCatboostModelHolder, TCatboostCalcer, TCatboostModelTraits> TParent;
    friend class NPyBind::TPythonType<TCatboostModelHolder, TCatboostCalcer, TCatboostModelTraits>;
    TCatboostModelTraits();

public:
    static TCatboostCalcer *GetObject(const TCatboostModelHolder &holder) {
        return holder.CatboostModel.Get();
    }

    static TCatboostModelHolder *DoInitObject(PyObject *args, PyObject *kwargs);
    static TCatboostModelHolder *DoInitPureObject(const TVector<TString> &);
};

class TCatboostModelCalcCaller : public NPyBind::TBaseMethodCaller<TCatboostCalcer> {
public:
    bool CallMethod(PyObject *, TCatboostCalcer *self, PyObject *args, PyObject*, PyObject *&res) const override {
        try {
            // TODO support slices

            TVector< TVector<float> > factors;
            TVector< TVector<TStringBuf> > categoricalFactors;
            TVector<double> result;

            if (!NPyBind::ExtractArgs(args, factors, categoricalFactors)) {
                if (!NPyBind::ExtractArgs(args, factors))
                    ythrow yexception() << "Can't calculate: factors should be matrix of float; if there are categoricalFactors, they should be matrix of string";

                result.resize(factors.size());
                self->CalcRelevs(factors, result);
            } else {
                TVector< TConstArrayRef<float> > factorsAsConstArrayRef;
                factorsAsConstArrayRef.reserve(factors.size());
                for (const auto& row : factors) {
                    factorsAsConstArrayRef.emplace_back(row.data(), row.size());
                }

                result.resize(factorsAsConstArrayRef.size());
                self->CalcRelevs(factorsAsConstArrayRef, categoricalFactors, result);
            }
            res = NPyBind::BuildPyObject(result);
            return true;
        } catch (const std::exception &ex) {
            PyErr_SetString(PyExc_RuntimeError, ex.what());
        } catch (...) {
            PyErr_SetString(PyExc_RuntimeError, "Can't calculate: unknown exception");
        }
        res = nullptr;
        return true;
    }
};


class TCatboostModelCopyTruncatedCaller : public NPyBind::TBaseMethodCaller<TCatboostCalcer> {
public:
    bool CallMethod(PyObject *, TCatboostCalcer *self, PyObject *args, PyObject*, PyObject *&res) const override {
        try {
            size_t begin = 0, end = 0;
            if (!NPyBind::ExtractArgs(args, end) && !NPyBind::ExtractArgs(args, begin, end))
                ythrow yexception() << "Can't truncate: it should be one or two parameters - treeCount or (begin, end)";

            THolder<TCatboostCalcer> mx = self->CopyTreeRange(begin, end);
            NPyBind::TPyObjectPtr result(
                TCatboostModelTraits::Instance().CreatePyObject(new TCatboostModelHolder(mx.Release()))
            );
            res = result.RefGet();
            return true;
        } catch (const std::exception &ex) {
            PyErr_SetString(PyExc_RuntimeError, ex.what());
        } catch (...) {
            PyErr_SetString(PyExc_RuntimeError, "Can't truncate: unknown exception");
        }
        Py_INCREF(Py_None);
        res = Py_None;
        return true;
    }
};

template<class TSubObject>
class TObjectSaveCaller : public NPyBind::TBaseMethodCaller<TSubObject> {
public:
    bool CallMethod(PyObject*, TSubObject *self, PyObject *args, PyObject*, PyObject *&res) const override {
        try {
            TString path;
            if (!NPyBind::ExtractArgs(args, path))
                ythrow yexception() << "Can't save: it should be exactly one parameter - file path";
            TFixedBufferFileOutput out(path);
            ::Save(&out, *self);
            Py_INCREF(Py_None);
            res = Py_None;
            return true;
        } catch (const std::exception &ex) {
            PyErr_SetString(PyExc_RuntimeError, ex.what());
        } catch (...) {
            PyErr_SetString(PyExc_RuntimeError, "Can't save: unknown exception");
        }
        res = nullptr;
        return true;
    }
};

template<class TSubObject>
class TObjectLoadCaller : public NPyBind::TBaseMethodCaller<TSubObject> {
public:
    bool CallMethod(PyObject*, TSubObject *self, PyObject *args, PyObject*, PyObject *&res) const override {
        try {
            TString path;
            if (!NPyBind::ExtractArgs(args, path))
                ythrow yexception() << "Can't load: it should be exactly one parameter - file path";
            TFileInput in(path);
            ::Load(&in, *self);
            Py_INCREF(Py_None);
            res = Py_None;
            return true;
        } catch (const std::exception &ex) {
            PyErr_SetString(PyExc_RuntimeError, ex.what());
        } catch (...) {
            PyErr_SetString(PyExc_RuntimeError, "Can't load: unknown exception");
        }
        res = nullptr;
        return true;
    }
};

TCatboostModelTraits::TCatboostModelTraits()
    : TParent("libcatboost.TCatboostModel", "catboost model data")
{
    AddCaller("Calculate", new TCatboostModelCalcCaller);
    AddCaller("CopyTruncated", new TCatboostModelCopyTruncatedCaller());
    AddCaller("Load", new TObjectLoadCaller<TCatboostCalcer>());
    AddCaller("Save", new TObjectSaveCaller<TCatboostCalcer>());
}

TCatboostModelHolder *TCatboostModelTraits::DoInitPureObject(const TVector<TString> &) {
    THolder<TCatboostCalcer> mx(new TCatboostCalcer);
    return new TCatboostModelHolder(mx.Release());
}

TCatboostModelHolder *TCatboostModelTraits::DoInitObject(PyObject *args, PyObject*) {
    try {
        THolder<TCatboostCalcer> mx(new TCatboostCalcer);
        TString modelPath;
        if (NPyBind::ExtractArgs(args, modelPath)) {
            TFileInput in(modelPath);
            ::Load(&in, *mx);
        } else if (args && (!PyTuple_Check(args) || PyTuple_Size(args) > 0)) {
            ythrow yexception() << "Can't create TCatboostModel: there should be no params or exactly one parameter - path to model file";
        }
        return new TCatboostModelHolder(mx.Release());
    } catch (const std::exception &ex) {
        PyErr_SetString(PyExc_RuntimeError, ex.what());
    } catch (...) {
        PyErr_SetString(PyExc_RuntimeError,     "Can't create TCatboostModel: unknown exception");
    }
    return nullptr;
}

static PyMethodDef LibCatboostMethods[] = {
    {nullptr, nullptr, 0, nullptr}
};

void DoInitLibcatboost() {
    PyObject* m = Py_InitModule("libcatboost", LibCatboostMethods);
    TCatboostModelTraits::Instance().Register(m, "TCatboostModel");
}
