Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

fix: default predicate value on construction #159

Merged
merged 4 commits into from
Oct 20, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 4 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,10 @@ Release Versions:
- [2.1.1](#211)
- [2.1.0](#210)

## Upcoming changes

- fix: default predicate value on construction (#158)

## 5.0.0

### October 9th, 2024
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -834,7 +834,7 @@ def _publish_predicates(self):
"""
Helper function to publish all predicates.
"""
message = copy.copy(self.__predicate_message)
message = copy.deepcopy(self.__predicate_message)
for name in self._predicates.keys():
new_value = self._predicates[name].query()
if new_value is not None:
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -83,4 +83,14 @@ class MinimalTwistInput : public ComponentT {
private:
std::promise<void> received_;
};

template<class ComponentT>
class MinimalTrigger : public ComponentT {
public:
MinimalTrigger(const rclcpp::NodeOptions& node_options) : ComponentT(node_options, "trigger") {
this->add_trigger("test");
}

void trigger() { ComponentT::trigger("test"); }
};
}// namespace modulo_components
Original file line number Diff line number Diff line change
Expand Up @@ -7,15 +7,6 @@

using namespace modulo_components;

class Trigger : public ComponentPublicInterface {
public:
explicit Trigger(const rclcpp::NodeOptions& node_options) : ComponentPublicInterface(node_options, "trigger") {
this->add_trigger("test");
}

void trigger() { Component::trigger("test"); }
};

class ComponentCommunicationTest : public ::testing::Test {
protected:
void SetUp() override {
Expand Down Expand Up @@ -73,7 +64,7 @@ TEST_F(ComponentCommunicationTest, TwistInputOutput) {
}

TEST_F(ComponentCommunicationTest, Trigger) {
auto trigger = std::make_shared<Trigger>(rclcpp::NodeOptions());
auto trigger = std::make_shared<MinimalTrigger<ComponentPublicInterface>>(rclcpp::NodeOptions());
auto listener =
std::make_shared<modulo_utils::testutils::PredicatesListener>("/trigger", std::vector<std::string>{"test"});
this->exec_->add_node(listener);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -7,22 +7,6 @@

using namespace modulo_components;

class LifecycleTrigger : public LifecycleComponentPublicInterface {
public:
explicit LifecycleTrigger(const rclcpp::NodeOptions& node_options)
: LifecycleComponentPublicInterface(node_options, "trigger") {}

bool on_configure_callback() final {
this->add_trigger("test");
return true;
}

bool on_activate_callback() final {
this->trigger("test");
return true;
}
};

class LifecycleComponentCommunicationTest : public ::testing::Test {
protected:
void SetUp() override {
Expand Down Expand Up @@ -80,16 +64,17 @@ TEST_F(LifecycleComponentCommunicationTest, TwistInputOutput) {
}

TEST_F(LifecycleComponentCommunicationTest, Trigger) {
auto trigger = std::make_shared<LifecycleTrigger>(rclcpp::NodeOptions());
auto trigger = std::make_shared<MinimalTrigger<LifecycleComponentPublicInterface>>(rclcpp::NodeOptions());
auto listener =
std::make_shared<modulo_utils::testutils::PredicatesListener>("/trigger", std::vector<std::string>{"test"});
this->exec_->add_node(trigger->get_node_base_interface());
trigger->configure();
trigger->activate();
this->exec_->add_node(listener);
auto result_code = this->exec_->spin_until_future_complete(listener->get_predicate_future(), 500ms);
ASSERT_EQ(result_code, rclcpp::FutureReturnCode::TIMEOUT);
EXPECT_FALSE(listener->get_predicate_values().at("test"));
trigger->activate();
trigger->trigger();
result_code = this->exec_->spin_until_future_complete(listener->get_predicate_future(), 500ms);
ASSERT_EQ(result_code, rclcpp::FutureReturnCode::SUCCESS);
EXPECT_TRUE(listener->get_predicate_values().at("test"));
Expand Down
10 changes: 10 additions & 0 deletions source/modulo_components/test/python/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,16 @@ def random_sensor():
return msg


@pytest.fixture
def make_minimal_trigger():
def _make_minimal_trigger(component_type):
component = component_type("trigger")
component.add_trigger("test")
return component

yield _make_minimal_trigger


@pytest.fixture
def minimal_cartesian_output(request, random_pose):
def _make_minimal_cartesian_output(component_type, topic, publish_on_step):
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -3,15 +3,6 @@
from modulo_core.exceptions import CoreError


class Trigger(Component):
def __init__(self):
super().__init__("trigger")
self.add_trigger("test")

def trigger(self):
super().trigger("test")


@pytest.mark.parametrize("minimal_cartesian_input", [[Component, "/topic"]], indirect=True)
@pytest.mark.parametrize("minimal_cartesian_output", [[Component, "/topic", True]], indirect=True)
def test_input_output(ros_exec, random_pose, minimal_cartesian_output, minimal_cartesian_input):
Expand Down Expand Up @@ -71,15 +62,15 @@ def test_input_output_invalid_msg(ros_exec, make_minimal_invalid_encoded_state_p
assert not minimal_cartesian_input.received_future.result()


def test_trigger(ros_exec, make_predicates_listener):
trigger = Trigger()
def test_trigger(ros_exec, make_predicates_listener, make_minimal_trigger):
trigger = make_minimal_trigger(Component)
listener = make_predicates_listener("/trigger", ["test"])
ros_exec.add_node(listener)
ros_exec.add_node(trigger)
ros_exec.spin_until_future_complete(listener.predicates_future, timeout_sec=0.5)
assert not listener.predicates_future.done()
assert not listener.predicate_values["test"]
trigger.trigger()
trigger.trigger("test")
ros_exec.spin_until_future_complete(listener.predicates_future, timeout_sec=0.5)
assert listener.predicates_future.done()
assert listener.predicate_values["test"]
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -84,18 +84,19 @@ def test_input_output_invalid_msg(ros_exec, make_lifecycle_change_client, make_m
assert not minimal_cartesian_input.received_future.result()


def test_trigger(ros_exec, make_lifecycle_change_client, make_predicates_listener):
trigger = Trigger()
def test_trigger(ros_exec, make_lifecycle_change_client, make_predicates_listener, make_minimal_trigger):
trigger = make_minimal_trigger(LifecycleComponent)
listener = make_predicates_listener("/trigger", ["test"])
client = make_lifecycle_change_client("trigger")
ros_exec.add_node(trigger)
ros_exec.add_node(listener)
ros_exec.add_node(client)
client.configure(ros_exec)
client.activate(ros_exec)
ros_exec.spin_until_future_complete(listener.predicates_future, timeout_sec=0.5)
assert not listener.predicates_future.done()
assert not listener.predicate_values["test"]
client.activate(ros_exec)
trigger.trigger("test")
ros_exec.spin_until_future_complete(listener.predicates_future, timeout_sec=0.5)
assert listener.predicates_future.done()
assert listener.predicate_values["test"]
Expand Down
8 changes: 3 additions & 5 deletions source/modulo_core/include/modulo_core/Predicate.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -12,16 +12,14 @@ namespace modulo_core {
*/
class Predicate {
public:
explicit Predicate(const std::function<bool(void)>& predicate_function) : predicate_(std::move(predicate_function)) {
previous_value_ = !predicate_();
}
explicit Predicate(const std::function<bool(void)>& predicate_function) : predicate_(std::move(predicate_function)) {}

bool get_value() const { return predicate_(); }

void set_predicate(const std::function<bool(void)>& predicate_function) { predicate_ = predicate_function; }

std::optional<bool> query() {
if (const auto new_value = predicate_(); new_value != previous_value_) {
if (const auto new_value = predicate_(); !previous_value_ || new_value != *previous_value_) {
previous_value_ = new_value;
return new_value;
}
Expand All @@ -30,7 +28,7 @@ class Predicate {

private:
std::function<bool(void)> predicate_;
bool previous_value_;
std::optional<bool> previous_value_;
};

}// namespace modulo_core
2 changes: 1 addition & 1 deletion source/modulo_core/modulo_core/predicate.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@ def __init__(self, predicate_function: Callable[[], bool]):
value = predicate_function()
if not isinstance(value, bool):
raise CoreError("Predicate function does not return a bool")
self.__previous_value = not value
self.__previous_value: Optional[bool] = None

def get_value(self) -> bool:
return self.__predicate()
Expand Down
9 changes: 9 additions & 0 deletions source/modulo_core/test/cpp/test_predicate.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -25,3 +25,12 @@ TEST(PredicateTest, SimplePredicate) {
value = predicate.query();
EXPECT_FALSE(value);
}

TEST(PredicateTest, ChangeBeforeQuery) {
auto predicate = Predicate([]() { return true; });
predicate.set_predicate([]() { return false; });
EXPECT_FALSE(predicate.get_value());
auto value = predicate.query();
EXPECT_TRUE(value);
EXPECT_FALSE(*value);
}
23 changes: 23 additions & 0 deletions source/modulo_core/test/python/test_predicate.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,23 @@
from modulo_core import Predicate


def test_simple_predicate():
predicate = Predicate(lambda: True)
assert predicate.get_value()
assert predicate.query()
assert predicate.get_value()
assert predicate.query() is None

predicate.set_predicate(lambda: False)
assert predicate.get_value() is False
assert predicate.query() is False
assert predicate.get_value() is False
assert predicate.query() is None


def test_predicate_change_before_query():
predicate = Predicate(lambda: False)
predicate.set_predicate(lambda: True)
assert predicate.get_value()
assert predicate.query()
assert predicate.query() is None
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,12 @@ def predicate_values(self) -> dict:
The values of the predicates
"""
return self.__predicates

def reset_future(self):
"""
Reset the future
"""
self.__future.set_result(False)

def __callback(self, message):
if message.node == self.__component:
Expand Down
14 changes: 10 additions & 4 deletions source/modulo_utils/test/python/test_encoded_state_recorder.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,5 @@
import os
import shutil
import time

import clproto
Expand All @@ -7,20 +9,24 @@
read_encoded_state_recording,
read_recording_directory)

directory = "/tmp/recording"


def test_encoded_state_recorder():
current_time = time.time()
if os.path.exists(directory) and os.path.isdir(directory):
shutil.rmtree(directory)
current_time = f"{time.time()}"
random_state = sr.CartesianState().Random("test")
msg = EncodedState()
msg.data = clproto.encode(random_state, clproto.MessageType.CARTESIAN_STATE_MESSAGE)
with EncodedStateRecorder(f"/tmp/{current_time}") as rec:
with EncodedStateRecorder(os.path.join(directory, current_time)) as rec:
rec.write(msg)

data = read_encoded_state_recording(f"/tmp/{current_time}")
data = read_encoded_state_recording(os.path.join(directory, current_time))
assert data
assert len(data) == 1
assert data[0]["state"].get_name() == random_state.get_name()

full_data = read_recording_directory("/tmp")
full_data = read_recording_directory(directory)
assert len(full_data) == 1
assert full_data[f"{current_time}"]
Loading