Skip to content

Commit

Permalink
reformat
Browse files Browse the repository at this point in the history
  • Loading branch information
hollandjg committed Oct 10, 2024
1 parent d7bfe84 commit 99cf31e
Show file tree
Hide file tree
Showing 4 changed files with 47 additions and 57 deletions.
61 changes: 25 additions & 36 deletions src/social_norms_trees/atomic_mutations.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,12 +23,10 @@
BehaviorIdentifier = TypeVar(
"BehaviorIdentifier", bound=Union[ExistingNode, NewNode, CompositeIndex]
)
BehaviorTreeNode = TypeVar(
"BehaviorTreeNode", bound=py_trees.behaviour.Behaviour)
BehaviorTreeNode = TypeVar("BehaviorTreeNode", bound=py_trees.behaviour.Behaviour)
BehaviorTree = TypeVar("BehaviorTree", bound=BehaviorTreeNode)
BehaviorLibrary = TypeVar("BehaviorLibrary", bound=List[BehaviorTreeNode])
TreeOrLibrary = TypeVar(
"TreeOrLibrary", bound=Union[BehaviorTree, BehaviorLibrary])
TreeOrLibrary = TypeVar("TreeOrLibrary", bound=Union[BehaviorTree, BehaviorLibrary])


# =============================================================================
Expand Down Expand Up @@ -451,15 +449,11 @@ def get_library_mapping(library: BehaviorLibrary) -> NodeMappingRepresentation:

mapping = {str(i): n for i, n in enumerate(library)}
labels = list(mapping.keys())
representation = "\n".join(
[f"{i}: {n.name}" for i, n in enumerate(library)])
representation = "\n".join([f"{i}: {n.name}" for i, n in enumerate(library)])
return NodeMappingRepresentation(mapping, labels, representation)


prompt_identify_library_node = partial(
prompt_identify,
function=get_library_mapping
)
prompt_identify_library_node = partial(prompt_identify, function=get_library_mapping)


def get_composite_mapping(tree: BehaviorTree, skip_label="_"):
Expand Down Expand Up @@ -502,10 +496,7 @@ def get_composite_mapping(tree: BehaviorTree, skip_label="_"):
return NodeMappingRepresentation(mapping, allowed_labels, representation)


prompt_identify_composite = partial(
prompt_identify,
function=get_composite_mapping
)
prompt_identify_composite = partial(prompt_identify, function=get_composite_mapping)


def get_child_index_mapping(tree: BehaviorTree, skip_label="_"):
Expand Down Expand Up @@ -559,29 +550,26 @@ def get_child_index_mapping(tree: BehaviorTree, skip_label="_"):
return NodeMappingRepresentation(mapping, allowed_labels, representation)


prompt_identify_child_index = partial(
prompt_identify,
function=get_child_index_mapping
)
prompt_identify_child_index = partial(prompt_identify, function=get_child_index_mapping)


def get_position_mapping(tree):
"""
[-] S0
--> {1}
[-] S1
--> {2}
--> Dummy
--> {3}
--> {4}
[-] S2
--> {5}
--> Failure
--> {6}
--> {7}
[-] S0
--> {1}
[-] S1
--> {2}
--> Dummy
--> {3}
--> {4}
[-] S2
--> {5}
--> Failure
--> {6}
--> {7}
Expand All @@ -595,8 +583,7 @@ def get_position_mapping(tree):


# Wrapper functions for the atomic operations which give them a UI.
MutationResult = namedtuple(
"MutationResult", ["result", "tree", "function", "kwargs"])
MutationResult = namedtuple("MutationResult", ["result", "tree", "function", "kwargs"])


def mutate_chooser(*fs: Union[Callable], message="Which action?"):
Expand All @@ -622,7 +609,9 @@ def mutate_chooser(*fs: Union[Callable], message="Which action?"):

def mutate_ui(
f: Callable,
) -> Callable[[py_trees.behaviour.Behaviour, List[py_trees.behaviour.Behaviour]], MutationResult]:
) -> Callable[
[py_trees.behaviour.Behaviour, List[py_trees.behaviour.Behaviour]], MutationResult
]:
"""Factory function for a tree mutator UI.
This creates a version of the atomic function `f`
which prompts the user for the appropriate arguments
Expand Down Expand Up @@ -663,14 +652,14 @@ def prompt_get_mutate_arguments(annotation: GenericAlias, tree, library):
return node
elif annotation_ == str(CompositeIndex):
_logger.debug("in CompositeIndex")
composite_node = prompt_identify_composite(
tree, message="Which parent?")
composite_node = prompt_identify_composite(tree, message="Which parent?")
index = prompt_identify_child_index(composite_node)
return composite_node, index
elif annotation_ == str(NewNode):
_logger.debug("in NewNode")
new_node = prompt_identify_library_node(
library, message="Which node from the library?")
library, message="Which node from the library?"
)
return new_node
else:
_logger.debug("in 'else'")
Expand Down
3 changes: 1 addition & 2 deletions src/social_norms_trees/behavior_library.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,8 +4,7 @@ def __init__(self, behavior_list):
self.behavior_from_display_name = {
behavior["name"]: behavior for behavior in behavior_list
}
self.behavior_from_id = {
behavior["id"]: behavior for behavior in behavior_list}
self.behavior_from_id = {behavior["id"]: behavior for behavior in behavior_list}

