forked from ros-drivers/usb_cam
-
Notifications
You must be signed in to change notification settings - Fork 0
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
learning: WIP do depth anything directly
- Loading branch information
Showing
9 changed files
with
217 additions
and
238 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,45 +1,45 @@ | ||
#!/bin/bash | ||
source /opt/ros/noetic/setup.bash | ||
|
||
# Check if the plan file exists before generating it | ||
echo $LD_LIBRARY_PATH | ||
if [ ! -f "test/resources/raft-small.plan" ]; then | ||
echo "Plan file not found. Generating plan file..." | ||
if /usr/src/tensorrt/bin/trtexec --buildOnly --onnx="test/resources/raft-small.onnx" --saveEngine="test/resources/raft-small.plan" --plugins="/usr/lib/x86_64-linux-gnu/libnvinfer_plugin.so" | ||
then | ||
echo "Plan file generation successful" | ||
# Set paths for the model and plan files | ||
MODEL_PATH="test/resources/depth_anything_v2_vitb.onnx" | ||
MODEL_URL="https://github.com/fabio-sim/Depth-Anything-ONNX/releases/download/v2.0.0/depth_anything_v2_vitb.onnx" | ||
|
||
# Step 1: Check if the ONNX model file exists | ||
if [ ! -f "$MODEL_PATH" ]; then | ||
echo "ONNX model file not found. Downloading..." | ||
if wget -O "$MODEL_PATH" "$MODEL_URL"; then | ||
echo "Model downloaded successfully." | ||
else | ||
echo "Plan file generation failed" | ||
echo "Model download failed." | ||
exit 1 | ||
fi | ||
else | ||
echo "Plan file already exists. Skipping generation." | ||
echo "ONNX model file already exists. Skipping download." | ||
fi | ||
|
||
# Build the project and run tests | ||
rm -rf build | ||
mkdir -p build | ||
cd build | ||
if cmake .. -DBUILD_TESTING=ON | ||
then | ||
echo "CMake successfull" | ||
if make test_learning_interface | ||
then | ||
echo "Make successfull" | ||
|
||
if cmake .. -DBUILD_TESTING=ON; then | ||
echo "CMake successful." | ||
if make test_depth_anything_v2; then | ||
echo "Make successful." | ||
else | ||
echo "Make failed" | ||
echo "Make failed." | ||
exit 1 | ||
fi | ||
else | ||
echo "CMake failed" | ||
echo "CMake failed." | ||
exit 1 | ||
fi | ||
|
||
if ./devel/lib/usb_cam/test_learning_interface | ||
then | ||
echo "Tests successful" | ||
# Run the test executable | ||
if ./devel/lib/usb_cam/test_depth_anything_v2; then | ||
echo "Tests successful." | ||
else | ||
echo "Tests failed" | ||
echo "Tests failed." | ||
exit 1 | ||
fi | ||
|
||
|
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,23 @@ | ||
#ifndef DEPTH_ANYTHING_HPP_ | ||
#define DEPTH_ANYTHING_HPP_ | ||
|
||
#include "interface.hpp" | ||
#include <opencv2/opencv.hpp> | ||
#include <vector> | ||
|
||
class DepthAnythingV2 : public LearningInterface { | ||
public: | ||
DepthAnythingV2(std::string model_path) { | ||
_model_path = model_path; | ||
} | ||
|
||
void get_output(uint8_t* output_buffer) override { | ||
// TODO | ||
} | ||
|
||
void publish() override { | ||
// TODO | ||
} | ||
}; | ||
|
||
#endif // DEPTH_ANYTHING_HPP_ |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,73 +1,58 @@ | ||
#ifndef LEARNING_INTERFACE_HPP_ | ||
#define LEARNING_INTERFACE_HPP_ | ||
|
||
#include <cassert> | ||
#include <cuda_runtime.h> | ||
#include <fstream> | ||
#include <iostream> | ||
#include <NvInfer.h> | ||
#include <NvInferPlugin.h> | ||
#include <NvInferRuntime.h> | ||
#include <NvInferRuntimeCommon.h> | ||
#include <NvOnnxParser.h> | ||
#include <sstream> | ||
#include <fstream> | ||
#include <string> | ||
#include <vector> | ||
#include <tuple> | ||
#include <algorithm> | ||
#include <opencv2/opencv.hpp> | ||
#include <NvInfer.h> | ||
|
||
class LearningInterface { | ||
public: | ||
LearningInterface() : _model_path("") { | ||
// Instantiate the logger and initialize plugins | ||
if (!initLibNvInferPlugins(static_cast<void*>(&_logger), "")) { | ||
std::cerr << "Error: Failed to initialize TensorRT plugins." << std::endl; | ||
throw std::runtime_error("Failed to initialize TensorRT plugins."); | ||
} | ||
} | ||
LearningInterface() : _model_path("") {} | ||
|
||
void set_input(cv::Mat input_image); | ||
|
||
virtual void set_input(const uint8_t* input_buffer, size_t height, size_t width) = 0; | ||
virtual void get_output(uint8_t* output_buffer) = 0; | ||
virtual void publish() = 0; | ||
|
||
void load_model(); | ||
bool run_inference(size_t batch_size); | ||
|
||
virtual ~LearningInterface() { | ||
// Release allocated CUDA memory | ||
if (_buffers[0]) cudaFree(_buffers[0]); | ||
if (_buffers[1]) cudaFree(_buffers[1]); | ||
|
||
delete[] _input_buffer; | ||
delete[] _output_buffer; | ||
} | ||
void predict(); | ||
|
||
float* get_input_buffer() { return _input_buffer; } | ||
nvinfer1::ICudaEngine* get_engine() { return _engine; } | ||
nvinfer1::IExecutionContext* get_context() { return _context; } | ||
nvinfer1::IRuntime* get_runtime() { return _runtime; } | ||
|
||
~LearningInterface(); | ||
|
||
protected: | ||
float* _input_buffer = nullptr; | ||
float* _output_buffer = nullptr; | ||
nvinfer1::ICudaEngine* _engine = nullptr; | ||
nvinfer1::IExecutionContext* _context = nullptr; | ||
nvinfer1::IRuntime* _runtime = nullptr; | ||
size_t input_height; | ||
size_t input_width; | ||
size_t output_height; | ||
size_t output_width; | ||
cudaStream_t _stream; | ||
float* _input_data = nullptr; | ||
float* _output_data = nullptr; | ||
nvinfer1::ICudaEngine* _engine; | ||
nvinfer1::IExecutionContext* _context; | ||
nvinfer1::INetworkDefinition* _network; | ||
nvinfer1::IRuntime* _runtime; | ||
std::string _model_path; | ||
|
||
private: | ||
void* _buffers[2] = { nullptr, nullptr }; | ||
|
||
// TODO: static? | ||
class Logger : public nvinfer1::ILogger { | ||
public: | ||
void log(Severity severity, const char* msg) noexcept override { | ||
if (severity <= Severity::kWARNING) { // Limit logging to warnings and errors | ||
// Only output logs with severity greater than warning | ||
if (severity <= Severity::kWARNING) { | ||
std::cout << msg << std::endl; | ||
} | ||
} | ||
}; | ||
Logger _logger; | ||
} _logger; | ||
|
||
bool _save_engine(const std::string& onnx_path); | ||
void _build(std::string onnx_path); | ||
}; | ||
|
||
#endif // LEARNING_INTERFACE_HPP_ |
This file was deleted.
Oops, something went wrong.
Oops, something went wrong.