#include "extcalc.h"

#include <library/cpp/pybind/typedesc.h>
#include <library/cpp/streams/factory/factory.h>

#include <util/folder/path.h>
#include <util/stream/file.h>
#include <util/stream/str.h>
#include <util/ysaveload.h>

#include <kernel/extended_mx_calcer/interface/extended_relev_calcer.h>
#include <kernel/formula_storage/loader.h>


struct TMXNetBundleHolder {
    THolder<NExtendedMx::TExtendedRelevCalcer> MXNetBundle;

    TMXNetBundleHolder(NExtendedMx::TExtendedRelevCalcer *info)
        : MXNetBundle(info) {
    }
};


class TMXNetBundleTraits : public NPyBind::TPythonType<TMXNetBundleHolder, NExtendedMx::TExtendedRelevCalcer, TMXNetBundleTraits> {
private:
    typedef class NPyBind::TPythonType<TMXNetBundleHolder, NExtendedMx::TExtendedRelevCalcer, TMXNetBundleTraits> TParent;
    friend class NPyBind::TPythonType<TMXNetBundleHolder, NExtendedMx::TExtendedRelevCalcer, TMXNetBundleTraits>;
    TMXNetBundleTraits();

public:
    static NExtendedMx::TExtendedRelevCalcer *GetObject(const TMXNetBundleHolder &holder) {
        return holder.MXNetBundle.Get();
    }

    static TMXNetBundleHolder *DoInitObject(PyObject *args, PyObject *kwargs);
    //static TMXNetBundleHolder *DoInitPureObject(const TVector<TString> &);
};


template<class TSubObject>
class TMXNetSetPropertyCaller : public NPyBind::TBaseMethodCaller<TSubObject> {
public:
    virtual bool CallMethod(PyObject*, TSubObject *self, PyObject *args, PyObject*, PyObject *&res) const {
        try {
            TString name, value;
            if (!NPyBind::ExtractArgs(args, name, value))
                ythrow yexception() << "Can't set property: it should be exactly two parameters - name and value";
            self->Info[name] = value;
        } catch (const std::exception &ex) {
            PyErr_SetString(PyExc_RuntimeError, ex.what());
        } catch (...) {
            PyErr_SetString(PyExc_RuntimeError, "Can't set property: unknown exception");
        }
        Py_INCREF(Py_None);
        res = Py_None;
        return true;
    }
};


template<class TSubObject>
class TMXNetGetPropertyCaller : public NPyBind::TBaseMethodCaller<TSubObject> {
public:
    virtual bool CallMethod(PyObject*, TSubObject *self, PyObject *args, PyObject*, PyObject *&res) const {
        try {
            TString name;
            if (!NPyBind::ExtractArgs(args, name))
                ythrow yexception() << "Can't get property: it should be exactly one parameter - name of the property";
            const auto &iter = self->Info.find(name);
            if (iter != self->Info.end()) {
                res = NPyBind::BuildPyObject(iter->second);
            } else {
                Py_INCREF(Py_None);
                res = Py_None;
            }
        } catch (const std::exception &ex) {
            PyErr_SetString(PyExc_RuntimeError, ex.what());
        } catch (...) {
            PyErr_SetString(PyExc_RuntimeError, "Can't get property: unknown exception");
        }
        return true;
    }
};


template<class TSubObject>
class TObjectSaveCaller : public NPyBind::TBaseMethodCaller<TSubObject> {
public:
    virtual bool CallMethod(PyObject*, TSubObject *self, PyObject *args, PyObject*, PyObject *&res) const {
        try {
            TString path;
            if (!NPyBind::ExtractArgs(args, path))
                ythrow yexception() << "Can't save: it should be exactly one parameter - path for saving model";
            TFixedBufferFileOutput out(path);
            ::Save(&out, *self);
        } catch (const std::exception &ex) {
            PyErr_SetString(PyExc_RuntimeError, ex.what());
        } catch (...) {
            PyErr_SetString(PyExc_RuntimeError, "Can't save: unknown exception");
        }
        Py_INCREF(Py_None);
        res = Py_None;
        return true;
    }
};

template<class TSubObject>
class TObjectLoadCaller : public NPyBind::TBaseMethodCaller<TSubObject> {
public:
    virtual bool CallMethod(PyObject*, TSubObject *self, PyObject *args, PyObject*, PyObject *&res) const {
        try {
            TString path;
            if (!NPyBind::ExtractArgs(args, path))
                ythrow yexception() << "Can't save: it should be exactly one parameter - file path";
            TFileInput in(path);
            ::Load(&in, *self);
            Py_INCREF(Py_None);
            res = Py_None;
            return true;
        } catch (const std::exception &ex) {
            PyErr_SetString(PyExc_RuntimeError, ex.what());
        } catch (...) {
            PyErr_SetString(PyExc_RuntimeError, "Can't load: unknown exception");
        }
        res = nullptr;
        return true;
    }
};

template<class TSubObject>
class TObjectGetStateCaller : public NPyBind::TBaseMethodCaller<TSubObject> {
public:
    virtual bool CallMethod(PyObject*, TSubObject *self, PyObject*, PyObject*, PyObject *&res) const {
        try {
            TString state;
            TStringOutput out(state);
            self->Save(&out);
            NPyBind::TPyObjectPtr result(NPyBind::BuildPyObject(state), true);
            res = result.RefGet();
            return true;
        } catch (const std::exception &ex) {
            PyErr_SetString(PyExc_RuntimeError, ex.what());
        } catch (...) {
            PyErr_SetString(PyExc_RuntimeError, "Can't get state: unknown exception");
        }
        res = nullptr;
        return true;
    }
};

