Skip to content

Commit

Permalink
Merge pull request #1213 from sys-bio/issue-1210-only-fix
Browse files Browse the repository at this point in the history
Fix for issue #1210
  • Loading branch information
luciansmith authored Apr 24, 2024
2 parents 92d063b + 92b3d02 commit f87c96e
Show file tree
Hide file tree
Showing 21 changed files with 215 additions and 100 deletions.
2 changes: 1 addition & 1 deletion CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@ cmake_minimum_required(VERSION 3.16)
# Version information and include modules

set(ROADRUNNER_VERSION_MAJOR 2)
set(ROADRUNNER_VERSION_MINOR 6)
set(ROADRUNNER_VERSION_MINOR 7)
set(ROADRUNNER_VERSION_PATCH 0)

set(ROADRUNNER_VERSION "${ROADRUNNER_VERSION_MAJOR}.${ROADRUNNER_VERSION_MINOR}.${ROADRUNNER_VERSION_PATCH}")
Expand Down
28 changes: 22 additions & 6 deletions source/llvm/Jit.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -34,13 +34,16 @@ using namespace rr;
namespace rrllvm {

Jit::Jit(std::uint32_t options)
: options(options),
context(std::make_unique<llvm::LLVMContext>()),
: options(options)
, context(std::make_unique<llvm::LLVMContext>())
// todo the module name should be the sbmlMD5. Might be cleaner to
// add this as a parameter to Jit constructor.
module(std::make_unique<llvm::Module>("LLVM Module", *context)),
moduleNonOwning(module.get()), /*Maintain a weak ref so we don't lose our handle to the module*/
builder(std::make_unique<llvm::IRBuilder<>>(*context)) {
, module(std::make_unique<llvm::Module>("LLVM Module", *context))
, moduleNonOwning(module.get()) /*Maintain a weak ref so we don't lose our handle to the module*/
, builder(std::make_unique<llvm::IRBuilder<>>(*context))
, compiledModuleBinaryStream(nullptr)
, moduleBuffer()
{


// IR module is initialized with just a ModuleID and a source filename
Expand All @@ -67,7 +70,7 @@ namespace rrllvm {
}

Jit::Jit()
: Jit(LoadSBMLOptions().modelGeneratorOpt) {}
: Jit(LoadSBMLOptions().modelGeneratorOpt) {}

llvm::Module *Jit::getModuleNonOwning() {
return moduleNonOwning;
Expand Down Expand Up @@ -307,6 +310,19 @@ namespace rrllvm {
FunctionType::get(double_type, args_d2, false));
}

std::string Jit::getModuleBinaryStreamAsString()
{
return compiledModuleBinaryStream->str().str();
}

void Jit::resetModuleBinaryStream(std::string cmbs)
{
//There might be a way to do this directly, but this is at least clean.
llvm::raw_svector_ostream* binarystream = new llvm::raw_svector_ostream(moduleBuffer);
*binarystream << cmbs;
compiledModuleBinaryStream.reset(binarystream);
}

/**
* getProcessTriple() - Return an appropriate target triple for generating
* code to be loaded into the current process, e.g. when using the JIT.
Expand Down
17 changes: 16 additions & 1 deletion source/llvm/Jit.h
Original file line number Diff line number Diff line change
Expand Up @@ -328,11 +328,26 @@ namespace rrllvm {
*/
virtual std::string getModuleAsString(std::string sbmlMD5) = 0;

/**
* @brief Return MCJit's compiled binary stream as a string.
* @details Converts the binary stream to a string. Useful for saving and
* loading state for MCJit objects (LLJit doesn't use it).
*/
virtual std::string getModuleBinaryStreamAsString();

/**
* @brief Reset MCJit's compiled binary stream with the given string.
* @details MCJit is the only Jit that uses this. (It's bad interface design
* but works at least.). LLJit uses a caching mechanism which allows us
* to retrieve object files directly, foregoing the need for this function.
*/
virtual void resetModuleBinaryStream(std::string cmbs);

/**
* @brief MCJit compiles the generated LLVM IR to this binary stream
* which is then used both for adding to the Jit as a module and for
* saveState.
* @details MCJit is the only Jit that uses this. (Its bad interface design
* @details MCJit is the only Jit that uses this. (It's bad interface design
* but works at least.). LLJit uses a caching mechanism which allows us
* to retrieve object files directly, foregoing the need for this variable.
*/
Expand Down
22 changes: 14 additions & 8 deletions source/llvm/JitFactory.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -6,22 +6,28 @@

namespace rrllvm {

std::unique_ptr<Jit> JitFactory::makeJitEngine(std::uint32_t opt) {
std::unique_ptr<Jit> jit;
Jit* JitFactory::makeJitEngine(std::uint32_t opt) {
rrLog(Logger::LOG_DEBUG) << __FUNC__;
Jit* jit = NULL;
if (opt & LoadSBMLOptions::MCJIT) {
jit = std::move(std::make_unique<MCJit>(opt));
rrLog(Logger::LOG_DEBUG) << "Creating an MCJit object.";
jit = new rrllvm::MCJit(opt);
}

else if (opt & LoadSBMLOptions::LLJIT) {
jit = std::move(std::make_unique<LLJit>(opt));
jit = new rrllvm::LLJit(opt);
}

else {
throw std::invalid_argument("Cannot create JIT object; need to say whether it's MCJit or LLJit in the options.");
}

return std::move(jit);
rrLog(Logger::LOG_DEBUG) << "Done creating a Jit object.";
return jit;
}

std::unique_ptr<Jit> JitFactory::makeJitEngine() {
Jit* JitFactory::makeJitEngine() {
LoadSBMLOptions opt;
std::unique_ptr<Jit> j = JitFactory::makeJitEngine(opt.modelGeneratorOpt);
return std::move(j);
return JitFactory::makeJitEngine(opt.modelGeneratorOpt);
}
}
4 changes: 2 additions & 2 deletions source/llvm/JitFactory.h
Original file line number Diff line number Diff line change
Expand Up @@ -37,15 +37,15 @@ namespace rrllvm {
/**
* @brief Create a Jit engine using local options provided by the user
*/
static std::unique_ptr<Jit> makeJitEngine(std::uint32_t opt);
static Jit* makeJitEngine(std::uint32_t opt);

/**
* @brief Create a Jit engine using the global options in Config.
* @details LoadSBMLOptions is populated based on the global Config values.
* This function instantiates the LoadSBMLOptions and provides the default
* modelGeneratorOpt to JitFactory::makeJitEngine(opt);
*/
static std::unique_ptr<Jit> makeJitEngine();
static Jit* makeJitEngine();

};

Expand Down
9 changes: 8 additions & 1 deletion source/llvm/LLJit.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -72,8 +72,15 @@ namespace rrllvm {
return mangledNameStream.str();
}

LLJit::LLJit()
: LLJit(LoadSBMLOptions().modelGeneratorOpt)
{
}

LLJit::LLJit(std::uint32_t options)
: Jit(options) {
: Jit(options)
, llJit()
{

// todo, can we cross compile providing a different host arch?

Expand Down
2 changes: 1 addition & 1 deletion source/llvm/LLJit.h
Original file line number Diff line number Diff line change
Expand Up @@ -42,7 +42,7 @@ namespace rrllvm {
class LLJit : public Jit {
public:

LLJit() = default;
LLJit();

~LLJit() override = default;

Expand Down
2 changes: 2 additions & 0 deletions source/llvm/LLVMExecutableModel.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -274,6 +274,8 @@ LLVMExecutableModel::LLVMExecutableModel(std::istream& in, uint modelGeneratorOp
conversionFactor(1.0),
flags(defaultFlags())
{

rrLog(Logger::LOG_DEBUG) << __FUNC__;
modelData = LLVMModelData_from_save(in);
resources->loadState(in, modelGeneratorOpt);

Expand Down
75 changes: 42 additions & 33 deletions source/llvm/MCJit.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -44,7 +44,8 @@ namespace rrllvm {
* file type.
*/
#if LLVM_VERSION_MAJOR == 6
llvm::LLVMTargetMachine::CodeGenFileType getCodeGenFileType(){
llvm::LLVMTargetMachine::CodeGenFileType getCodeGenFileType()
{
return llvm::TargetMachine::CGFT_ObjectFile;
}
#elif LLVM_VERSION_MAJOR >= 12
Expand All @@ -57,25 +58,33 @@ namespace rrllvm {


MCJit::MCJit(std::uint32_t opt)
: Jit(opt),
engineBuilder(EngineBuilder(std::move(module))) {

: Jit(opt)
, engineBuilder(EngineBuilder(std::move(module)))
, executionEngine()
, functionPassManager()
, errString()
{
compiledModuleBinaryStream = std::make_unique<llvm::raw_svector_ostream>(moduleBuffer);

engineBuilder
.setErrorStr(errString.get())
.setMCJITMemoryManager(std::make_unique<SectionMemoryManager>());
executionEngine = std::unique_ptr<ExecutionEngine>(engineBuilder.create());
.setErrorStr(errString.get())
.setMCJITMemoryManager(std::make_unique<SectionMemoryManager>());
executionEngine.reset(engineBuilder.create());

MCJit::mapFunctionsToJitSymbols();
MCJit::initFunctionPassManager();
}

MCJit::MCJit()
: MCJit(LoadSBMLOptions().modelGeneratorOpt)
{
}

ExecutionEngine *MCJit::getExecutionEngineNonOwning() const {
ExecutionEngine* MCJit::getExecutionEngineNonOwning() const {
return executionEngine.get();
}
std::string MCJit::mangleName(const std::string &unmangledName) const {

std::string MCJit::mangleName(const std::string& unmangledName) const {
std::string mangledName;
llvm::raw_string_ostream mangledNameStream(mangledName);
llvm::Mangler::getNameWithPrefix(mangledNameStream, unmangledName, getDataLayout());
Expand All @@ -86,9 +95,9 @@ namespace rrllvm {
llvm::sys::DynamicLibrary::LoadLibraryPermanently(nullptr); // for symbols in current process


for (auto [fnName, fnTy_addr_pair] : externalFunctionSignatures()){
for (auto [fnName, fnTy_addr_pair] : externalFunctionSignatures()) {
auto [fnTy, addr] = fnTy_addr_pair;
rrLogDebug << "Creating function \"" << fnName << "\"; fn type: " << toStringRef(fnTy).str() << "; at addr: " <<addr;
rrLogDebug << "Creating function \"" << fnName << "\"; fn type: " << toStringRef(fnTy).str() << "; at addr: " << addr;
Function::Create(fnTy, Function::ExternalLinkage, fnName, getModuleNonOwning());
llvm::sys::DynamicLibrary::AddSymbol(fnName, addr);
}
Expand All @@ -98,8 +107,8 @@ namespace rrllvm {
ModelDataIRBuilder::getCSRMatrixSetNZDecl(getModuleNonOwning());
ModelDataIRBuilder::getCSRMatrixGetNZDecl(getModuleNonOwning());
// Add the symbol to the library
llvm::sys::DynamicLibrary::AddSymbol(ModelDataIRBuilder::csr_matrix_set_nzName, (void*) rr::csr_matrix_set_nz);
llvm::sys::DynamicLibrary::AddSymbol(ModelDataIRBuilder::csr_matrix_get_nzName, (void*) rr::csr_matrix_get_nz);
llvm::sys::DynamicLibrary::AddSymbol(ModelDataIRBuilder::csr_matrix_set_nzName, (void*)rr::csr_matrix_set_nz);
llvm::sys::DynamicLibrary::AddSymbol(ModelDataIRBuilder::csr_matrix_get_nzName, (void*)rr::csr_matrix_get_nz);
}

void MCJit::transferObjectsToResources(std::shared_ptr<rrllvm::ModelResources> modelResources) {
Expand All @@ -112,12 +121,12 @@ namespace rrllvm {

}

std::uint64_t MCJit::lookupFunctionAddress(const std::string &name) {
void *v = executionEngine->getPointerToNamedFunction(mangleName(name));
return (std::uint64_t) v;
std::uint64_t MCJit::lookupFunctionAddress(const std::string& name) {
void* v = executionEngine->getPointerToNamedFunction(mangleName(name));
return (std::uint64_t)v;
}

llvm::TargetMachine *MCJit::getTargetMachine() {
llvm::TargetMachine* MCJit::getTargetMachine() {
return executionEngine->getTargetMachine();
}

Expand All @@ -138,7 +147,7 @@ namespace rrllvm {


llvm::Expected<std::unique_ptr<llvm::object::ObjectFile> > objectFileExpected =
llvm::object::ObjectFile::createObjectFile(obj->getMemBufferRef());
llvm::object::ObjectFile::createObjectFile(obj->getMemBufferRef());
if (!objectFileExpected) {
throw std::invalid_argument("Failed to load object data");
}
Expand All @@ -147,11 +156,11 @@ namespace rrllvm {
}


const llvm::DataLayout &MCJit::getDataLayout() const {
const llvm::DataLayout& MCJit::getDataLayout() const {
return getExecutionEngineNonOwning()->getDataLayout();
}

void MCJit::addModule(llvm::Module *M) {
void MCJit::addModule(llvm::Module* M) {

}

Expand All @@ -169,16 +178,16 @@ namespace rrllvm {

if (compiledModuleBinaryStream->str().empty()) {
std::string err = "Attempt to add module before its been written to binary. Make a call to "
"MCJit::writeObjectToBinaryStream() before addModule()";
"MCJit::writeObjectToBinaryStream() before addModule()";
rrLogErr << err;
throw_llvm_exception(err);
}

auto memBuffer(llvm::MemoryBuffer::getMemBuffer(compiledModuleBinaryStream->str().str()));
auto memBuffer(llvm::MemoryBuffer::getMemBuffer(getModuleBinaryStreamAsString()));

llvm::Expected<std::unique_ptr<llvm::object::ObjectFile> > objectFileExpected =
llvm::object::ObjectFile::createObjectFile(
llvm::MemoryBufferRef(compiledModuleBinaryStream->str(), "id"));
llvm::object::ObjectFile::createObjectFile(
llvm::MemoryBufferRef(compiledModuleBinaryStream->str(), "id"));

if (!objectFileExpected) {
//LS DEBUG: find a way to get the text out of the error.
Expand All @@ -196,7 +205,7 @@ namespace rrllvm {
getExecutionEngineNonOwning()->finalizeObject();
}

std::unique_ptr<llvm::MemoryBuffer> MCJit::getCompiledModelFromCache(const std::string &sbmlMD5) {
std::unique_ptr<llvm::MemoryBuffer> MCJit::getCompiledModelFromCache(const std::string& sbmlMD5) {
return nullptr;
}

Expand All @@ -213,8 +222,8 @@ namespace rrllvm {

//Write the object file to modBufferOut
std::error_code EC;
// llvm::SmallVector<char, 10> modBufferOut;
// postOptimizedModuleStream(modBufferOut);
// llvm::SmallVector<char, 10> modBufferOut;
// postOptimizedModuleStream(modBufferOut);

llvm::legacy::PassManager pass;
auto FileType = getCodeGenFileType();
Expand Down Expand Up @@ -244,10 +253,10 @@ namespace rrllvm {
* Note to developers - passes are stored in llvm/Transforms/Scalar.h.
*/

// we only support LLVM >= 3.1
// we only support LLVM >= 3.1
#if (LLVM_VERSION_MAJOR == 3) && (LLVM_VERSION_MINOR == 1)
//#if (LLVM_VERSION_MAJOR == 6)
functionPassManager->add(new TargetData(*executionEngine->getTargetData()));
functionPassManager->add(new TargetData(*executionEngine->getTargetData()));
#elif (LLVM_VERSION_MAJOR == 3) && (LLVM_VERSION_MINOR <= 4)
functionPassManager->add(new DataLayout(*executionEngine->getDataLayout()));
#elif (LLVM_VERSION_MINOR > 4)
Expand All @@ -271,7 +280,7 @@ namespace rrllvm {
functionPassManager->add(createInstSimplifyLegacyPass());
#else
rrLogWarn << "Not using llvm optimization \"OPTIMIZE_INSTRUCTION_SIMPLIFIER\" "
"because llvm version is " << LLVM_VERSION_MAJOR;
"because llvm version is " << LLVM_VERSION_MAJOR;
#endif
}

Expand Down Expand Up @@ -300,7 +309,7 @@ namespace rrllvm {
// or replaced with createDeadCodeEliminationPass, which we add below anyway
#else
rrLogWarn << "Not using OPTIMIZE_DEAD_INST_ELIMINATION because you are using"
"LLVM version " << LLVM_VERSION_MAJOR;
"LLVM version " << LLVM_VERSION_MAJOR;
#endif
}

Expand All @@ -313,11 +322,11 @@ namespace rrllvm {
}
}

llvm::legacy::FunctionPassManager *MCJit::getFunctionPassManager() const {
llvm::legacy::FunctionPassManager* MCJit::getFunctionPassManager() const {
return functionPassManager.get();
}

llvm::raw_svector_ostream &MCJit::getCompiledModuleStream() {
llvm::raw_svector_ostream& MCJit::getCompiledModuleStream() {
return *compiledModuleBinaryStream;
}

Expand Down
8 changes: 8 additions & 0 deletions source/llvm/MCJit.h
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,14 @@ namespace rrllvm {

explicit MCJit(std::uint32_t options);

/**
* @brief default constructor.
* @details delegates to MCJit(std::uint32_t options). The options
* argument is the default constructed from LoadSBMLOptions.modelGeneratorOpt.
* Note, that LoadSBMLOptions is influenced by the global Config.
*/
MCJit();

~MCJit() override = default;


Expand Down
Loading

0 comments on commit f87c96e

Please sign in to comment.