Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Make sure to properly support newer versions of onnxruntime #104

Merged
merged 1 commit into from
Jul 30, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
11 changes: 10 additions & 1 deletion CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -28,11 +28,20 @@ find_package(k4FWCore) # implicit: Gaudi
find_package(EDM4HEP) # implicit: Podio
find_package(DD4hep)
find_package(k4geo)
find_package(ONNXRuntime)
find_package(FastJet)
# New versions of ONNRuntime package provide onnxruntime-Config.cmake
# and use the name onnxruntime
find_package(onnxruntime)
if (NOT onnxruntime_FOUND)
message(STATUS "Could not find onnxruntime (> 1.17.1). Looking for an older version")
find_package(ONNXRuntime)
endif()

if(onnxruntime_FOUND OR ONNXRuntime_FOUND)
else()
message(FATAL_ERROR "Failed to locate ONNXRuntime!")
endif()

#---------------------------------------------------------------


Expand Down
13 changes: 1 addition & 12 deletions RecFCCeeCalorimeter/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -2,17 +2,6 @@
# Package: RecFCCeeCalorimeter
################################################################################

# ONNX DEPENDENCIES
# includes
include_directories("${ONNXRUNTIME_INCLUDE_DIRS}")
# libs
link_directories("${ONNXRUNTIME_LIBRARY_DIRS}")
# New versions of ONNXRuntime add directories to include
# through the target onnxruntime::onnxruntime
if(onnxruntime_FOUND)
set(ONNXRUNTIME_LIBRARY onnxruntime::onnxruntime)
endif()