template<class TSubObject>
class TObjectSetStateCaller : public NPyBind::TBaseMethodCaller<TSubObject> {
public:
    virtual bool CallMethod(PyObject*, TSubObject *self, PyObject *args, PyObject*, PyObject *&res) const {
        try {
            TString state;
            if (!NPyBind::ExtractArgs(args, state))
                ythrow yexception() << "Can't set state: it should be exactly one string parameter - the state";
            TStringInput in(state);
            ::Load(&in, *self);
            Py_INCREF(Py_None);
            res = Py_None;
            return true;
        } catch (const std::exception &ex) {
            PyErr_SetString(PyExc_RuntimeError, ex.what());
        } catch (...) {
            PyErr_SetString(PyExc_RuntimeError, "Can't set state: unknown exception");
        }
        res = nullptr;
        return true;
    }
};


class TMXNetBundleCalcCaller : public NPyBind::TBaseMethodCaller<NExtendedMx::TExtendedRelevCalcer> {
public:
    bool CallMethod(PyObject *, NExtendedMx::TExtendedRelevCalcer *self, PyObject *args, PyObject*, PyObject *&res) const override {
        try {
            TString ctx;
            TVector<float> features;
            if (!NPyBind::ExtractArgs(args, features, ctx) && !NPyBind::ExtractArgs(args, features))
                ythrow yexception() << "Can't calculate: factors should be vector of float and (opt) json context string";
            NExtendedMx::TCalcContext calcCtx(NSc::TValue::FromJsonThrow(ctx));
            self->CalcRelevExtended(features, calcCtx);
            res = NPyBind::BuildPyObject(ToString(calcCtx.Root()));
            return true;
        } catch (const std::exception &ex) {
            PyErr_SetString(PyExc_RuntimeError, ex.what());
        } catch (...) {
            PyErr_SetString(PyExc_RuntimeError, "Can't calculate: unknown exception");
        }
        res = nullptr;
        return true;
    }
};

class TMXNetBundleRecalcCaller : public NPyBind::TBaseMethodCaller<NExtendedMx::TExtendedRelevCalcer> {
public:
    bool CallMethod(PyObject *, NExtendedMx::TExtendedRelevCalcer *self, PyObject *args, PyObject*, PyObject *&res) const override {
        try {
            TString ctx;
            TString protoParams;
            if (!NPyBind::ExtractArgs(args, ctx, protoParams))
                ythrow yexception() << "Can't recalculate: error parsing args";
            NExtendedMx::TCalcContext calcCtx(NSc::TValue::FromJsonThrow(ctx));
            const NSc::TValue recalcParams(NSc::TValue::FromJsonThrow(protoParams));
            self->RecalcRelevExtended(calcCtx, recalcParams);
            res = NPyBind::BuildPyObject(ToString(calcCtx.Root()));
            return true;
        } catch (const std::exception &ex) {
            PyErr_SetString(PyExc_RuntimeError, ex.what());
        } catch (...) {
            PyErr_SetString(PyExc_RuntimeError, "Can't recalculate: unknown exception");
        }
        res = nullptr;
        return true;
    }
};


TMXNetBundleTraits::TMXNetBundleTraits()
    : TParent("libmxnet.TMXNetBundle", "matrixnet data")
{
    AddCaller("Calculate", new TMXNetBundleCalcCaller);
    AddCaller("Recalculate", new TMXNetBundleRecalcCaller);
    //AddCaller("Save", new TObjectSaveCaller<NMatrixnet::TMnMultiCateg>());
    //AddCaller("Load", new TObjectLoadCaller<NMatrixnet::TMnMultiCateg>());
    //AddCaller("__getstate__", new TObjectGetStateCaller<NExtendedMx::TExtendedRelevCalcer>());
    //AddCaller("__setstate__", new TObjectSetStateCaller<NExtendedMx::TExtendedRelevCalcer>());
}


//TMXNetBundleHolder *TMXNetBundleTraits::DoInitPureObject(const TVector<TString> &) {
//    THolder<NExtendedMx::TExtendedRelevCalcer> mx(new NExtendedMx::TExtendedRelevCalcer);
//    return new TMXNetBundleHolder(mx.Release());
//}

TMXNetBundleHolder *TMXNetBundleTraits::DoInitObject(PyObject *args, PyObject*) {
    try {
        TString modelPath;
        if (!NPyBind::ExtractArgs(args, modelPath))
            ythrow yexception() << "Can't create TMXNetBundle: it should be exactly one parameter - path to model file";
        TFsPath bundle(modelPath);
        THolder<NExtendedMx::TExtendedRelevCalcer> mx(LoadFormula<NExtendedMx::TExtendedRelevCalcer>("." + bundle.GetExtension(), OpenInput(modelPath).Get()).Release());
        mx->Initialize(nullptr);
        return new TMXNetBundleHolder(mx.Release());
    } catch (const std::exception &ex) {
        PyErr_SetString(PyExc_RuntimeError, ex.what());
    } catch (...) {
        PyErr_SetString(PyExc_RuntimeError, "Can't create TMXNetBundle: unknown exception");
    }
    return nullptr;
}

static PyMethodDef LibExtcalcMethods[] = {
    {nullptr, nullptr, 0, nullptr}
};

void DoInitLibextcalc() {
    PyObject* m = Py_InitModule("libextcalc", LibExtcalcMethods);
    TMXNetBundleTraits::Instance().Register(m, "TMXNetBundle");
}
