Skip to content

Commit

Permalink
fixed errors in graphing
Browse files Browse the repository at this point in the history
  • Loading branch information
digicosmos86 committed Feb 22, 2024
1 parent 1a6a17d commit 80a202d
Show file tree
Hide file tree
Showing 2 changed files with 8 additions and 5 deletions.
2 changes: 1 addition & 1 deletion src/hssm/hssm.py
Original file line number Diff line number Diff line change
Expand Up @@ -668,7 +668,7 @@ def graph(self, formatting="plain", name=None, figsize=None, dpi=300, fmt="png")

graphviz = HSSMModelGraph(
model=self.pymc_model, parent=self._parent_param
).make_graph(formatting=formatting)
).make_graph(formatting=formatting, response_str=self.response_str)

width, height = (None, None) if figsize is None else figsize

Expand Down
11 changes: 7 additions & 4 deletions src/hssm/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -172,7 +172,10 @@ def __init__(self, model, parent):
super().__init__(model)

def make_graph(
self, var_names: Iterable[VarName] | None = None, formatting: str = "plain"
self,
var_names: Iterable[VarName] | None = None,
formatting: str = "plain",
response_str: str = "rt,response",
):
"""Make graphviz Digraph of PyMC model.
Expand Down Expand Up @@ -224,8 +227,8 @@ def make_graph(
label=f"{self.parent.name}\n~\nDeterministic",
shape="box",
)
shape = fast_eval(self.model[self.response_str].shape)
plate_label = f"{self.response_str}_obs({shape[0]})"
shape = fast_eval(self.model[response_str].shape)
plate_label = f"{response_str}_obs({shape[0]})"

sub.attr(
label=plate_label,
Expand All @@ -249,7 +252,7 @@ def make_graph(
graph.edge(parent.replace(":", "&"), child.replace(":", "&"))

if self.parent.is_regression:
graph.edge(self.parent.name, self.response_str)
graph.edge(self.parent.name, response_str)

return graph

Expand Down

0 comments on commit 80a202d

Please sign in to comment.