Skip to content

Commit

Permalink
Merge pull request #38 from brown-ccv/reformat-code
Browse files Browse the repository at this point in the history
reformat
  • Loading branch information
jashlu authored Oct 4, 2024
2 parents 959ab09 + 8717822 commit 1628d28
Show file tree
Hide file tree
Showing 5 changed files with 128 additions and 135 deletions.
8 changes: 5 additions & 3 deletions src/social_norms_trees/behavior_library.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,7 @@
class BehaviorLibrary:
def __init__(self, behavior_list):
self.behaviors = behavior_list
self.behavior_from_display_name = {behavior["display_name"]: behavior for behavior in behavior_list}
self.behavior_from_id = {behavior["id"]: behavior for behavior in behavior_list}
self.behaviors = behavior_list
self.behavior_from_display_name = {
behavior["display_name"]: behavior for behavior in behavior_list
}
self.behavior_from_id = {behavior["id"]: behavior for behavior in behavior_list}
3 changes: 1 addition & 2 deletions src/social_norms_trees/custom_node_library.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
import py_trees


class CustomBehavior(py_trees.behaviours.Dummy):
def __init__(self, name, id_, display_name):
super().__init__(name)
Expand All @@ -16,5 +17,3 @@ def __init__(self, name, id_, display_name, children=None, memory=False):
# id of the behavior within the behavior library (persists)
# but also the unique id for the behavior within the tree (in case there are multiple instances of
# the behavior in one tree)


116 changes: 51 additions & 65 deletions src/social_norms_trees/mutate_tree.py
Original file line number Diff line number Diff line change
Expand Up @@ -149,11 +149,15 @@ def format_children_with_indices(composite: py_trees.composites.Composite) -> st
output = label_tree_lines(composite, index_strings)
return output


def format_parents_with_indices(composite: py_trees.composites.Composite) -> str:
index_strings = []
i = 0
for b in iterate_nodes(composite):
if b.__class__.__name__ == "CustomSequence" or b.__class__.__name__ == "CustomSelector":
if (
b.__class__.__name__ == "CustomSequence"
or b.__class__.__name__ == "CustomSelector"
):
index_strings.append(str(i))
else:
index_strings.append("_")
Expand All @@ -164,9 +168,9 @@ def format_parents_with_indices(composite: py_trees.composites.Composite) -> str


def format_tree_with_indices(
tree: py_trees.behaviour.Behaviour,
show_root: bool = False,
) -> tuple[str, List[str]]:
tree: py_trees.behaviour.Behaviour,
show_root: bool = False,
) -> tuple[str, List[str]]:
"""
Examples:
>>> print(format_tree_with_indices(py_trees.behaviours.Dummy()))
Expand Down Expand Up @@ -199,12 +203,12 @@ def format_tree_with_indices(
index = 0
for i, node in enumerate_nodes(tree):
if i == 0 and not show_root:
index_strings.append('_')
index_strings.append("_")
else:
index_strings.append(str(index))
index += 1
output = label_tree_lines(tree, index_strings)

return output, index_strings[1:]


Expand All @@ -224,6 +228,7 @@ def prompt_identify_node(
node = next(islice(iterate_nodes(tree), node_index, node_index + 1))
return node


def prompt_identify_parent_node(
tree: py_trees.behaviour.Behaviour,
message: str = "Which position?",
Expand All @@ -237,7 +242,7 @@ def prompt_identify_parent_node(
text=text,
type=int,
)

node = next(islice(iterate_nodes(tree), node_index, node_index + 1))
return node

Expand All @@ -257,7 +262,7 @@ def prompt_identify_tree_iterator_index(
node_index = click.prompt(
text=text,
type=click.Choice(index_options, case_sensitive=False),
show_choices=False
show_choices=False,
)
return int(node_index)

Expand Down Expand Up @@ -345,22 +350,19 @@ def remove_node(tree: T, node: Optional[py_trees.behaviour.Behaviour] = None) ->
f"{node}'s parent is None, so we can't remove it. You can't remove the root node."
)
action_log = {}
return tree,
return (tree,)
elif isinstance(parent_node, py_trees.composites.Composite):
parent_node.remove_child(node)
action_log = {
"type": "remove_node",
"nodes": [
{
"id_": node.id_,
"display_name": node.display_name
},
{"id_": node.id_, "display_name": node.display_name},
],
"timestamp": datetime.now().isoformat(),
}
"timestamp": datetime.now().isoformat(),
}
else:
raise NotImplementedError()

return tree, action_log


Expand Down Expand Up @@ -394,7 +396,6 @@ def move_node(
node.parent.remove_child(node)
new_parent.insert_child(node, index)


if not internal_call:
action_log = {
"type": "move_node",
Expand All @@ -404,10 +405,10 @@ def move_node(
"display_name": node.display_name,
},
],
"timestamp": datetime.now().isoformat(),
"timestamp": datetime.now().isoformat(),
}
return tree, action_log

return tree


Expand Down Expand Up @@ -481,49 +482,36 @@ def exchange_nodes(
}
)
else:
nodes.append(
{
"id": node0.id_,
"display_name": node0.display_name
}
)

nodes.append({"id": node0.id_, "display_name": node0.display_name})

if node1.__class__.__name__ != "CustomBehavior":
nodes.append(
{
"display_name": node1.display_name,
}
)
else:
nodes.append(
{
"id": node1.id_,
"display_name": node1.display_name
}
)
nodes.append({"id": node1.id_, "display_name": node1.display_name})

action_log = {
"type": "exchange_nodes",
"nodes": nodes,
"timestamp": datetime.now().isoformat(),
"timestamp": datetime.now().isoformat(),
}
return tree, action_log


def prompt_select_node(behavior_library, text):

for idx, tree_name in enumerate(behavior_library.behavior_from_display_name.keys(), 1):
print(f"{idx}. {tree_name}")
for idx, tree_name in enumerate(
behavior_library.behavior_from_display_name.keys(), 1
):
print(f"{idx}. {tree_name}")

choices = [str(i + 1) for i in range(len(behavior_library.behaviors))]
node_index = click.prompt(
text=text,
type=click.Choice(choices),
show_choices=False
)
node_index = click.prompt(text=text, type=click.Choice(choices), show_choices=False)

node_key = list(behavior_library.behavior_from_display_name.keys())[node_index - 1]

node_key = list(behavior_library.behavior_from_display_name.keys())[node_index-1]

return behavior_library.behavior_from_display_name[node_key]


Expand All @@ -538,40 +526,38 @@ def add_node(
"""


