//
// Created by alxmopo3ov on 14.11.16.
//

#include "difacto.h"

#include <library/cpp/pybind/typedesc.h>
#include <quality/deprecated/Misc/FastIO.h>
#include <util/folder/path.h>

#include "ads/dmlc/difacto/sources/learn/difacto_light_applicator/difacto_light_applicator.h"
#include "ads/dmlc/difacto/sources/learn/base/parser/arg_parser.h"


dmlc::difacto::Config read_conf(const char *path) {
    dmlc::ArgParser parser;
    parser.ReadFile(path);
    dmlc::difacto::Config conf;
    parser.ParseToProto(&conf);
    return conf;
};


dmlc::difacto::Config read_conf(TString path) {
    dmlc::ArgParser parser;
    parser.ReadFile(path.data());
    dmlc::difacto::Config conf;
    parser.ParseToProto(&conf);
    return conf;
};


struct TDiFactoLightApplicatorHolder {
    THolder <LightDifactoApplicator> applicator_holder;

    TDiFactoLightApplicatorHolder(LightDifactoApplicator* app)
            : applicator_holder(app) {
    }
};


template <typename T>
void make_protobuf_stream_from_raw_input(const TVector<TVector<TVector<T>>> &raw_input,
                                         std::vector <dmlc::data::TrainingSample> &out_proto_stream) {
    for(size_t i = 0; i < raw_input.ysize(); i++) { //iterate over objects
        const TVector <TVector<T>> &cur_instance = raw_input[i];
        dmlc::data::TrainingSample sample;
        sample.set_label(true);
        sample.set_weight(1.0f);
        for(size_t f = 0; f < cur_instance.ysize(); f++) {
            const TVector <T> &cur_field = cur_instance[f];
            dmlc::data::TrainingSample_Field *fld = sample.add_fields();

            for(size_t j = 0; j < cur_field.ysize(); j++) {
                fld->add_features(cur_field[j]);
                fld->add_values(1.0f);
            }
        }
        out_proto_stream.push_back(sample);
    }
}

template <typename T>
void print_raw_input(const TVector<TVector<TVector<T>>> &raw_input) {
    for(size_t i = 0; i < raw_input.ysize(); i++) {
        for(size_t j = 0; j < raw_input[i].ysize(); j++) {
            for(size_t k = 0; k < raw_input[i][j].ysize(); k++) {
                std::cout<<raw_input[i][j][k]<<" ";
            }
            std::cout<<"\n";
            }
        std::cout<<"\n";
        }
    std::cout<<"\n";
}


void make_protobuf_stream_from_serialized_input(const TVector <TString> &protobuf_stream,
                                                const TVector <size_t> &sizes,
                                                std::vector <dmlc::data::TrainingSample> &out_proto_stream) {
    out_proto_stream.reserve(protobuf_stream.size());
    for (size_t i = 0; i < protobuf_stream.size(); i++) {
        dmlc::data::TrainingSample sample;
        Y_PROTOBUF_SUPPRESS_NODISCARD sample.ParseFromArray((const void *)protobuf_stream[i].data(), sizes[i]);
        out_proto_stream.push_back(sample);
    }
}


class TDiFactoLightApplicatorTraits : public NPyBind::TPythonType<TDiFactoLightApplicatorHolder, LightDifactoApplicator, TDiFactoLightApplicatorTraits> {
private:
    typedef class NPyBind::TPythonType<TDiFactoLightApplicatorHolder, LightDifactoApplicator, TDiFactoLightApplicatorTraits> TParent;
    friend class NPyBind::TPythonType<TDiFactoLightApplicatorHolder, LightDifactoApplicator, TDiFactoLightApplicatorTraits>;
    TDiFactoLightApplicatorTraits();

public:
    static LightDifactoApplicator *GetObject(const TDiFactoLightApplicatorHolder &holder) {
        return holder.applicator_holder.Get();
    }

    static TDiFactoLightApplicatorHolder *DoInitObject(PyObject *args, PyObject *kwargs);
    static TDiFactoLightApplicatorHolder *DoInitPureObject(const TVector<TString> &);
};


