diff --git a/source/modulo_controllers/include/modulo_controllers/BaseControllerInterface.hpp b/source/modulo_controllers/include/modulo_controllers/BaseControllerInterface.hpp index 0b49f628..bde57816 100644 --- a/source/modulo_controllers/include/modulo_controllers/BaseControllerInterface.hpp +++ b/source/modulo_controllers/include/modulo_controllers/BaseControllerInterface.hpp @@ -1,5 +1,6 @@ #pragma once +#include #include #include @@ -22,6 +23,8 @@ #include +#include + namespace modulo_controllers { typedef std::variant< @@ -30,7 +33,7 @@ typedef std::variant< std::shared_ptr>, std::shared_ptr>, std::shared_ptr>, - std::shared_ptr>> + std::shared_ptr>, std::any> SubscriptionVariant; typedef std::variant< @@ -39,7 +42,7 @@ typedef std::variant< realtime_tools::RealtimeBuffer>, realtime_tools::RealtimeBuffer>, realtime_tools::RealtimeBuffer>, - realtime_tools::RealtimeBuffer>> + realtime_tools::RealtimeBuffer>, std::any> BufferVariant; typedef std::tuple< @@ -66,9 +69,11 @@ typedef std::pair< std::shared_ptr>, realtime_tools::RealtimePublisherSharedPtr> StringPublishers; +typedef std::pair CustomPublishers; typedef std::variant< - EncodedStatePublishers, BoolPublishers, DoublePublishers, DoubleVecPublishers, IntPublishers, StringPublishers> + EncodedStatePublishers, BoolPublishers, DoublePublishers, DoubleVecPublishers, IntPublishers, StringPublishers, + CustomPublishers> PublisherVariant; /** @@ -76,6 +81,7 @@ typedef std::variant< * @brief Input structure to save topic data in a realtime buffer and timestamps in one object. */ struct ControllerInput { + ControllerInput() = default; ControllerInput(BufferVariant buffer_variant) : buffer(std::move(buffer_variant)) {} BufferVariant buffer; std::chrono::time_point timestamp; @@ -471,6 +477,11 @@ class BaseControllerInterface : public controller_interface::ControllerInterface std::shared_ptr predicate_timer_; std::timed_mutex command_mutex_; + + std::map> + custom_output_configuration_callables_; + std::map> + custom_input_configuration_callables_; }; template @@ -515,11 +526,36 @@ inline void BaseControllerInterface::set_parameter_value(const std::string& name template inline void BaseControllerInterface::add_input(const std::string& name, const std::string& topic_name) { - auto buffer = realtime_tools::RealtimeBuffer>(); - auto input = ControllerInput(buffer); - create_input(input, name, topic_name); - input_message_pairs_.insert_or_assign( - name, modulo_core::communication::make_shared_message_pair(std::make_shared(), get_node()->get_clock())); + if constexpr (modulo_core::concepts::CustomT) { + auto buffer = std::make_shared>>(); + auto input = ControllerInput(buffer); + auto parsed_name = validate_and_declare_signal(name, "input", topic_name); + if (!parsed_name.empty()) { + inputs_.insert_or_assign(parsed_name, input); + custom_input_configuration_callables_.insert_or_assign( + name, [this](const std::string& name, const std::string& topic) { + auto subscription = + get_node()->create_subscription(topic, qos_, [this, name](const std::shared_ptr message) { + auto buffer_variant = std::get(inputs_.at(name).buffer); + auto buffer = std::any_cast>>>( + buffer_variant); + buffer->writeFromNonRT(message); + inputs_.at(name).timestamp = std::chrono::steady_clock::now(); + }); + subscriptions_.push_back(subscription); + }); + } + } else { + auto buffer = realtime_tools::RealtimeBuffer>(); + auto input = ControllerInput(buffer); + auto parsed_name = validate_and_declare_signal(name, "input", topic_name); + if (!parsed_name.empty()) { + inputs_.insert_or_assign(parsed_name, input); + input_message_pairs_.insert_or_assign( + parsed_name, + modulo_core::communication::make_shared_message_pair(std::make_shared(), get_node()->get_clock())); + } + } } template<> @@ -569,8 +605,22 @@ BaseControllerInterface::create_subscription(const std::string& name, const std: template inline void BaseControllerInterface::add_output(const std::string& name, const std::string& topic_name) { - std::shared_ptr state_ptr = std::make_shared(); - create_output(EncodedStatePublishers(state_ptr, {}, {}), name, topic_name); + if constexpr (modulo_core::concepts::CustomT) { + typedef std::pair>, realtime_tools::RealtimePublisherSharedPtr> PublisherT; + auto parsed_name = validate_and_declare_signal(name, "output", topic_name); + if (!parsed_name.empty()) { + outputs_.insert_or_assign(parsed_name, PublisherT()); + custom_output_configuration_callables_.insert_or_assign( + name, [this](CustomPublishers& pub, const std::string& topic) { + auto publisher = get_node()->create_publisher(topic, qos_); + pub.first = publisher; + pub.second = std::make_shared>(publisher); + }); + } + } else { + std::shared_ptr state_ptr = std::make_shared(); + create_output(EncodedStatePublishers(state_ptr, {}, {}), name, topic_name); + } } template<> @@ -604,33 +654,45 @@ inline std::optional BaseControllerInterface::read_input(const std::string& n if (!check_input_valid(name)) { return {}; } - auto message = - **std::get>>(inputs_.at(name).buffer) - .readFromNonRT(); - std::shared_ptr state; - try { - auto message_pair = input_message_pairs_.at(name); - message_pair->read(message); - state = message_pair->get_message_pair()->get_data(); - } catch (const std::exception& ex) { - RCLCPP_WARN_THROTTLE( - get_node()->get_logger(), *get_node()->get_clock(), 1000, - "Could not read EncodedState message on input '%s': %s", name.c_str(), ex.what()); - return {}; - } - if (state->is_empty()) { + + if constexpr (modulo_core::concepts::CustomT) { + try { + auto buffer_variant = std::get(inputs_.at(name).buffer); + auto buffer = std::any_cast>>>(buffer_variant); + return **(buffer->readFromNonRT()); + } catch (const std::bad_any_cast& ex) { + RCLCPP_ERROR(get_node()->get_logger(), "Failed to read custom input: %s", ex.what()); + } return {}; - } - auto cast_ptr = std::dynamic_pointer_cast(state); - if (cast_ptr != nullptr) { - return *cast_ptr; } else { - RCLCPP_WARN_THROTTLE( - get_node()->get_logger(), *get_node()->get_clock(), 1000, - "Dynamic cast of message on input '%s' from type '%s' to type '%s' failed.", name.c_str(), - get_state_type_name(state->get_type()).c_str(), get_state_type_name(T().get_type()).c_str()); + auto message = + **std::get>>(inputs_.at(name).buffer) + .readFromNonRT(); + std::shared_ptr state; + try { + auto message_pair = input_message_pairs_.at(name); + message_pair->read(message); + state = message_pair->get_message_pair()->get_data(); + } catch (const std::exception& ex) { + RCLCPP_WARN_THROTTLE( + get_node()->get_logger(), *get_node()->get_clock(), 1000, + "Could not read EncodedState message on input '%s': %s", name.c_str(), ex.what()); + return {}; + } + if (state->is_empty()) { + return {}; + } + auto cast_ptr = std::dynamic_pointer_cast(state); + if (cast_ptr != nullptr) { + return *cast_ptr; + } else { + RCLCPP_WARN_THROTTLE( + get_node()->get_logger(), *get_node()->get_clock(), 1000, + "Dynamic cast of message on input '%s' from type '%s' to type '%s' failed.", name.c_str(), + get_state_type_name(state->get_type()).c_str(), get_state_type_name(T().get_type()).c_str()); + } + return {}; } - return {}; } template<> @@ -689,44 +751,71 @@ inline std::optional BaseControllerInterface::read_input inline void BaseControllerInterface::write_output(const std::string& name, const T& data) { - if (data.is_empty()) { - RCLCPP_DEBUG_THROTTLE( - get_node()->get_logger(), *get_node()->get_clock(), 1000, - "Skipping publication of output '%s' due to emptiness of state", name.c_str()); - return; - } if (outputs_.find(name) == outputs_.end()) { RCLCPP_WARN_THROTTLE( get_node()->get_logger(), *get_node()->get_clock(), 1000, "Could not find output '%s'", name.c_str()); return; } - EncodedStatePublishers publishers; - try { - publishers = std::get(outputs_.at(name)); - } catch (const std::bad_variant_access&) { - RCLCPP_WARN_THROTTLE( - get_node()->get_logger(), *get_node()->get_clock(), 1000, - "Could not retrieve publisher for output '%s': Invalid output type", name.c_str()); - return; - } - if (const auto output_type = std::get<0>(publishers)->get_type(); output_type != data.get_type()) { - RCLCPP_WARN_THROTTLE( - get_node()->get_logger(), *get_node()->get_clock(), 1000, - "Skipping publication of output '%s' due to wrong data type (expected '%s', got '%s')", - state_representation::get_state_type_name(output_type).c_str(), - state_representation::get_state_type_name(data.get_type()).c_str(), name.c_str()); - return; - } - auto rt_pub = std::get<2>(publishers); - if (rt_pub && rt_pub->trylock()) { + + if constexpr (modulo_core::concepts::CustomT) { + CustomPublishers publishers; try { - modulo_core::translators::write_message(rt_pub->msg_, data, get_node()->get_clock()->now()); - } catch (const modulo_core::exceptions::MessageTranslationException& ex) { + publishers = std::get(outputs_.at(name)); + } catch (const std::bad_variant_access&) { + RCLCPP_WARN_THROTTLE( + get_node()->get_logger(), *get_node()->get_clock(), 1000, + "Could not retrieve publisher for output '%s': Invalid output type", name.c_str()); + return; + } + + std::shared_ptr> rt_pub; + try { + rt_pub = std::any_cast>>(publishers.second); + } catch (const std::bad_any_cast& ex) { RCLCPP_ERROR_THROTTLE( - get_node()->get_logger(), *get_node()->get_clock(), 1000, "Failed to publish output '%s': %s", name.c_str(), - ex.what()); + get_node()->get_logger(), *get_node()->get_clock(), 1000, + "Skipping publication of output '%s' due to wrong data type: %s", name.c_str(), ex.what()); + return; + } + if (rt_pub && rt_pub->trylock()) { + rt_pub->msg_ = data; + rt_pub->unlockAndPublish(); + } + } else { + if (data.is_empty()) { + RCLCPP_DEBUG_THROTTLE( + get_node()->get_logger(), *get_node()->get_clock(), 1000, + "Skipping publication of output '%s' due to emptiness of state", name.c_str()); + return; + } + EncodedStatePublishers publishers; + try { + publishers = std::get(outputs_.at(name)); + } catch (const std::bad_variant_access&) { + RCLCPP_WARN_THROTTLE( + get_node()->get_logger(), *get_node()->get_clock(), 1000, + "Could not retrieve publisher for output '%s': Invalid output type", name.c_str()); + return; + } + if (const auto output_type = std::get<0>(publishers)->get_type(); output_type != data.get_type()) { + RCLCPP_WARN_THROTTLE( + get_node()->get_logger(), *get_node()->get_clock(), 1000, + "Skipping publication of output '%s' due to wrong data type (expected '%s', got '%s')", + state_representation::get_state_type_name(output_type).c_str(), + state_representation::get_state_type_name(data.get_type()).c_str(), name.c_str()); + return; + } + auto rt_pub = std::get<2>(publishers); + if (rt_pub && rt_pub->trylock()) { + try { + modulo_core::translators::write_message(rt_pub->msg_, data, get_node()->get_clock()->now()); + } catch (const modulo_core::exceptions::MessageTranslationException& ex) { + RCLCPP_ERROR_THROTTLE( + get_node()->get_logger(), *get_node()->get_clock(), 1000, "Failed to publish output '%s': %s", name.c_str(), + ex.what()); + } + rt_pub->unlockAndPublish(); } - rt_pub->unlockAndPublish(); } } diff --git a/source/modulo_controllers/src/BaseControllerInterface.cpp b/source/modulo_controllers/src/BaseControllerInterface.cpp index 54b2600f..e07ef5e2 100644 --- a/source/modulo_controllers/src/BaseControllerInterface.cpp +++ b/source/modulo_controllers/src/BaseControllerInterface.cpp @@ -330,7 +330,7 @@ void BaseControllerInterface::create_input( const ControllerInput& input, const std::string& name, const std::string& topic_name) { auto parsed_name = validate_and_declare_signal(name, "input", topic_name); if (!parsed_name.empty()) { - inputs_.insert_or_assign(name, input); + inputs_.insert_or_assign(parsed_name, input); } } @@ -357,6 +357,9 @@ void BaseControllerInterface::add_inputs() { }, [&](const realtime_tools::RealtimeBuffer>&) { subscriptions_.push_back(create_subscription(name, topic)); + }, + [&](const std::any&) { + custom_input_configuration_callables_.at(name)(name, topic); }}, input.buffer); } catch (const std::exception& ex) { @@ -369,7 +372,7 @@ void BaseControllerInterface::create_output( const PublisherVariant& publishers, const std::string& name, const std::string& topic_name) { auto parsed_name = validate_and_declare_signal(name, "output", topic_name); if (!parsed_name.empty()) { - outputs_.insert_or_assign(name, publishers); + outputs_.insert_or_assign(parsed_name, publishers); } } @@ -403,10 +406,15 @@ void BaseControllerInterface::add_outputs() { [&](StringPublishers& pub) { pub.first = get_node()->create_publisher(topic, qos_); pub.second = std::make_shared>(pub.first); + }, + [&](CustomPublishers& pub) { + custom_output_configuration_callables_.at(name)(pub, name); }}, publishers); + } catch (const std::bad_any_cast& ex) { + RCLCPP_ERROR(get_node()->get_logger(), "Failed to add custom output '%s': %s", name.c_str(), ex.what()); } catch (const std::exception& ex) { - RCLCPP_ERROR(get_node()->get_logger(), "Failed to add input '%s': %s", name.c_str(), ex.what()); + RCLCPP_ERROR(get_node()->get_logger(), "Failed to add output '%s': %s", name.c_str(), ex.what()); } } } diff --git a/source/modulo_controllers/test/test_controller_interface.cpp b/source/modulo_controllers/test/test_controller_interface.cpp index 8fceda3d..67cf4f05 100644 --- a/source/modulo_controllers/test/test_controller_interface.cpp +++ b/source/modulo_controllers/test/test_controller_interface.cpp @@ -3,6 +3,7 @@ #include #include +#include #include #include @@ -26,6 +27,12 @@ class FriendControllerInterface : public ControllerInterface { } }; +sensor_msgs::msg::Image make_image_msg(double width) { + auto msg = sensor_msgs::msg::Image(); + msg.width = width; + return msg; +} + using BoolT = std::tuple; using DoubleT = std::tuple; using DoubleVecT = std::tuple, std_msgs::msg::Float64MultiArray>; @@ -33,6 +40,7 @@ using IntT = std::tuple; using StringT = std::tuple; using CartesianStateT = std::tuple; using JointStateT = std::tuple; +using ImageT = std::tuple; template T write_std_msg(const T& message_data) { @@ -49,6 +57,12 @@ T write_state_msg(const T& message_data) { return copy; } +ImageT write_image_msg(const ImageT& message_data) { + auto copy = message_data; + std::get<1>(copy) = std::get<0>(message_data); + return copy; +} + template T read_std_msg(const T& message_data) { auto copy = message_data; @@ -63,6 +77,12 @@ T read_state_msg(const T& message_data) { return copy; } +ImageT read_image_msg(const ImageT& message_data) { + auto copy = message_data; + std::get<0>(copy) = std::get<1>(message_data); + return copy; +} + template bool std_msg_equal(const T& sent, const T& received) { return std::get<0>(sent) == std::get<0>(received); @@ -74,12 +94,16 @@ bool encoded_state_equal(const T& sent, const T& received) { return equal && std::get<0>(sent).data().isApprox(std::get<0>(received).data()); } +bool sensor_msg_equal(const ImageT& sent, const ImageT& received) { + return std::get<0>(sent).width == std::get<0>(received).width; +} + template using SignalT = std::vector, std::function, std::function>>; static std::tuple< SignalT, SignalT, SignalT, SignalT, SignalT, SignalT, - SignalT> + SignalT, SignalT> signal_test_cases{ {std::make_tuple( std::make_tuple(true, std_msgs::msg::Bool()), write_std_msg, read_std_msg, @@ -100,7 +124,9 @@ static std::tuple< write_state_msg, read_state_msg, encoded_state_equal)}, {std::make_tuple( std::make_tuple(JointState::Random("test", 3), modulo_core::EncodedState()), write_state_msg, - read_state_msg, encoded_state_equal)}}; + read_state_msg, encoded_state_equal)}, + {std::make_tuple( + std::make_tuple(make_image_msg(1), make_image_msg(2)), write_image_msg, read_image_msg, sensor_msg_equal)}}; template class ControllerInterfaceTest : public ::testing::Test { @@ -177,7 +203,6 @@ TYPED_TEST_P(ControllerInterfaceTest, OutputTest) { for (auto [message_data, write_func, read_func, validation_func] : this->test_cases_) { auto data = std::get<0>(message_data); this->interface_->template write_output("output", data); - // rclcpp::spin_some(this->interface_->get_node()->get_node_base_interface()); auto return_code = rclcpp::spin_until_future_complete(test_node.get_node_base_interface(), test_node.get_sub_future(), 200ms); ASSERT_EQ(return_code, rclcpp::FutureReturnCode::SUCCESS); @@ -188,5 +213,5 @@ TYPED_TEST_P(ControllerInterfaceTest, OutputTest) { REGISTER_TYPED_TEST_CASE_P(ControllerInterfaceTest, ConfigureErrorTest, InputTest, OutputTest); -typedef ::testing::Types SignalTypes; +typedef ::testing::Types SignalTypes; INSTANTIATE_TYPED_TEST_CASE_P(TestPrefix, ControllerInterfaceTest, SignalTypes);