def __iter__(self):
for i in self.behaviors:
Expand Down
5 changes: 3 additions & 2 deletions src/social_norms_trees/serialize_tree.py
Original file line number Diff line number Diff line change
Expand Up @@ -61,8 +61,9 @@ def deserialize_library_element(description: dict):

if node_type == "Sequence":
if "children" in description.keys():
children = [deserialize_library_element(
child) for child in description["children"]]
children = [
deserialize_library_element(child) for child in description["children"]
]
else:
children = []

Expand Down
35 changes: 18 additions & 17 deletions src/social_norms_trees/ui_wrapper.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,11 @@
remove,
end_experiment,
)
from social_norms_trees.serialize_tree import deserialize_library_element, serialize_tree, deserialize_tree
from social_norms_trees.serialize_tree import (
deserialize_library_element,
serialize_tree,
deserialize_tree,
)

from social_norms_trees.behavior_library import BehaviorLibrary

Expand Down Expand Up @@ -50,8 +54,7 @@ def experiment_setup(db, origin_tree):
print("\n")
participant_id = participant_login()

experiment_id = initialize_experiment_record(
db, participant_id, origin_tree)
experiment_id = initialize_experiment_record(db, participant_id, origin_tree)

print("\nSetup Complete.\n")

Expand All @@ -66,8 +69,7 @@ def participant_login():

def load_resources(file_path):
try:
print(
f"\nLoading behavior tree and behavior library from {file_path}...\n")
print(f"\nLoading behavior tree and behavior library from {file_path}...\n")
with open(file_path, "r") as file:
resources = json.load(file)

Expand All @@ -80,8 +82,7 @@ def load_resources(file_path):
behavior_list = resources.get("behavior_library")
context_paragraph = resources.get("context")

behavior_tree = deserialize_tree(
behavior_tree, BehaviorLibrary(behavior_list))
behavior_tree = deserialize_tree(behavior_tree, BehaviorLibrary(behavior_list))

behavior_library = [deserialize_library_element(e) for e in behavior_list]

Expand Down Expand Up @@ -138,7 +139,7 @@ def run_experiment(tree, library):
results_dict = {
"start_time": datetime.now().isoformat(),
"initial_behavior_tree": serialize_tree(tree),
"action_log": []
"action_log": [],
}

try:
Expand All @@ -148,11 +149,13 @@ def run_experiment(tree, library):
if f is end_experiment:
break
results = f(tree, library)
results_dict["action_log"].append({
"type": results.function.__name__,
"kwargs": serialize_function_arguments(results.kwargs),
"time": datetime.now().isoformat(),
})
results_dict["action_log"].append(
{
"type": results.function.__name__,
"kwargs": serialize_function_arguments(results.kwargs),
"time": datetime.now().isoformat(),
}
)

except QuitException:
pass
Expand Down Expand Up @@ -183,8 +186,7 @@ def main(
],
db_file: Annotated[
pathlib.Path,
typer.Option(
help="file where the experimental results will be written"),
typer.Option(help="file where the experimental results will be written"),
] = "db.json",
verbose: Annotated[bool, typer.Option("--verbose")] = False,
debug: Annotated[bool, typer.Option("--debug")] = False,
Expand All @@ -206,8 +208,7 @@ def main(

# load tree to run experiment on, and behavior library

original_tree, behavior_library, context_paragraph = load_resources(
resources_file)
original_tree, behavior_library, context_paragraph = load_resources(resources_file)
print(f"\nContext of this experiment: {context_paragraph}")

participant_id, experiment_id = experiment_setup(db, original_tree)
Expand Down

0 comments on commit 99cf31e

Please sign in to comment.