#include <solomon/agent/lib/python2/code_module.h>
#include <solomon/agent/lib/python2/gil.h>
#include <solomon/agent/lib/python2/initializer.h>

#include <library/cpp/testing/gtest/gtest.h>

#include <util/system/tempfile.h>

using namespace NPython2;

class TCodeModuleTest: public ::testing::Test {
public:
    static constexpr TStringBuf TestPy =
            "def value():\n"
            "    return 42\n"
            "\n"
            "def increment(value):\n"
            "    return value + 1\n"sv;

    TObjectPtr TestCodeModule() {
        TTempFileHandle tempFile;
        tempFile.Write(TestPy.data(), TestPy.size());
        tempFile.FlushData();

        TObjectPtr module = LoadPyModule("test_module", tempFile.Name());
        Y_ENSURE(module);
        return module;
    }

private:
    TInitializer* Initializer_{TInitializer::Instance()};
    NPython2::TGilGuard Gil_;
};

TEST_F(TCodeModuleTest, Exec) {
    TTempFileHandle tempFile;
    tempFile.Write(TestPy.data(), TestPy.size());
    tempFile.FlushData();

    TObjectPtr module = LoadPyModule("my_module", tempFile.Name());
    ASSERT_TRUE(module);

    // (1) module has __name__ string
    {
        TObjectPtr name = PyObject_GetAttrString(module.Get(), "__name__");
        ASSERT_TRUE(name);
        ASSERT_TRUE(PyString_Check(name.Get()));

        const char* nameStr = PyString_AsString(name.Get());
        ASSERT_STREQ("my_module", nameStr);
    }

    // (2) module has __file__ string
    {
        TObjectPtr file = PyObject_GetAttrString(module.Get(), "__file__");
        ASSERT_TRUE(file);
        ASSERT_TRUE(PyString_Check(file.Get()));

        const char* fileStr = PyString_AsString(file.Get());
        ASSERT_EQ(tempFile.Name(), fileStr);
    }

    // (3) module has __builtins__ dict
    {
        TObjectPtr builtins =
                PyObject_GetAttrString(module.Get(), "__builtins__");
        ASSERT_TRUE(builtins);
        ASSERT_TRUE(PyDict_Check(builtins.Get()));
        ASSERT_TRUE(PyDict_Size(builtins.Get()) > 0);
    }
}

TEST_F(TCodeModuleTest, CallFunc) {
    TObjectPtr module = TestCodeModule();

    TObjectPtr func = PyObject_GetAttrString(module.Get(), "value");
    ASSERT_TRUE(func);
    ASSERT_TRUE(PyCallable_Check(func.Get()));

    TObjectPtr result = PyObject_Call(func.Get(), PyTuple_New(0), nullptr);
    ASSERT_TRUE(result);
    ASSERT_TRUE(PyInt_Check(result.Get()));

    long resultLong = PyInt_AsLong(result.Get());
    ASSERT_EQ(resultLong, 42L);
}

TEST_F(TCodeModuleTest, CallFuncWithArg) {
    TObjectPtr module = TestCodeModule();

    TObjectPtr func = PyObject_GetAttrString(module.Get(), "increment");
    ASSERT_TRUE(func);
    ASSERT_TRUE(PyCallable_Check(func.Get()));

    TObjectPtr args = Py_BuildValue("(I)", 42);
    TObjectPtr result = PyObject_Call(func.Get(), args.Get(), nullptr);
    ASSERT_TRUE(result);
    ASSERT_TRUE(PyInt_Check(result.Get()));

    long resultLong = PyInt_AsLong(result.Get());
    ASSERT_EQ(resultLong, 43L);
}
