#include <library/cpp/pybind/typedesc.h>

#include <contrib/libs/opencv/include/opencv2/opencv.hpp>

#include <cv/imgclassifiers/danet/external/ext_entry_point/entry_point.h>

#include <util/generic/map.h>
#include <util/generic/vector.h>
#include <util/charset/wide.h>
#include <util/generic/algorithm.h>
#include <utility>

#include <sstream>

using namespace NNeuralNet;
using namespace cv;

struct TDanetOpsHolder {
    THolder<TDanetOps> NNLibEntryPoint;

    TDanetOpsHolder(TDanetOps* entry)
        : NNLibEntryPoint(entry) {
    }
};

class TDanetOpsTraits : public NPyBind::TPythonType<TDanetOpsHolder, TDanetOps, TDanetOpsTraits> {
private:
    typedef class NPyBind::TPythonType<TDanetOpsHolder, TDanetOps, TDanetOpsTraits> TParent;
    friend class NPyBind::TPythonType<TDanetOpsHolder, TDanetOps, TDanetOpsTraits>;
    TDanetOpsTraits();

public:
    static TDanetOps *GetObject(const TDanetOpsHolder &holder) {
        return holder.NNLibEntryPoint.Get();
    }

    static TDanetOpsHolder *DoInitObject(PyObject *args, PyObject *kwargs);
};

namespace NPyBind {
    template<>
    inline bool FromPyObject(PyObject *obj, cv::Mat& res) {
        if (0 == PyObject_CheckBuffer(obj))
            ythrow yexception() << "Obj should be buffer";

        Py_buffer imgBuf;
        PyObject_GetBuffer(obj, &imgBuf, PyBUF_C_CONTIGUOUS);

        if (3 != imgBuf.ndim)
            ythrow yexception() << "img dim should be 3";

        res = cv::Mat(imgBuf.shape[0], imgBuf.shape[1], CV_8UC3, imgBuf.buf);
        return true;
    }
}

class TFeatureCaller : public NPyBind::TBaseMethodCaller<TDanetOps> {
public:
    bool CallMethod(PyObject *, TDanetOps *self, PyObject *args, PyObject *, PyObject *&res) const override {
        try {
            cv::Mat img;
            TString layerName;
            if (!NPyBind::ExtractArgs(args, img, layerName))
                ythrow yexception() << "Could not parse args for pydanet::Features";

            TVectorFloat result;
            self->GetLayersOutputsForImage(self->GetFreeWorker(), img, layerName, result);
            res = NPyBind::BuildPyObject(result);
            return true;
        } catch (const std::exception &ex) {
            PyErr_SetString(PyExc_RuntimeError, ex.what());
        } catch (...) {
            PyErr_SetString(PyExc_RuntimeError, "Could not parse args for pydanet::Features unknown exception");
        }
        res = nullptr;
        return true;
    }
};

class TPredictCaller : public NPyBind::TBaseMethodCaller<TDanetOps> {
public:
    bool CallMethod(PyObject *, TDanetOps *self, PyObject *args, PyObject *, PyObject *&res) const override {
        try {
            cv::Mat img;
            if (!NPyBind::ExtractArgs(args, img))
              ythrow yexception() << "Could not parse args for pydanet::Predict";

            TPreds result;
            self->PredictOnImage(self->GetFreeWorker(), img, result);
            res = NPyBind::BuildPyObject(result);
            return true;
        } catch (const std::exception &ex) {
            PyErr_SetString(PyExc_RuntimeError, ex.what());
        } catch (...) {
            PyErr_SetString(PyExc_RuntimeError, "Could not parse args for pydanet::Features unknown exception");
        }
        res = nullptr;
        return true;
    }
};


TDanetOpsTraits::TDanetOpsTraits()
    : TParent("pydanet.pydanet", "Neural net module") {
    AddCaller("Features", new TFeatureCaller);
    AddCaller("Predict", new TPredictCaller);
}


TDanetOpsHolder *TDanetOpsTraits::DoInitObject(PyObject* args, PyObject* kwargs) {
     try {
         TString config, model;
         int thrNum = 1;
         static const char* keywords[] = {"config", "thread_num", "model"};
         if (!NPyBind::ExtractOptionalArgs(args, kwargs, keywords, config, thrNum, model))
             ythrow yexception() << "Could not parse args for pydanet::__init__()";
         if (!model.empty())
             return new TDanetOpsHolder(new TDanetOps(config, model, thrNum));
         else
             return new TDanetOpsHolder(new TDanetOps(config, thrNum));
     } catch (const std::exception &ex) {
         PyErr_SetString(PyExc_RuntimeError, ex.what());
     } catch (...) {
         PyErr_SetString(PyExc_RuntimeError, "Can't create pydanet: unknown exception");
     }
     return nullptr;
}

static PyMethodDef PyDanetMethods[] = {
     {nullptr, nullptr, 0, nullptr}
};

PyMODINIT_FUNC initpydanet() {
    PyObject* m = Py_InitModule("pydanet", PyDanetMethods);
    TDanetOpsTraits::Instance().Register(m, "pydanet");
}