behavior = prompt_select_node(behavior_library, f"Which behavior do you want to add?")

if behavior['type'] == "Behavior":
behavior = prompt_select_node(
behavior_library, f"Which behavior do you want to add?"
)

if behavior["type"] == "Behavior":
new_node = CustomBehavior(
name=behavior['display_name'],
id_=behavior['id'],
display_name=behavior['display_name']
)
elif behavior['type'] == "Sequence":
name=behavior["display_name"],
id_=behavior["id"],
display_name=behavior["display_name"],
)

elif behavior["type"] == "Sequence":
new_node = CustomSequence(
name=behavior['display_name'],
id_=behavior['id'],
display_name=behavior['display_name'],
)
name=behavior["display_name"],
id_=behavior["id"],
display_name=behavior["display_name"],
)

new_parent = prompt_identify_parent_node(
tree, f"What should its parent be?", display_nodes=True
)

index = prompt_identify_child_index(new_parent)

assert isinstance(new_parent, py_trees.composites.Composite)

new_parent.insert_child(new_node, index)

action_log = {
"type": "add_node",
"node": {
"id": new_node.id_,
"display_name": new_node.display_name
},
"timestamp": datetime.now().isoformat(),
"node": {"id": new_node.id_, "display_name": new_node.display_name},
"timestamp": datetime.now().isoformat(),
}

return tree, action_log
57 changes: 30 additions & 27 deletions src/social_norms_trees/serialize_tree.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,59 +13,62 @@ def serialize_node(node):

return serialize_node(tree)


def deserialize_tree(tree, behavior_library):
def deserialize_node(node):
assert type(node['type'] == str), (
assert type(node["type"] == str), (
f"\nThere was an invalid configuration detected in the inputted behavior tree: "
f"Invalid type for node attribute 'type' found for node '{node['name']}'. "
f"Please ensure that the 'name' attribute is a string."
)
assert type(node['name'] == str), (
assert type(node["name"] == str), (
f"\nThere was an invalid configuration detected in the inputted behavior tree: "
f"Invalid type for node attribute 'name' found for node '{node['name']}'. "
f"Please ensure that the 'name' attribute is a string."
)

node_type = node['type']
node_type = node["type"]
assert node_type in ["Sequence", "Selector", "Behavior"], (
f"\nThere was an invalid configuration detected in the inputted behavior tree: "
f"Invalid node type '{node_type}' found for node '{node['name']}'. "
f"Please ensure that all node types are correct and supported."
)

behavior = behavior_library.behavior_from_display_name[node['name']]
behavior = behavior_library.behavior_from_display_name[node["name"]]

if node_type == 'Sequence':
if node_type == "Sequence":
children = [deserialize_node(child) for child in node["children"]]

children = [deserialize_node(child) for child in node['children']]

if behavior:
return CustomSequence(
name=behavior['display_name'],
id_=behavior['id'],
display_name=behavior['display_name'],
children=children
name=behavior["display_name"],
id_=behavior["id"],
display_name=behavior["display_name"],
children=children,
)
else:
raise ValueError(f"Behavior {node['name']} not found in behavior library")

#TODO: node type Selector

elif node_type == 'Behavior':

assert ('children' not in node or len(node['children']) == 0), (
f"\nThere was an invalid configuration detected in the inputted behavior tree: "
f"Children were detected for Behavior type node '{node['name']}': "
f"Behavior nodes should not have any children. Please check the structure of your behavior tree."
)
raise ValueError(
f"Behavior {node['name']} not found in behavior library"
)

# TODO: node type Selector

elif node_type == "Behavior":
assert "children" not in node or len(node["children"]) == 0, (
f"\nThere was an invalid configuration detected in the inputted behavior tree: "
f"Children were detected for Behavior type node '{node['name']}': "
f"Behavior nodes should not have any children. Please check the structure of your behavior tree."
)

if behavior:
return CustomBehavior(
name=behavior['display_name'],
id_=behavior['id'],
display_name=behavior['display_name']
name=behavior["display_name"],
id_=behavior["id"],
display_name=behavior["display_name"],
)
else:
raise ValueError(f"Behavior {node['name']} not found in behavior library")
raise ValueError(
f"Behavior {node['name']} not found in behavior library"
)

return deserialize_node(tree)
return deserialize_node(tree)
Loading

0 comments on commit 1628d28

Please sign in to comment.