Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat!: add support for custom inputs and outputs #145

Merged
merged 8 commits into from
Oct 4, 2024
Merged
2 changes: 1 addition & 1 deletion .vscode/c_cpp_properties.json
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@
],
"compilerPath": "/usr/bin/gcc",
"cStandard": "c17",
"cppStandard": "gnu++17",
"cppStandard": "c++20",
"intelliSenseMode": "linux-gcc-x64"
}
],
Expand Down
18 changes: 16 additions & 2 deletions .vscode/settings.json
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,20 @@
"editor.rulers": [
120
],
"autopep8.args": ["--max-line-length", "120", "--experimental"],
"pylint.args": ["--generate-members", "--max-line-length", "120", "-d", "C0114", "-d", "C0115", "-d", "C0116"]
"autopep8.args": [
"--max-line-length",
"120",
"--experimental"
],
"pylint.args": [
"--generate-members",
"--max-line-length",
"120",
"-d",
"C0114",
"-d",
"C0115",
"-d",
"C0116"
],
}
1 change: 1 addition & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -45,6 +45,7 @@ Release Versions:
- chore: format repository (#142)
- docs: update schema path in component descriptions (#154)
- feat(utils): add binary reader and recorder for encoded states (#152)
- feat!: add support for custom inputs and outputs (#133)

## 4.2.2

Expand Down
4 changes: 2 additions & 2 deletions source/modulo_components/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -6,9 +6,9 @@ if(NOT CMAKE_C_STANDARD)
set(CMAKE_C_STANDARD 99)
endif()

# default to C++17
# default to C++20
if(NOT CMAKE_CXX_STANDARD)
set(CMAKE_CXX_STANDARD 17)
set(CMAKE_CXX_STANDARD 20)
endif()

if(CMAKE_COMPILER_IS_GNUCXX OR CMAKE_CXX_COMPILER_ID MATCHES "Clang")
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -155,6 +155,15 @@ inline void Component::add_output(
->create_publisher_interface(message_pair);
break;
}
case MessageType::CUSTOM_MESSAGE: {
if constexpr (modulo_core::concepts::CustomT<DataT>) {
auto publisher = this->create_publisher<DataT>(topic_name, this->get_qos());
this->outputs_.at(parsed_signal_name) =
std::make_shared<PublisherHandler<rclcpp::Publisher<DataT>, DataT>>(PublisherType::PUBLISHER, publisher)
->create_publisher_interface(message_pair);
}
break;
}
}
} catch (const std::exception& ex) {
RCLCPP_ERROR_STREAM(this->get_logger(), "Failed to add output '" << signal_name << "': " << ex.what());
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
#include <modulo_core/communication/PublisherHandler.hpp>
#include <modulo_core/communication/PublisherType.hpp>
#include <modulo_core/communication/SubscriptionHandler.hpp>
#include <modulo_core/concepts.hpp>
#include <modulo_core/exceptions.hpp>
#include <modulo_core/translators/parameter_translators.hpp>

Expand Down Expand Up @@ -646,54 +647,70 @@ inline void ComponentInterface::add_input(
std::shared_ptr<SubscriptionInterface> subscription_interface;
switch (message_pair->get_type()) {
case MessageType::BOOL: {
auto subscription_handler = std::make_shared<SubscriptionHandler<std_msgs::msg::Bool>>(message_pair);
auto subscription_handler =
std::make_shared<SubscriptionHandler<std_msgs::msg::Bool>>(message_pair, this->node_logging_->get_logger());
auto subscription = rclcpp::create_subscription<std_msgs::msg::Bool>(
this->node_parameters_, this->node_topics_, topic_name, this->qos_,
subscription_handler->get_callback(user_callback));
subscription_interface = subscription_handler->create_subscription_interface(subscription);
break;
}
case MessageType::FLOAT64: {
auto subscription_handler = std::make_shared<SubscriptionHandler<std_msgs::msg::Float64>>(message_pair);
auto subscription_handler = std::make_shared<SubscriptionHandler<std_msgs::msg::Float64>>(
message_pair, this->node_logging_->get_logger());
auto subscription = rclcpp::create_subscription<std_msgs::msg::Float64>(
this->node_parameters_, this->node_topics_, topic_name, this->qos_,
subscription_handler->get_callback(user_callback));
subscription_interface = subscription_handler->create_subscription_interface(subscription);
break;
}
case MessageType::FLOAT64_MULTI_ARRAY: {
auto subscription_handler =
std::make_shared<SubscriptionHandler<std_msgs::msg::Float64MultiArray>>(message_pair);
auto subscription_handler = std::make_shared<SubscriptionHandler<std_msgs::msg::Float64MultiArray>>(
message_pair, this->node_logging_->get_logger());
auto subscription = rclcpp::create_subscription<std_msgs::msg::Float64MultiArray>(
this->node_parameters_, this->node_topics_, topic_name, this->qos_,
subscription_handler->get_callback(user_callback));
subscription_interface = subscription_handler->create_subscription_interface(subscription);
break;
}
case MessageType::INT32: {
auto subscription_handler = std::make_shared<SubscriptionHandler<std_msgs::msg::Int32>>(message_pair);
auto subscription_handler = std::make_shared<SubscriptionHandler<std_msgs::msg::Int32>>(
message_pair, this->node_logging_->get_logger());
auto subscription = rclcpp::create_subscription<std_msgs::msg::Int32>(
this->node_parameters_, this->node_topics_, topic_name, this->qos_,
subscription_handler->get_callback(user_callback));
subscription_interface = subscription_handler->create_subscription_interface(subscription);
break;
}
case MessageType::STRING: {
auto subscription_handler = std::make_shared<SubscriptionHandler<std_msgs::msg::String>>(message_pair);
auto subscription_handler = std::make_shared<SubscriptionHandler<std_msgs::msg::String>>(
message_pair, this->node_logging_->get_logger());
auto subscription = rclcpp::create_subscription<std_msgs::msg::String>(
this->node_parameters_, this->node_topics_, topic_name, this->qos_,
subscription_handler->get_callback(user_callback));
subscription_interface = subscription_handler->create_subscription_interface(subscription);
break;
}
case MessageType::ENCODED_STATE: {
auto subscription_handler = std::make_shared<SubscriptionHandler<modulo_core::EncodedState>>(message_pair);
auto subscription_handler = std::make_shared<SubscriptionHandler<modulo_core::EncodedState>>(
message_pair, this->node_logging_->get_logger());
auto subscription = rclcpp::create_subscription<modulo_core::EncodedState>(
this->node_parameters_, this->node_topics_, topic_name, this->qos_,
subscription_handler->get_callback(user_callback));
subscription_interface = subscription_handler->create_subscription_interface(subscription);
break;
}
case MessageType::CUSTOM_MESSAGE: {
if constexpr (modulo_core::concepts::CustomT<DataT>) {
auto subscription_handler =
std::make_shared<SubscriptionHandler<DataT>>(message_pair, this->node_logging_->get_logger());
auto subscription = rclcpp::create_subscription<DataT>(
this->node_parameters_, this->node_topics_, topic_name, this->qos_,
subscription_handler->get_callback(user_callback));
subscription_interface = subscription_handler->create_subscription_interface(subscription);
}
break;
}
}
this->inputs_.insert_or_assign(parsed_signal_name, subscription_interface);
} catch (const std::exception& ex) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -273,6 +273,11 @@ class LifecycleComponent : public rclcpp_lifecycle::LifecycleNode, public Compon
using ComponentInterface::publish_outputs;
using ComponentInterface::publish_predicates;
using rclcpp_lifecycle::LifecycleNode::get_parameter;

std::map<
std::string,
std::function<std::shared_ptr<modulo_core::communication::PublisherInterface>(const std::string& topic_name)>>
custom_output_configuration_callables_;///< Map of custom output configuration callables
};

template<typename DataT>
Expand All @@ -287,9 +292,24 @@ inline void LifecycleComponent::add_output(
return;
}
try {
this->create_output(
modulo_core::communication::PublisherType::LIFECYCLE_PUBLISHER, signal_name, data, default_topic, fixed_topic,
publish_on_step);
using modulo_core::communication::PublisherHandler;
using modulo_core::communication::PublisherType;

auto parsed_signal_name = this->create_output(
PublisherType::LIFECYCLE_PUBLISHER, signal_name, data, default_topic, fixed_topic, publish_on_step);

auto message_pair = this->outputs_.at(parsed_signal_name)->get_message_pair();
if (message_pair->get_type() == modulo_core::communication::MessageType::CUSTOM_MESSAGE) {
if constexpr (modulo_core::concepts::CustomT<DataT>) {
this->custom_output_configuration_callables_.insert_or_assign(
parsed_signal_name, [this, message_pair](const std::string& topic_name) {
auto publisher = this->create_publisher<DataT>(topic_name, this->get_qos());
return std::make_shared<PublisherHandler<rclcpp_lifecycle::LifecyclePublisher<DataT>, DataT>>(
PublisherType::LIFECYCLE_PUBLISHER, publisher)
->create_publisher_interface(message_pair);
});
}
}
} catch (const modulo_core::exceptions::AddSignalException& ex) {
RCLCPP_ERROR_STREAM(this->get_logger(), "Failed to add output '" << signal_name << "': " << ex.what());
}
Expand Down
47 changes: 37 additions & 10 deletions source/modulo_components/modulo_components/component_interface.py
Original file line number Diff line number Diff line change
Expand Up @@ -448,6 +448,11 @@ def _create_output(self, signal_name: str, data: str, message_type: MsgT, clprot
elif message_type == EncodedState:
translator = partial(modulo_writers.write_clproto_message,
clproto_message_type=clproto_message_type)
elif hasattr(message_type, 'get_fields_and_field_types'):
def write_ros_msg(message, data):
for field in message.get_fields_and_field_types().keys():
setattr(message, field, getattr(data, field))
translator = write_ros_msg
else:
raise AddSignalError("The provided message type is not supported to create a component output.")
self._outputs[parsed_signal_name] = {"attribute": data, "message_type": message_type,
Expand All @@ -469,7 +474,17 @@ def remove_input(self, signal_name: str):
return
self.get_logger().debug(f"Removing signal '{signal_name}'.")

def __subscription_callback(self, message: MsgT, attribute_name: str, reader: Callable, user_callback: Callable):
def __read_translated_message(self, message: MsgT, attribute_name: str, reader: Callable):
obj_type = type(self.__getattribute__(attribute_name))
decoded_message = reader(message)
self.__setattr__(attribute_name, obj_type(decoded_message))

def __read_custom_message(self, message: MsgT, attribute_name: str):
for field in message.get_fields_and_field_types().keys():
setattr(self.__getattribute__(attribute_name), field, getattr(message, field))

def __subscription_callback(
self, message: MsgT, attribute_name: str, read_message: Callable, user_callback: Callable):
"""
Subscription callback for the ROS subscriptions.

Expand All @@ -478,9 +493,7 @@ def __subscription_callback(self, message: MsgT, attribute_name: str, reader: Ca
:param reader: A callable that can read the ROS message and translate to the desired type
"""
try:
obj_type = type(self.__getattribute__(attribute_name))
decoded_message = reader(message)
self.__setattr__(attribute_name, obj_type(decoded_message))
read_message(message, attribute_name)
except (AttributeError, MessageTranslationError, TypeError) as e:
self.get_logger().warn(f"Failed to read message for attribute {attribute_name}: {e}",
throttle_duration_sec=1.0)
Expand All @@ -491,7 +504,7 @@ def __subscription_callback(self, message: MsgT, attribute_name: str, reader: Ca
self.get_logger().error(f"Failed to execute user callback in subscription for attribute"
f" '{attribute_name}': {e}", throttle_duration_sec=1.0)

def declare_signal(self, signal_name: str, signal_type: str, default_topic="", fixed_topic=False):
def __declare_signal(self, signal_name: str, signal_type: str, default_topic="", fixed_topic=False):
"""
Declare an input to create the topic parameter without adding it to the map of inputs yet.

Expand All @@ -505,7 +518,9 @@ def declare_signal(self, signal_name: str, signal_type: str, default_topic="", f
if not parsed_signal_name:
raise AddSignalError(topic_validation_warning(signal_name, signal_type))
if signal_name != parsed_signal_name:
self.get_logger().warn(topic_validation_warning(signal_name, signal_type))
self.get_logger().warn(
f"The parsed name for {signal_type} '{signal_name}' is '{parsed_signal_name}'."
"Use the parsed name to refer to this {signal_type}.")
if parsed_signal_name in self._inputs.keys():
raise AddSignalError(f"Signal with name '{parsed_signal_name}' already exists as input.")
if parsed_signal_name in self._outputs.keys():
Expand All @@ -529,7 +544,7 @@ def declare_input(self, signal_name: str, default_topic="", fixed_topic=False):
:param fixed_topic: If true, the topic name of the signal is fixed
:raises AddSignalError: if the input could not be declared (empty name or already created)
"""
self.declare_signal(signal_name, "input", default_topic, fixed_topic)
self.__declare_signal(signal_name, "input", default_topic, fixed_topic)

def declare_output(self, signal_name: str, default_topic="", fixed_topic=False):
"""
Expand All @@ -540,7 +555,7 @@ def declare_output(self, signal_name: str, default_topic="", fixed_topic=False):
:param fixed_topic: If true, the topic name of the signal is fixed
:raises AddSignalError: if the output could not be declared (empty name or already created)
"""
self.declare_signal(signal_name, "output", default_topic, fixed_topic)
self.__declare_signal(signal_name, "output", default_topic, fixed_topic)

def add_input(self, signal_name: str, subscription: Union[str, Callable], message_type: MsgT, default_topic="",
fixed_topic=False, user_callback: Callable = None):
Expand Down Expand Up @@ -581,19 +596,31 @@ def default_callback():
user_callback = default_callback
if message_type == Bool or message_type == Float64 or \
message_type == Float64MultiArray or message_type == Int32 or message_type == String:
read_message = partial(self.__read_translated_message,
reader=modulo_readers.read_std_message)
self._inputs[parsed_signal_name] = \
self.create_subscription(message_type, topic_name,
partial(self.__subscription_callback,
attribute_name=subscription,
reader=modulo_readers.read_std_message,
read_message=read_message,
user_callback=user_callback),
self._qos)
elif message_type == EncodedState:
read_message = partial(self.__read_translated_message,
reader=modulo_readers.read_clproto_message)
self._inputs[parsed_signal_name] = \
self.create_subscription(message_type, topic_name,
partial(self.__subscription_callback,
attribute_name=subscription,
read_message=read_message,
user_callback=user_callback),
self._qos)
elif hasattr(message_type, 'get_fields_and_field_types'):
self._inputs[parsed_signal_name] = \
self.create_subscription(message_type, topic_name,
partial(self.__subscription_callback,
attribute_name=subscription,
reader=modulo_readers.read_clproto_message,
read_message=self.__read_custom_message,
user_callback=user_callback),
self._qos)
else:
Expand Down
4 changes: 3 additions & 1 deletion source/modulo_components/src/ComponentInterface.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -303,7 +303,9 @@ void ComponentInterface::declare_signal(
}
if (signal_name != parsed_signal_name) {
RCLCPP_WARN_STREAM(
this->node_logging_->get_logger(), modulo_utils::parsing::topic_validation_warning(signal_name, type));
this->node_logging_->get_logger(),
"The parsed name for " + type + " '" + signal_name + "' is '" + parsed_signal_name
+ "'. Use the parsed name to refer to this " + type);
}
if (this->inputs_.find(parsed_signal_name) != this->inputs_.cend()) {
throw exceptions::AddSignalException("Signal with name '" + parsed_signal_name + "' already exists as input.");
Expand Down
4 changes: 4 additions & 0 deletions source/modulo_components/src/LifecycleComponent.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -292,6 +292,10 @@ bool LifecycleComponent::configure_outputs() {
->create_publisher_interface(message_pair);
break;
}
case MessageType::CUSTOM_MESSAGE: {
interface = this->custom_output_configuration_callables_.at(name)(topic_name);
break;
}
}
} catch (const modulo_core::exceptions::CoreException& ex) {
success = false;
Expand Down
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
#pragma once

#include <geometry_msgs/msg/twist.hpp>
#include <rclcpp/rclcpp.hpp>
#include <state_representation/space/cartesian/CartesianState.hpp>

Expand Down Expand Up @@ -47,6 +48,38 @@ class MinimalCartesianInput : public ComponentT {
std::shared_ptr<CartesianState> input;
std::shared_future<void> received_future;

private:
std::promise<void> received_;
};

template<class ComponentT>
class MinimalTwistOutput : public ComponentT {
public:
MinimalTwistOutput(
const rclcpp::NodeOptions& node_options, const std::string& topic,
std::shared_ptr<geometry_msgs::msg::Twist> twist, bool publish_on_step)
: ComponentT(node_options, "minimal_twist_output"), output_(twist) {
this->add_output("twist", this->output_, topic, true, publish_on_step);
}

void publish() { this->publish_output("twist"); }

private:
std::shared_ptr<geometry_msgs::msg::Twist> output_;
};

template<class ComponentT>
class MinimalTwistInput : public ComponentT {
public:
MinimalTwistInput(const rclcpp::NodeOptions& node_options, const std::string& topic)
: ComponentT(node_options, "minimal_twist_input"), input(std::make_shared<geometry_msgs::msg::Twist>()) {
this->received_future = this->received_.get_future();
this->add_input("twist", this->input, [this]() { this->received_.set_value(); }, topic);
}

std::shared_ptr<geometry_msgs::msg::Twist> input;
std::shared_future<void> received_future;

private:
std::promise<void> received_;
};
Expand Down
Loading
Loading