class TDiFactoLightApplicatorReadHashTableCaller: public NPyBind::TBaseMethodCaller<LightDifactoApplicator> {
public:
    bool CallMethod(PyObject *, LightDifactoApplicator *self, PyObject *args, PyObject*, PyObject *&res) const override {
        try {
            TVector<TString> model_files;
            if (!NPyBind::ExtractArgs(args, model_files))
                ythrow yexception() << "Cant read model for prediction: pybind extractargs failed";
            std::vector <std::string> model_files2(model_files.ysize());
            for(size_t i = 0; i < model_files.ysize(); i++) {
                model_files2[i] = std::string(model_files[i].data());
            }

            auto hash_table = CreateHashTable<size_t, float>(self->GetModelType(), model_files2);
            self->SetHashTable(hash_table);
            Py_INCREF(Py_None);
            res = Py_None;
            return true;
        } catch (const std::exception &ex) {
            PyErr_SetString(PyExc_RuntimeError, ex.what());
        } catch (...) {
            PyErr_SetString(PyExc_RuntimeError, "Can't calculate: unknown exception");
        }
        res = nullptr;
        return true;
    }
};


class TDiFactoLightApplicatorPredictCaller: public NPyBind::TBaseMethodCaller<LightDifactoApplicator> {
public:
    bool CallMethod(PyObject *, LightDifactoApplicator *self, PyObject *args, PyObject*, PyObject *&res) const override {
        try {
            // Unfortunately, PyBind cannot parse list of tuples [(x, y)] into TVector<pair<..>>
            TVector <TString> protobuf_stream;
            TVector <size_t> sizes;
            if (!NPyBind::ExtractArgs(args, protobuf_stream, sizes))
                ythrow yexception() << "Cant read input data for prediction: NPyBind::ExtractArgs failed";

            std::vector<float> result2;
            std::vector <dmlc::data::TrainingSample> parsed_proto_stream;
            make_protobuf_stream_from_serialized_input(protobuf_stream, sizes, parsed_proto_stream);
            self->Predict(parsed_proto_stream, result2);

            // Convertion to std vector (because i am too lazy to rebuild difacto on TVector and stroka
            TVector<float> result(result2.size());
            for(size_t i = 0; i < result2.size(); i++) {
                result[i] = result2[i];
            }

            res = NPyBind::BuildPyObject(result);
            return true;
        } catch (const std::exception &ex) {
            PyErr_SetString(PyExc_RuntimeError, ex.what());
        } catch (...) {
            PyErr_SetString(PyExc_RuntimeError, "Can't calculate: unknown exception");
        }
        res = nullptr;
        return true;
    }
};



class TDiFactoLightApplicatorPredictRawCaller: public NPyBind::TBaseMethodCaller<LightDifactoApplicator> {
public:
    bool CallMethod(PyObject *, LightDifactoApplicator *self, PyObject *args, PyObject*, PyObject *&res) const override {
        try {
            // Unfortunately, PyBind cannot parse list of tuples [(x, y)] into TVector<pair<..>>
            TVector<TVector<TVector<unsigned long long>>> input_hashes_stream;
            if (!NPyBind::ExtractArgs(args, input_hashes_stream))
                ythrow yexception() << "Cant read input data for prediction: NPyBind::ExtractArgs failed";

            std::vector<float> result2;
            // we read raw input and we have to convert it to protobuf stream
            std::vector <dmlc::data::TrainingSample> protobuf_stream2;
            //print_raw_input(input_hashes_stream);
            make_protobuf_stream_from_raw_input(input_hashes_stream, protobuf_stream2);
            self->Predict(protobuf_stream2, result2);

            // Convertion to std vector (because i am too lazy to rebuild difacto on TVector and stroka
            TVector<float> result(result2.size());
            for(size_t i = 0; i < result2.size(); i++) {
                result[i] = result2[i];
            }

            res = NPyBind::BuildPyObject(result);
            //res = Py_None;
            return true;
        } catch (const std::exception &ex) {
            PyErr_SetString(PyExc_RuntimeError, ex.what());
        } catch (...) {
            PyErr_SetString(PyExc_RuntimeError, "Can't calculate: unknown exception");
        }
        res = nullptr;
        return true;
    }
};


template<class TSubObject>
class TMXNetSetPropertyCaller : public NPyBind::TBaseMethodCaller<TSubObject> {
public:
    bool CallMethod(PyObject*, TSubObject *self, PyObject *args, PyObject*, PyObject *&res) const override {
        try {
            TString name, value;
            if (!NPyBind::ExtractArgs(args, name, value))
                ythrow yexception() << "Can't set property: it should be exactly two parameters - name and value";
            self->Info[name] = value;
        } catch (const std::exception &ex) {
            PyErr_SetString(PyExc_RuntimeError, ex.what());
        } catch (...) {
            PyErr_SetString(PyExc_RuntimeError, "Can't set property: unknown exception");
        }
        Py_INCREF(Py_None);
        res = Py_None;
        return true;
    }
};


