Skip to content

Commit

Permalink
Merge branch 'custom-visuals'
Browse files Browse the repository at this point in the history
# Conflicts:
#	torchlens/model_history.py
  • Loading branch information
JohnMark Taylor committed Jul 30, 2024
2 parents fe7cd58 + a3bc421 commit 3be8051
Showing 1 changed file with 81 additions and 11 deletions.
92 changes: 81 additions & 11 deletions torchlens/model_history.py
Original file line number Diff line number Diff line change
Expand Up @@ -4960,6 +4960,12 @@ def render_graph(
vis_opt: str = "unrolled",
vis_nesting_depth: int = 1000,
vis_outpath: str = "modelgraph",
vis_graph_overrides: Dict = None,
vis_node_overrides: Dict = None,
vis_nested_node_overrides: Dict = None,
vis_edge_overrides: Dict = None,
vis_gradient_edge_overrides: Dict = None,
vis_module_overrides: Dict = None,
save_only: bool = False,
vis_fileformat: str = "pdf",
show_buffer_layers: bool = False,
Expand All @@ -4978,6 +4984,19 @@ def render_graph(
direction: which way the graph should go: either 'bottomup', 'topdown', or 'leftright'
"""
if vis_graph_overrides is None:
vis_graph_overrides = {}
if vis_node_overrides is None:
vis_node_overrides = {}
if vis_nested_node_overrides is None:
vis_nested_node_overrides = {}
if vis_edge_overrides is None:
vis_edge_overrides = {}
if vis_gradient_edge_overrides is None:
vis_gradient_edge_overrides = {}
if vis_module_overrides is None:
vis_module_overrides = {}

if not self.all_layers_logged:
raise ValueError(
"Must have all layers logged in order to render the graph; either save all layers,"
Expand Down Expand Up @@ -5037,6 +5056,12 @@ def render_graph(
'labeljust': 'left',
'ordering': 'out'}

for arg_name, arg_val in vis_graph_overrides.items():
if callable(arg_val):
graph_args[arg_name] = str(arg_val(self))
else:
graph_args[arg_name] = str(arg_val)

dot.graph_attr.update(graph_args)
dot.node_attr.update({"ordering": "out"})

Expand All @@ -5059,10 +5084,14 @@ def render_graph(
collapsed_modules,
vis_nesting_depth,
show_buffer_layers,
vis_node_overrides,
vis_nested_node_overrides,
vis_edge_overrides,
vis_gradient_edge_overrides
)

# Finally, set up the subgraphs.
self._set_up_subgraphs(dot, vis_opt, module_cluster_dict)
self._set_up_subgraphs(dot, vis_opt, module_cluster_dict, vis_module_overrides)

if in_notebook() and not save_only:
display(dot)
Expand All @@ -5079,6 +5108,10 @@ def _add_node_to_graphviz(
collapsed_modules: Set,
vis_nesting_depth: int = 1000,
show_buffer_layers: bool = False,
vis_node_overrides: Dict = None,
vis_collapsed_node_overrides: Dict = None,
vis_edge_overrides: Dict = None,
vis_gradient_edge_overrides: Dict = None
):
"""Addes a node and its relevant edges to the graphviz figure.
Expand All @@ -5095,12 +5128,12 @@ def _add_node_to_graphviz(

if is_collapsed_module:
self._construct_collapsed_module_node(
node, graphviz_graph, collapsed_modules, vis_opt, vis_nesting_depth
node, graphviz_graph, collapsed_modules, vis_opt, vis_nesting_depth, vis_collapsed_node_overrides
)
node_color = "black"
else:
node_color = self._construct_layer_node(
node, graphviz_graph, show_buffer_layers, vis_opt
node, graphviz_graph, show_buffer_layers, vis_opt, vis_node_overrides
)

self._add_edges_for_node(
Expand All @@ -5112,7 +5145,9 @@ def _add_node_to_graphviz(
edges_used,
graphviz_graph,
vis_opt,
show_buffer_layers
show_buffer_layers,
vis_edge_overrides,
vis_gradient_edge_overrides
)

@staticmethod
Expand All @@ -5126,7 +5161,7 @@ def _check_if_collapsed_module(node, vis_nesting_depth):
else:
return False

def _construct_layer_node(self, node, graphviz_graph, show_buffer_layers, vis_opt):
def _construct_layer_node(self, node, graphviz_graph, show_buffer_layers, vis_opt, vis_node_overrides):
# Get the address, shape, color, and line style:

node_address, node_shape, node_color = self._get_node_address_shape_color(
Expand All @@ -5152,6 +5187,12 @@ def _construct_layer_node(self, node, graphviz_graph, show_buffer_layers, vis_op
'shape': node_shape,
'ordering': 'out'
}
for arg_name, arg_val in vis_node_overrides.items():
if callable(arg_val):
node_args[arg_name] = str(arg_val(self, node))
else:
node_args[arg_name] = str(arg_val)

graphviz_graph.node(**node_args)

if node.is_last_output_layer:
Expand All @@ -5162,7 +5203,7 @@ def _construct_layer_node(self, node, graphviz_graph, show_buffer_layers, vis_op
return node_color

def _construct_collapsed_module_node(
self, node, graphviz_graph, collapsed_modules, vis_opt, vis_nesting_depth
self, node, graphviz_graph, collapsed_modules, vis_opt, vis_nesting_depth, vis_collapsed_node_overrides
):
module_address_w_pass = node.containing_modules_origin_nested[
vis_nesting_depth - 1
Expand Down Expand Up @@ -5240,6 +5281,12 @@ def _construct_collapsed_module_node(
'ordering': 'out'
}

for arg_name, arg_val in vis_collapsed_node_overrides.items():
if callable(arg_val):
node_args[arg_name] = str(arg_val(self, node))
else:
node_args[arg_name] = str(arg_val)

graphviz_graph.node(**node_args)

def _get_node_address_shape_color(
Expand Down Expand Up @@ -5433,7 +5480,9 @@ def _add_edges_for_node(
edges_used: Set,
graphviz_graph,
vis_opt: str = "unrolled",
show_buffer_layers: bool = False
show_buffer_layers: bool = False,
vis_edge_overrides: Dict = None,
vis_gradient_edge_overrides: Dict = None
):
"""Add the rolled-up edges for a node, marking for the edge which passes it happened for.
Expand Down Expand Up @@ -5545,6 +5594,12 @@ def _add_edges_for_node(
if vis_opt == "rolled":
self._label_rolled_pass_nums(child_node, parent_node, edge_dict)

for arg_name, arg_val in vis_edge_overrides.items():
if callable(arg_val):
edge_dict[arg_name] = str(arg_val(self, parent_node, child_node))
else:
edge_dict[arg_name] = str(arg_val)

# Add it to the appropriate module cluster (most nested one containing both nodes)
containing_module = self._get_lowest_containing_module_for_two_nodes(
parent_node, child_node, both_nodes_collapsed_modules, vis_nesting_depth
Expand Down Expand Up @@ -5573,6 +5628,7 @@ def _add_edges_for_node(
containing_module,
module_edge_dict,
graphviz_graph,
vis_gradient_edge_overrides
)

def _label_node_arguments_if_needed(
Expand Down Expand Up @@ -5761,6 +5817,7 @@ def _add_gradient_edge(
containing_module,
module_edge_dict,
graphviz_graph,
vis_gradient_edge_overrides
):
"""Adds a backwards edge if both layers have saved gradients, showing the backward pass."""
if parent_layer.has_saved_grad and child_layer.has_saved_grad:
Expand All @@ -5773,14 +5830,19 @@ def _add_gradient_edge(
"arrowsize": ".7",
"labelfontsize": "8",
}
for arg_name, arg_val in vis_gradient_edge_overrides.items():
if callable(arg_val):
edge_dict[arg_name] = str(arg_val(self, parent_layer, child_layer))
else:
edge_dict[arg_name] = str(arg_val)

if containing_module != -1:
module_edge_dict[containing_module]["edges"].append(edge_dict)
else:
graphviz_graph.edge(**edge_dict)

def _set_up_subgraphs(
self, graphviz_graph, vis_opt: str, module_edge_dict: Dict[str, List]
self, graphviz_graph, vis_opt: str, module_edge_dict: Dict[str, List], vis_module_overrides: Dict = None
):
"""Given a dictionary specifying the edges in each cluster and the graphviz graph object,
set up the nested subgraphs and the nodes that should go inside each of them. There will be some tricky
Expand Down Expand Up @@ -5817,7 +5879,8 @@ def _set_up_subgraphs(
subgraph_stack,
nesting_depth,
max_nesting_depth,
vis_opt
vis_opt,
vis_module_overrides
)

def _setup_subgraphs_recurse(
Expand All @@ -5829,7 +5892,8 @@ def _setup_subgraphs_recurse(
subgraph_stack,
nesting_depth,
max_nesting_depth,
vis_opt
vis_opt,
vis_module_overrides
):
"""Utility function to crawl down several layers deep into nested subgraphs.
Expand Down Expand Up @@ -5875,7 +5939,8 @@ def _setup_subgraphs_recurse(
subgraph_stack,
nesting_depth + 1,
max_nesting_depth,
vis_opt
vis_opt,
vis_module_overrides
)

else: # we made it, make the subgraph and add all edges.
Expand All @@ -5898,6 +5963,11 @@ def _setup_subgraphs_recurse(
'fillcolor': 'white',
'penwidth': str(pen_width)}

for arg_name, arg_val in vis_module_overrides.items():
if callable(arg_val):
module_args[arg_name] = str(arg_val(self, subgraph_name))
else:
module_args[arg_name] = str(arg_val)
s.attr(**module_args)
subgraph_edges = module_edge_dict[subgraph_name]["edges"]
for edge_dict in subgraph_edges:
Expand Down

0 comments on commit 3be8051

Please sign in to comment.