file(GLOB _module_sources src/components/*.cpp)
gaudi_add_module(k4RecFCCeeCalorimeterPlugins
SOURCES ${_module_sources}
Expand All @@ -26,7 +15,7 @@ gaudi_add_module(k4RecFCCeeCalorimeterPlugins
DD4hep::DDG4
ROOT::Core
ROOT::Hist
${ONNXRUNTIME_LIBRARY}
onnxruntime::onnxruntime
nlohmann_json::nlohmann_json
)
install(TARGETS k4RecFCCeeCalorimeterPlugins
Expand Down
51 changes: 28 additions & 23 deletions RecFCCeeCalorimeter/src/components/CalibrateCaloClusters.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -13,23 +13,18 @@
#include "edm4hep/Cluster.h"
#include "edm4hep/ClusterCollection.h"
#include "edm4hep/CalorimeterHitCollection.h"
#include <onnxruntime_cxx_api.h>

#include "OnnxruntimeUtilities.h"

DECLARE_COMPONENT(CalibrateCaloClusters)

// convert vector data with given shape into ONNX runtime tensor
template <typename T>
Ort::Value vec_to_tensor(std::vector<T> &data, const std::vector<std::int64_t> &shape)
{
Ort::MemoryInfo mem_info =
Ort::MemoryInfo::CreateCpu(OrtAllocatorType::OrtArenaAllocator, OrtMemType::OrtMemTypeDefault);
auto tensor = Ort::Value::CreateTensor<T>(mem_info, data.data(), data.size(), shape.data(), shape.size());
return tensor;
}

CalibrateCaloClusters::CalibrateCaloClusters(const std::string &name,
ISvcLocator *svcLoc)
: Gaudi::Algorithm(name, svcLoc),
m_geoSvc("GeoSvc", "CalibrateCaloClusters")
m_geoSvc("GeoSvc", "CalibrateCaloClusters"),
m_ortMemInfo(Ort::MemoryInfo::CreateCpu(OrtAllocatorType::OrtArenaAllocator, OrtMemType::OrtMemTypeDefault))
{
declareProperty("inClusters", m_inClusters,
"Input cluster collection");
Expand Down Expand Up @@ -227,7 +222,7 @@ StatusCode CalibrateCaloClusters::readCalibrationFile(const std::string &calibra
m_ortEnv = new Ort::Env(loggingLevel, "ONNX runtime environment");
Ort::SessionOptions session_options;
session_options.SetIntraOpNumThreads(1);
m_ortSession = new Ort::Experimental::Session(*m_ortEnv, const_cast<std::string &>(calibrationFile), session_options);
m_ortSession = new Ort::Session(*m_ortEnv, calibrationFile.data(), session_options);
}
catch (const Ort::Exception &exception)
{
Expand All @@ -239,12 +234,19 @@ StatusCode CalibrateCaloClusters::readCalibrationFile(const std::string &calibra
// use default allocator (CPU)
Ort::AllocatorWithDefaultOptions allocator;
debug() << "Input Node Name/Shape (" << m_input_names.size() << "):" << endmsg;
#if ORT_API_VERSION < 13
// Before 1.13 we have to roll our own unique_ptr wrapper here
auto allocDeleter = [&allocator](char* p) { allocator.Free(p); };
using AllocatedStringPtr = std::unique_ptr<char, decltype(allocDeleter)>;
#endif

for (std::size_t i = 0; i < m_ortSession->GetInputCount(); i++)
{
// for old ONNX runtime version
// m_input_names.emplace_back(m_ortSession->GetInputName(i, allocator));
// for new runtime version
m_input_names.emplace_back(m_ortSession->GetInputNameAllocated(i, allocator).get());
#if ORT_API_VERSION < 13
m_input_names.emplace_back(AllocatedStringPtr(m_ortSession->GetInputName(i, allocator), allocDeleter).release());
#else
m_input_names.emplace_back(m_ortSession->GetInputNameAllocated(i, allocator).release());
#endif
m_input_shapes = m_ortSession->GetInputTypeInfo(i).GetTensorTypeAndShapeInfo().GetShape();
debug() << "\t" << m_input_names.at(i) << " : ";
for (std::size_t k = 0; k < m_input_shapes.size() - 1; k++)
Expand All @@ -266,10 +268,11 @@ StatusCode CalibrateCaloClusters::readCalibrationFile(const std::string &calibra
debug() << "Output Node Name/Shape (" << m_output_names.size() << "):" << endmsg;
for (std::size_t i = 0; i < m_ortSession->GetOutputCount(); i++)
{
// for old ONNX runtime version
// m_output_names.emplace_back(m_ortSession->GetOutputName(i, allocator));
// for new runtime version
#if ORT_API_VERSION < 13
m_output_names.emplace_back(AllocatedStringPtr(m_ortSession->GetOutputName(i, allocator), allocDeleter).release());
#else
m_output_names.emplace_back(m_ortSession->GetOutputNameAllocated(i, allocator).get());
giovannimarchiori marked this conversation as resolved.
Show resolved Hide resolved
#endif
m_output_shapes = m_ortSession->GetOutputTypeInfo(i).GetTensorTypeAndShapeInfo().GetShape();
debug() << "\t" << m_output_names.at(i) << " : ";
for (std::size_t k = 0; k < m_output_shapes.size() - 1; k++)
Expand Down Expand Up @@ -329,18 +332,20 @@ StatusCode CalibrateCaloClusters::calibrateClusters(const edm4hep::ClusterCollec
float corr = 1.0;
// Create a single Ort tensor
std::vector<Ort::Value> input_tensors;
input_tensors.emplace_back(vec_to_tensor<float>(energiesInLayers, m_input_shapes));
input_tensors.emplace_back(vec_to_tensor<float>(energiesInLayers, m_input_shapes, m_ortMemInfo));

// double-check the dimensions of the input tensor
// assert(input_tensors[0].IsTensor() && input_tensors[0].GetTensorTypeAndShapeInfo().GetShape() == m_input_shapes);

// pass data through model
try
{
std::vector<Ort::Value> output_tensors = m_ortSession->Run(m_input_names,
input_tensors,
m_output_names,
Ort::RunOptions{nullptr});
auto output_tensors = m_ortSession->Run(Ort::RunOptions{nullptr},
m_input_names.data(),
input_tensors.data(),
input_tensors.size(),
m_output_names.data(),
m_output_names.size());

// double-check the dimensions of the output tensors
// NOTE: the number of output tensors is equal to the number of output nodes specifed in the Run() call
Expand Down
9 changes: 5 additions & 4 deletions RecFCCeeCalorimeter/src/components/CalibrateCaloClusters.h
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,7 @@ namespace dd4hep {
}

// ONNX
#include "onnxruntime/core/session/experimental_onnxruntime_cxx_api.h"
#include "onnxruntime_cxx_api.h"

/** @class CalibrateCaloClusters
*
Expand Down Expand Up @@ -149,12 +149,13 @@ class CalibrateCaloClusters : public Gaudi::Algorithm {

// the ONNX runtime session for applying the calibration,
// the environment, and the input and output shapes and names
Ort::Experimental::Session* m_ortSession = nullptr;
Ort::Session* m_ortSession = nullptr;
Ort::Env* m_ortEnv = nullptr;
Ort::MemoryInfo m_ortMemInfo;
std::vector<std::int64_t> m_input_shapes;
std::vector<std::int64_t> m_output_shapes;
std::vector<std::string> m_input_names;
std::vector<std::string> m_output_names;
std::vector<const char*> m_input_names;
std::vector<const char*> m_output_names;

// the indices of the shapeParameters containing the inputs to the model (if they exist)
std::vector<unsigned short int> m_inputPositionsInShapeParameters;
Expand Down
18 changes: 18 additions & 0 deletions RecFCCeeCalorimeter/src/components/OnnxruntimeUtilities.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,18 @@
#ifndef RECFCCEECALORIMETER_ONNXRUNTIMEUTILITIES_H
#define RECFCCEECALORIMETER_ONNXRUNTIMEUTILITIES_H

#include "onnxruntime_cxx_api.h"

#include <vector>

// convert vector data with given shape into ONNX runtime tensor
template <typename T>
Ort::Value vec_to_tensor(std::vector<T> &data,
const std::vector<std::int64_t> &shape,
const Ort::MemoryInfo &mem_info) {
auto tensor = Ort::Value::CreateTensor<T>(mem_info, data.data(), data.size(),
shape.data(), shape.size());
return tensor;
}

#endif
56 changes: 31 additions & 25 deletions RecFCCeeCalorimeter/src/components/PhotonIDTool.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -8,24 +8,18 @@

#include "nlohmann/json.hpp"

#include "OnnxruntimeUtilities.h"

using json = nlohmann::json;


DECLARE_COMPONENT(PhotonIDTool)

// convert vector data with given shape into ONNX runtime tensor
template <typename T>
Ort::Value vec_to_tensor(std::vector<T> &data, const std::vector<std::int64_t> &shape)
{
Ort::MemoryInfo mem_info =
Ort::MemoryInfo::CreateCpu(OrtAllocatorType::OrtArenaAllocator, OrtMemType::OrtMemTypeDefault);
auto tensor = Ort::Value::CreateTensor<T>(mem_info, data.data(), data.size(), shape.data(), shape.size());
return tensor;
}

PhotonIDTool::PhotonIDTool(const std::string &name,
ISvcLocator *svcLoc)
: Gaudi::Algorithm(name, svcLoc)
: Gaudi::Algorithm(name, svcLoc),
m_ortMemInfo(Ort::MemoryInfo::CreateCpu(OrtAllocatorType::OrtArenaAllocator, OrtMemType::OrtMemTypeDefault))
{
declareProperty("inClusters", m_inClusters, "Input cluster collection");
declareProperty("outClusters", m_outClusters, "Output cluster collection");
Expand Down Expand Up @@ -296,8 +290,7 @@ StatusCode PhotonIDTool::readMVAFiles(const std::string& mvaInputsFileName,
m_ortEnv = new Ort::Env(loggingLevel, "ONNX runtime environment for photonID");
Ort::SessionOptions session_options;
session_options.SetIntraOpNumThreads(1);
m_ortSession = new Ort::Experimental::Session(*m_ortEnv, const_cast<std::string &>(mvaModelFileName), session_options);
// m_ortSession = new Ort::Session(*m_ortEnv, const_cast<std::string &>(mvaModelFileName), session_options);
m_ortSession = new Ort::Session(*m_ortEnv, mvaModelFileName.data(), session_options);
}
catch (const Ort::Exception &exception)
{
Expand All @@ -308,14 +301,23 @@ StatusCode PhotonIDTool::readMVAFiles(const std::string& mvaInputsFileName,
// print name/shape of inputs
// use default allocator (CPU)
Ort::AllocatorWithDefaultOptions allocator;
#if ORT_API_VERSION < 13
// Before 1.13 we have to roll our own unique_ptr wrapper here
auto allocDeleter = [&allocator](char* p) { allocator.Free(p); };
using AllocatedStringPtr = std::unique_ptr<char, decltype(allocDeleter)>;
#endif


debug() << "Input Node Name/Shape (" << m_ortSession->GetInputCount() << "):" << endmsg;
for (std::size_t i = 0; i < m_ortSession->GetInputCount(); i++)
{
// for old ONNX runtime version
// m_input_names.emplace_back(m_ortSession->GetInputName(i, allocator));
// for new runtime version
m_input_names.emplace_back(m_ortSession->GetInputNameAllocated(i, allocator).get());
m_input_shapes = m_ortSession->GetInputTypeInfo(i).GetTensorTypeAndShapeInfo().GetShape();
#if ORT_API_VERSION < 13
m_input_names.emplace_back(AllocatedStringPtr(m_ortSession->GetInputName(i, allocator), allocDeleter).release());
#else
m_input_names.emplace_back(m_ortSession->GetInputNameAllocated(i, allocator).release());
#endif

m_input_shapes = m_ortSession->GetInputTypeInfo(i).GetTensorTypeAndShapeInfo().GetShape();
debug() << "\t" << m_input_names.at(i) << " : ";
for (std::size_t k = 0; k < m_input_shapes.size() - 1; k++)
{
Expand All @@ -336,10 +338,12 @@ StatusCode PhotonIDTool::readMVAFiles(const std::string& mvaInputsFileName,
debug() << "Output Node Name/Shape (" << m_ortSession->GetOutputCount() << "):" << endmsg;
for (std::size_t i = 0; i < m_ortSession->GetOutputCount(); i++)
{
// for old ONNX runtime version
// m_output_names.emplace_back(m_ortSession->GetOutputName(i, allocator));
// for new runtime version
#if ORT_API_VERSION < 13
m_output_names.emplace_back(AllocatedStringPtr(m_ortSession->GetOutputName(i, allocator), allocDeleter).release());
#else
m_output_names.emplace_back(m_ortSession->GetOutputNameAllocated(i, allocator).get());
giovannimarchiori marked this conversation as resolved.
Show resolved Hide resolved
#endif

m_output_shapes = m_ortSession->GetOutputTypeInfo(i).GetTensorTypeAndShapeInfo().GetShape();
debug() << m_output_shapes.size() << endmsg;
debug() << "\t" << m_output_names.at(i) << " : ";
Expand Down Expand Up @@ -384,15 +388,17 @@ StatusCode PhotonIDTool::applyMVAtoClusters(const edm4hep::ClusterCollection *in
float score= -1.0;
// Create a single Ort tensor
std::vector<Ort::Value> input_tensors;
input_tensors.emplace_back(vec_to_tensor<float>(mvaInputs, m_input_shapes));
input_tensors.emplace_back(vec_to_tensor<float>(mvaInputs, m_input_shapes, m_ortMemInfo));

// pass data through model
try
{
std::vector<Ort::Value> output_tensors = m_ortSession->Run(m_input_names,
input_tensors,
m_output_names,
Ort::RunOptions{nullptr});
auto output_tensors = m_ortSession->Run(Ort::RunOptions{nullptr},
m_input_names.data(),
input_tensors.data(),
input_tensors.size(),
m_output_names.data(),
m_output_names.size());

// double-check the dimensions of the output tensors
// NOTE: the number of output tensors is equal to the number of output nodes specified in the Run() call
Expand Down
9 changes: 5 additions & 4 deletions RecFCCeeCalorimeter/src/components/PhotonIDTool.h
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@ namespace edm4hep {
}

// ONNX
#include "onnxruntime/core/session/experimental_onnxruntime_cxx_api.h"
#include "onnxruntime_cxx_api.h"

/** @class PhotonIDTool
*
Expand Down Expand Up @@ -100,12 +100,13 @@ class PhotonIDTool : public Gaudi::Algorithm {

// the ONNX runtime session for running the inference,
// the environment, and the input and output shapes and names
Ort::Experimental::Session* m_ortSession = nullptr;
Ort::Session* m_ortSession = nullptr;
Ort::Env* m_ortEnv = nullptr;
Ort::MemoryInfo m_ortMemInfo;
std::vector<std::int64_t> m_input_shapes;
std::vector<std::int64_t> m_output_shapes;
std::vector<std::string> m_input_names;
std::vector<std::string> m_output_names;
std::vector<const char*> m_input_names;
std::vector<const char*> m_output_names;
std::vector<std::string> m_internal_input_names;

// the indices of the shapeParameters containing the inputs to the model (-1 if not found)
Expand Down
15 changes: 13 additions & 2 deletions cmake/FindONNXRuntime.cmake
Original file line number Diff line number Diff line change
@@ -1,4 +1,6 @@
find_path(ONNXRUNTIME_INCLUDE_DIR onnxruntime/core/session/onnxruntime_cxx_inline.h
find_path(ONNXRUNTIME_INCLUDE_DIR
NAMES onnxruntime_cxx_api.h
PATH_SUFFIXES onnxruntime/core/session
HINTS $ENV{ONNXRUNTIME_ROOT_DIR}/include ${ONNXRUNTIME_ROOT_DIR}/include)

find_library(ONNXRUNTIME_LIBRARY NAMES onnxruntime
Expand All @@ -11,4 +13,13 @@ mark_as_advanced(ONNXRUNTIME_FOUND ONNXRUNTIME_INCLUDE_DIR ONNXRUNTIME_LIBRARY)

set(ONNXRUNTIME_INCLUDE_DIRS ${ONNXRUNTIME_INCLUDE_DIR})
set(ONNXRUNTIME_LIBRARIES ${ONNXRUNTIME_LIBRARY})
get_filename_component(ONNXRUNTIME_LIBRARY_DIRS ${ONNXRUNTIME_LIBRARY} PATH)

# Rig an onnxruntime::onnxruntime target that works similar (enough) to the one
# that can be directly found via find_package(onnxruntime) for newer versions of
# onnxruntime
add_library(onnxruntime::onnxruntime INTERFACE IMPORTED GLOBAL)
set_target_properties(onnxruntime::onnxruntime
PROPERTIES
INTERFACE_INCLUDE_DIRECTORIES "${ONNXRUNTIME_INCLUDE_DIRS}"
INTERFACE_LINK_LIBRARIES "${ONNXRUNTIME_LIBRARIES}"
)
Loading