template<class TSubObject>
class TMXNetGetPropertyCaller : public NPyBind::TBaseMethodCaller<TSubObject> {
public:
    bool CallMethod(PyObject*, TSubObject *self, PyObject *args, PyObject*, PyObject *&res) const override {
        try {
            TString name;
            if (!NPyBind::ExtractArgs(args, name))
                ythrow yexception() << "Can't get property: it should be exactly one parameter - name of the property";
            const auto &iter = self->Info.find(name);
            if (iter != self->Info.end()) {
                res = NPyBind::BuildPyObject(iter->second);
            } else {
                Py_INCREF(Py_None);
                res = Py_None;
            }
        } catch (const std::exception &ex) {
            PyErr_SetString(PyExc_RuntimeError, ex.what());
        } catch (...) {
            PyErr_SetString(PyExc_RuntimeError, "Can't get property: unknown exception");
        }
        return true;
    }
};



template<class TSubObject>
class TObjectGetStateCaller : public NPyBind::TBaseMethodCaller<TSubObject> {
public:
    bool CallMethod(PyObject*, TSubObject *self, PyObject*, PyObject*, PyObject *&res) const override {
        try {
            TString state;
            TStringOutput out(state);
            self->Save(&out);
            NPyBind::TPyObjectPtr result(NPyBind::BuildPyObject(state), true);
            res = result.RefGet();
            return true;
        } catch (const std::exception &ex) {
            PyErr_SetString(PyExc_RuntimeError, ex.what());
        } catch (...) {
            PyErr_SetString(PyExc_RuntimeError, "Can't get state: unknown exception");
        }
        res = nullptr;
        return true;
    }
};

template<class TSubObject>
class TObjectSetStateCaller : public NPyBind::TBaseMethodCaller<TSubObject> {
public:
    bool CallMethod(PyObject*, TSubObject *self, PyObject *args, PyObject*, PyObject *&res) const override {
        try {
            TString state;
            if (!NPyBind::ExtractArgs(args, state))
                ythrow yexception() << "Can't set state: it should be exactly one string parameter - the state";
            TStringInput in(state);
            ::Load(&in, *self);
            Py_INCREF(Py_None);
            res = Py_None;
            return true;
        } catch (const std::exception &ex) {
            PyErr_SetString(PyExc_RuntimeError, ex.what());
        } catch (...) {
            PyErr_SetString(PyExc_RuntimeError, "Can't set state: unknown exception");
        }
        res = nullptr;
        return true;
    }
};


TDiFactoLightApplicatorTraits::TDiFactoLightApplicatorTraits() : TParent("libdifacto.TDiFactoLightApplicator",
                                                                         "DiFacto binding for applying models from python")
{
    AddCaller("Predict", new TDiFactoLightApplicatorPredictCaller);
    AddCaller("PredictRaw", new TDiFactoLightApplicatorPredictRawCaller);
    AddCaller("ReadHashTable", new TDiFactoLightApplicatorReadHashTableCaller);
}

TDiFactoLightApplicatorHolder *TDiFactoLightApplicatorTraits::DoInitPureObject(const TVector <TString> &) {
    THolder<LightDifactoApplicator> mx(new LightDifactoApplicator);
    return new TDiFactoLightApplicatorHolder(mx.Release());
}

TDiFactoLightApplicatorHolder *TDiFactoLightApplicatorTraits::DoInitObject(PyObject *args, PyObject *) {
    try {
        TString configPath;
        if (!NPyBind::ExtractArgs(args, configPath))
            ythrow yexception() << "Can't create TMXNetMC: it should be exactly one parameter - path to model file";
        auto conf = read_conf(configPath);
        THolder<LightDifactoApplicator> mx(new LightDifactoApplicator(conf));
        return new TDiFactoLightApplicatorHolder(mx.Release());
    } catch (const std::exception &ex) {
        PyErr_SetString(PyExc_RuntimeError, ex.what());
    } catch (...) {
        PyErr_SetString(PyExc_RuntimeError, "Can't create TMXNetMC: unknown exception");
    }
    return nullptr;
}


static PyMethodDef LibDifactoMethods[] = {
        {nullptr, nullptr, 0, nullptr}
};

void DoInitLibdifacto() {
    PyObject* m = Py_InitModule("libdifacto", LibDifactoMethods);
    TDiFactoLightApplicatorTraits::Instance().Register(m, "DiFactoLightApplicator");
}
