#include "code_module.h"
#include "error.h"

#include <util/generic/string.h>
#include <util/generic/yexception.h>
#include <util/system/fstat.h>

#include <memory>

#include <Python-ast.h>
#include <marshal.h>
#include <stdio.h>

namespace NPython2 {

TObjectPtr ParseSourceModule(const TString& path, FILE* file) {
    std::unique_ptr<PyArena, decltype(&PyArena_Free)>
            arena(PyArena_New(), &PyArena_Free);
    PY_ENSURE(arena, "cannot allocate PyArena");

    PyCompilerFlags flags;
    flags.cf_flags = 0;

    mod_ty mod = PyParser_ASTFromFile(
                file, path.c_str(), Py_file_input,
                nullptr, nullptr, &flags,
                nullptr, arena.get());

    if (mod) {
        if (auto co = PyAST_Compile(mod, path.c_str(), nullptr, arena.get())) {
            return reinterpret_cast<PyObject*>(co);
        }
    }

    PY_FAIL("cannot compile file: " << path <<
            ", python error:\n" << LastErrorAsString());
}

TObjectPtr TryLoadCompiledModule(const TString& path, time_t mtime) {
    std::unique_ptr<FILE, decltype(&fclose)>
            file(fopen(path.c_str(), "rb"), &fclose);
    if (!file) {
        return nullptr;
    }

    long magic = PyMarshal_ReadLongFromFile(file.get());
    if (magic != PyImport_GetMagicNumber()) {
        return nullptr;
    }

    long pycMtime = PyMarshal_ReadLongFromFile(file.get());
    if (mtime != pycMtime) {
        return nullptr;
    }

    if (TObjectPtr co = PyMarshal_ReadLastObjectFromFile(file.get())) {
        if (!PyCode_Check(co.Get())) {
            return nullptr;
        }
        return co;
    }

    return nullptr;
}

TObjectPtr LoadCompiledModule(const TString& path, FILE* file) {
    long magic = PyMarshal_ReadLongFromFile(file);
    long expected = PyImport_GetMagicNumber();
    PY_ENSURE(magic == expected,
            "invalid python magic in file: " << path
            << " (" << magic << " != " << expected << ')');

    (void) PyMarshal_ReadLongFromFile(file);

    if (TObjectPtr co = PyMarshal_ReadLastObjectFromFile(file)) {
        PY_ENSURE_TYPE(PyCode_Check, co.Get(),
                "read non code object from file: " << path);
        return co;
    }

    PY_FAIL("cannot load compiled module from: " << path);
}

void TryWriteCompiledModule(PyObject* co, const TString& path, time_t mtime) {
    std::unique_ptr<FILE, decltype(&fclose)>
            file(fopen(path.c_str(), "wb"), &fclose);

    if (!file) {
        return;
    }

    long magic = PyImport_GetMagicNumber();
    PyMarshal_WriteLongToFile(magic, file.get(), Py_MARSHAL_VERSION);
    PyMarshal_WriteLongToFile(static_cast<long>(mtime), file.get(), Py_MARSHAL_VERSION);
    PyMarshal_WriteObjectToFile(co, file.get(), Py_MARSHAL_VERSION);

    bool failed = (fflush(file.get()) != 0 || ferror(file.get()));
    file.reset();

    if (failed) {
        (void) unlink(path.c_str());
    }
}

TObjectPtr LoadPyModule(const TString& name, const TString& path, bool ignorePyc) {
    PY_ENSURE(!path.empty(), "module path must be non empty");

    std::unique_ptr<FILE, decltype(&fclose)>
            file(fopen(path.c_str(), "r"), &fclose);
    PY_ENSURE(file, "failed to open python2 file: " << path);

    TObjectPtr module;
    TString modulePath;
    if (path.EndsWith(TStringBuf(".pyc"))) {
        PY_ENSURE(!ignorePyc, "trying to load " << path << ", but .pyc files are ignored");
        module = LoadCompiledModule(path, file.get());
        modulePath = path;
    } else {
        auto fileHandle = fileno(file.get());   // WIN FHANDLE is void*, and UNIX FHANDLE is int
        TFileStat st(reinterpret_cast<FHANDLE&>(fileHandle));

        if (!ignorePyc) {
            modulePath = path + 'c';
            module = TryLoadCompiledModule(modulePath, st.MTime);
        }

        if (!module) {
            module = ParseSourceModule(path, file.get());
            if (!ignorePyc) {
                TryWriteCompiledModule(module.Get(), modulePath, st.MTime);
            }
            modulePath = path;
        }
    }

    auto m = PyImport_ExecCodeModuleEx(
            const_cast<char*>(name.c_str()),
            module.Get(),
            const_cast<char*>(modulePath.c_str()));
    if (m) {
        return m;
    }

    PY_FAIL("cannot excecute python2 module, "
            "python error:\n" << LastErrorAsString());
}

} // namespace NPython2
