diff --git a/include/usb_cam/learning/depth_anything_v2.hpp b/include/usb_cam/learning/depth_anything_v2.hpp index 4182b870..e4a2f84e 100644 --- a/include/usb_cam/learning/depth_anything_v2.hpp +++ b/include/usb_cam/learning/depth_anything_v2.hpp @@ -2,19 +2,23 @@ #define DEPTH_ANYTHING_HPP_ #include "interface.hpp" +#include "ros/ros.h" #include #include #include class DepthAnythingV2 : public LearningInterface { public: - DepthAnythingV2(std::string model_path) { - _model_path = model_path; + DepthAnythingV2(ros::NodeHandle* nh, std::string model_path) { _INPUT_SIZE = cv::Size(_HEIGHT, _WIDTH); + _model_path = model_path; + + if (nh != nullptr) { + _depth_publication = nh->advertise("depth_anything_v2", 1); + } } void set_input(sensor_msgs::Image& msg) override { - // From ROS msg image to cv mat cv_bridge::CvImagePtr cv_ptr = cv_bridge::toCvCopy(msg, sensor_msgs::image_encodings::RGB8); cv::Mat image = cv_ptr->image; @@ -33,18 +37,25 @@ class DepthAnythingV2 : public LearningInterface { _input_data = float_image.reshape(1, 1).ptr(0); } - void get_output(uint8_t* output_buffer) override { - // TODO - } - void publish() override { - // TODO + cv::Mat depth_prediction = cv::Mat(_HEIGHT, _WIDTH, CV_32FC1, _output_data); + + cv_bridge::CvImage depth_image; + depth_image.header.stamp = ros::Time::now(); // Set the timestamp + depth_image.header.frame_id = "depth_frame"; // Set the frame ID (update as needed) + depth_image.encoding = sensor_msgs::image_encodings::TYPE_32FC1; // Depth is typically float32, single channel + depth_image.image = depth_prediction; + + if (_depth_publication.getTopic() != "") { + _depth_publication.publish(depth_image.toImageMsg()); + } } private: const size_t _HEIGHT = 518; const size_t _WIDTH = 518; cv::Size _INPUT_SIZE; + ros::Publisher _depth_publication; }; #endif // DEPTH_ANYTHING_HPP_ diff --git a/include/usb_cam/learning/interface.hpp b/include/usb_cam/learning/interface.hpp index 4e1f21f2..8d1d1e0d 100644 --- a/include/usb_cam/learning/interface.hpp +++ b/include/usb_cam/learning/interface.hpp @@ -16,7 +16,6 @@ class LearningInterface { LearningInterface() : _model_path("") {} virtual void set_input(sensor_msgs::Image& image) = 0; - virtual void get_output(uint8_t* output_buffer) = 0; virtual void publish() = 0; void load_model(); diff --git a/src/interface.cpp b/src/interface.cpp index 7e59bf27..2e64109f 100644 --- a/src/interface.cpp +++ b/src/interface.cpp @@ -61,7 +61,6 @@ void LearningInterface::_build(std::string onnx_path) { } bool LearningInterface::_save_engine(const std::string& onnx_path) { - // Create an engine path from onnx path std::string engine_path; size_t dot_index = onnx_path.find_last_of("."); if (dot_index != std::string::npos) { @@ -71,7 +70,6 @@ bool LearningInterface::_save_engine(const std::string& onnx_path) { return false; } - // Save the engine to the path if (_engine) { nvinfer1::IHostMemory* data = _engine->serialize(); std::ofstream file; diff --git a/src/usb_cam_node.cpp b/src/usb_cam_node.cpp index 3094ac94..517a43e7 100644 --- a/src/usb_cam_node.cpp +++ b/src/usb_cam_node.cpp @@ -73,7 +73,7 @@ class UsbCamNode { UsbCamNode() : m_node("~") { // Setup the network that outputs derivates of the image captured // TODO: Actual network - networks.push_back(std::make_unique("depth_anything_v2_vitb.onnx")); + networks.push_back(std::make_unique(&m_node, "depth_anything_v2_vitb.onnx")); // Advertise the main image topic image_transport::ImageTransport it(m_node); @@ -180,9 +180,8 @@ class UsbCamNode { // Run all the networks for (const auto& net : networks) { net->set_input(m_image); - if (net->predict()) { - net->publish(); - } + net->predict(); + net->publish(); } return true; diff --git a/test/test_depth_anything_v2.cpp b/test/test_depth_anything_v2.cpp index c251b4e0..fc400a8c 100644 --- a/test/test_depth_anything_v2.cpp +++ b/test/test_depth_anything_v2.cpp @@ -9,7 +9,7 @@ // This class provides access to protected members that we normally don't want to expose class DepthAnythingV2Test : public DepthAnythingV2 { public: - DepthAnythingV2Test(const std::string& model_path) : DepthAnythingV2(model_path) {} + DepthAnythingV2Test(const std::string& model_path) : DepthAnythingV2(nullptr, model_path) {} float* get_input_data() { return _input_data; } }; @@ -59,11 +59,7 @@ TEST_F(TestDepthAnythingV2, TestSetInput) { } TEST_F(TestDepthAnythingV2, TestPredict) { - auto start = std::chrono::high_resolution_clock::now(); for (size_t i = 0; i < 10; i++) { depth_anything_v2->predict(); } - auto end = std::chrono::high_resolution_clock::now(); - std::chrono::duration elapsed = end - start; - std::cout << "Time taken for 10 predictions: " << elapsed.count() << " seconds." << std::endl; }