Skip to content

Commit

Permalink
depth_anything: Add publication method
Browse files Browse the repository at this point in the history
  • Loading branch information
marcojob committed Nov 11, 2024
1 parent 34f32aa commit 8683434
Show file tree
Hide file tree
Showing 5 changed files with 23 additions and 20 deletions.
27 changes: 19 additions & 8 deletions include/usb_cam/learning/depth_anything_v2.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -2,19 +2,23 @@
#define DEPTH_ANYTHING_HPP_

#include "interface.hpp"
#include "ros/ros.h"
#include <cv_bridge/cv_bridge.h>
#include <opencv2/opencv.hpp>
#include <sensor_msgs/Image.h>

class DepthAnythingV2 : public LearningInterface {
public:
DepthAnythingV2(std::string model_path) {
_model_path = model_path;
DepthAnythingV2(ros::NodeHandle* nh, std::string model_path) {
_INPUT_SIZE = cv::Size(_HEIGHT, _WIDTH);
_model_path = model_path;

if (nh != nullptr) {
_depth_publication = nh->advertise<sensor_msgs::Image>("depth_anything_v2", 1);
}
}

void set_input(sensor_msgs::Image& msg) override {
// From ROS msg image to cv mat
cv_bridge::CvImagePtr cv_ptr = cv_bridge::toCvCopy(msg, sensor_msgs::image_encodings::RGB8);
cv::Mat image = cv_ptr->image;

Expand All @@ -33,18 +37,25 @@ class DepthAnythingV2 : public LearningInterface {
_input_data = float_image.reshape(1, 1).ptr<float>(0);
}

void get_output(uint8_t* output_buffer) override {
// TODO
}

void publish() override {
// TODO
cv::Mat depth_prediction = cv::Mat(_HEIGHT, _WIDTH, CV_32FC1, _output_data);

cv_bridge::CvImage depth_image;
depth_image.header.stamp = ros::Time::now(); // Set the timestamp
depth_image.header.frame_id = "depth_frame"; // Set the frame ID (update as needed)
depth_image.encoding = sensor_msgs::image_encodings::TYPE_32FC1; // Depth is typically float32, single channel
depth_image.image = depth_prediction;

if (_depth_publication.getTopic() != "") {
_depth_publication.publish(depth_image.toImageMsg());
}
}

private:
const size_t _HEIGHT = 518;
const size_t _WIDTH = 518;
cv::Size _INPUT_SIZE;
ros::Publisher _depth_publication;
};

#endif // DEPTH_ANYTHING_HPP_
1 change: 0 additions & 1 deletion include/usb_cam/learning/interface.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,6 @@ class LearningInterface {
LearningInterface() : _model_path("") {}

virtual void set_input(sensor_msgs::Image& image) = 0;
virtual void get_output(uint8_t* output_buffer) = 0;
virtual void publish() = 0;

void load_model();
Expand Down
2 changes: 0 additions & 2 deletions src/interface.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -61,7 +61,6 @@ void LearningInterface::_build(std::string onnx_path) {
}

bool LearningInterface::_save_engine(const std::string& onnx_path) {
// Create an engine path from onnx path
std::string engine_path;
size_t dot_index = onnx_path.find_last_of(".");
if (dot_index != std::string::npos) {
Expand All @@ -71,7 +70,6 @@ bool LearningInterface::_save_engine(const std::string& onnx_path) {
return false;
}

// Save the engine to the path
if (_engine) {
nvinfer1::IHostMemory* data = _engine->serialize();
std::ofstream file;
Expand Down
7 changes: 3 additions & 4 deletions src/usb_cam_node.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -73,7 +73,7 @@ class UsbCamNode {
UsbCamNode() : m_node("~") {
// Setup the network that outputs derivates of the image captured
// TODO: Actual network
networks.push_back(std::make_unique<DepthAnythingV2>("depth_anything_v2_vitb.onnx"));
networks.push_back(std::make_unique<DepthAnythingV2>(&m_node, "depth_anything_v2_vitb.onnx"));

// Advertise the main image topic
image_transport::ImageTransport it(m_node);
Expand Down Expand Up @@ -180,9 +180,8 @@ class UsbCamNode {
// Run all the networks
for (const auto& net : networks) {
net->set_input(m_image);
if (net->predict()) {
net->publish();
}
net->predict();
net->publish();
}

return true;
Expand Down
6 changes: 1 addition & 5 deletions test/test_depth_anything_v2.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@
// This class provides access to protected members that we normally don't want to expose
class DepthAnythingV2Test : public DepthAnythingV2 {
public:
DepthAnythingV2Test(const std::string& model_path) : DepthAnythingV2(model_path) {}
DepthAnythingV2Test(const std::string& model_path) : DepthAnythingV2(nullptr, model_path) {}
float* get_input_data() { return _input_data; }
};

Expand Down Expand Up @@ -59,11 +59,7 @@ TEST_F(TestDepthAnythingV2, TestSetInput) {
}

TEST_F(TestDepthAnythingV2, TestPredict) {
auto start = std::chrono::high_resolution_clock::now();
for (size_t i = 0; i < 10; i++) {
depth_anything_v2->predict();
}
auto end = std::chrono::high_resolution_clock::now();
std::chrono::duration<double> elapsed = end - start;
std::cout << "Time taken for 10 predictions: " << elapsed.count() << " seconds." << std::endl;
}

0 comments on commit 8683434

Please sign in to comment.