Skip to content

Commit

Permalink
Fxed the argname finding function.
Browse files Browse the repository at this point in the history
  • Loading branch information
JohnMark Taylor committed Jul 22, 2024
1 parent 7611300 commit 7f356c8
Showing 1 changed file with 16 additions and 83 deletions.
99 changes: 16 additions & 83 deletions torchlens/model_history.py
Original file line number Diff line number Diff line change
Expand Up @@ -1215,7 +1215,7 @@ def get_func_argnames(self, orig_func: Callable, func_name: str):
argnames = []
for arg in arg_list:
argname = arg.split('=')[0]
if argname in ['*', '/']:
if argname in ['*', '/', '//', '']:
continue
argname = argname.replace('*', '')
argnames.append(argname)
Expand Down Expand Up @@ -4307,7 +4307,7 @@ def _fix_buffer_layers(self):
layer.orig_ancestors.remove(layer.tensor_label_raw)
layer.orig_ancestors.update(self[layer.buffer_parent].orig_ancestors)
layer.parent_layer_arg_locs['args'][0] = layer.buffer_parent
if self[layer.buffer_parent].tensor_contents is not None:
if (self[layer.buffer_parent].tensor_contents is not None) and (layer.creation_args is not None):
layer.creation_args.append(self[layer.buffer_parent].tensor_contents.detach().clone())

buffer_hash = str(layer.containing_modules_origin_nested) + str(layer.buffer_parent) + layer.buffer_address
Expand Down Expand Up @@ -4775,6 +4775,9 @@ def _add_lookup_keys_for_tensor_entry(
f"{module_name}:{module_pass}"
for module_name, module_pass in tensor_entry.containing_modules_origin_nested
]
if (tensor_entry.containing_module_origin is None) and len(
tensor_entry.containing_modules_origin_nested) > 0:
tensor_entry.containing_module_origin = tensor_entry.containing_modules_origin_nested[-1]

# Allow indexing by modules exited as well:
for module_pass in tensor_entry.module_passes_exited:
Expand Down Expand Up @@ -4957,12 +4960,6 @@ def render_graph(
vis_opt: str = "unrolled",
vis_nesting_depth: int = 1000,
vis_outpath: str = "modelgraph",
vis_graph_overrides: Dict = None,
vis_node_overrides: Dict = None,
vis_nested_node_overrides: Dict = None,
vis_edge_overrides: Dict = None,
vis_gradient_edge_overrides: Dict = None,
vis_module_overrides: Dict = None,
save_only: bool = False,
vis_fileformat: str = "pdf",
show_buffer_layers: bool = False,
Expand All @@ -4981,19 +4978,6 @@ def render_graph(
direction: which way the graph should go: either 'bottomup', 'topdown', or 'leftright'
"""
if vis_graph_overrides is None:
vis_graph_overrides = {}
if vis_node_overrides is None:
vis_node_overrides = {}
if vis_nested_node_overrides is None:
vis_nested_node_overrides = {}
if vis_edge_overrides is None:
vis_edge_overrides = {}
if vis_gradient_edge_overrides is None:
vis_gradient_edge_overrides = {}
if vis_module_overrides is None:
vis_module_overrides = {}

if not self.all_layers_logged:
raise ValueError(
"Must have all layers logged in order to render the graph; either save all layers,"
Expand Down Expand Up @@ -5053,12 +5037,6 @@ def render_graph(
'labeljust': 'left',
'ordering': 'out'}

for arg_name, arg_val in vis_graph_overrides.items():
if callable(arg_val):
graph_args[arg_name] = str(arg_val(self))
else:
graph_args[arg_name] = str(arg_val)

dot.graph_attr.update(graph_args)
dot.node_attr.update({"ordering": "out"})

Expand All @@ -5081,14 +5059,10 @@ def render_graph(
collapsed_modules,
vis_nesting_depth,
show_buffer_layers,
vis_node_overrides,
vis_nested_node_overrides,
vis_edge_overrides,
vis_gradient_edge_overrides
)

# Finally, set up the subgraphs.
self._set_up_subgraphs(dot, vis_opt, module_cluster_dict, vis_module_overrides)
self._set_up_subgraphs(dot, vis_opt, module_cluster_dict)

if in_notebook() and not save_only:
display(dot)
Expand All @@ -5105,10 +5079,6 @@ def _add_node_to_graphviz(
collapsed_modules: Set,
vis_nesting_depth: int = 1000,
show_buffer_layers: bool = False,
vis_node_overrides: Dict = None,
vis_collapsed_node_overrides: Dict = None,
vis_edge_overrides: Dict = None,
vis_gradient_edge_overrides: Dict = None
):
"""Addes a node and its relevant edges to the graphviz figure.
Expand All @@ -5125,12 +5095,12 @@ def _add_node_to_graphviz(

if is_collapsed_module:
self._construct_collapsed_module_node(
node, graphviz_graph, collapsed_modules, vis_opt, vis_nesting_depth, vis_collapsed_node_overrides
node, graphviz_graph, collapsed_modules, vis_opt, vis_nesting_depth
)
node_color = "black"
else:
node_color = self._construct_layer_node(
node, graphviz_graph, show_buffer_layers, vis_opt, vis_node_overrides
node, graphviz_graph, show_buffer_layers, vis_opt
)

self._add_edges_for_node(
Expand All @@ -5142,9 +5112,7 @@ def _add_node_to_graphviz(
edges_used,
graphviz_graph,
vis_opt,
show_buffer_layers,
vis_edge_overrides,
vis_gradient_edge_overrides
show_buffer_layers
)

@staticmethod
Expand All @@ -5158,7 +5126,7 @@ def _check_if_collapsed_module(node, vis_nesting_depth):
else:
return False

def _construct_layer_node(self, node, graphviz_graph, show_buffer_layers, vis_opt, vis_node_overrides):
def _construct_layer_node(self, node, graphviz_graph, show_buffer_layers, vis_opt):
# Get the address, shape, color, and line style:

node_address, node_shape, node_color = self._get_node_address_shape_color(
Expand All @@ -5184,12 +5152,6 @@ def _construct_layer_node(self, node, graphviz_graph, show_buffer_layers, vis_op
'shape': node_shape,
'ordering': 'out'
}
for arg_name, arg_val in vis_node_overrides.items():
if callable(arg_val):
node_args[arg_name] = str(arg_val(self, node))
else:
node_args[arg_name] = str(arg_val)

graphviz_graph.node(**node_args)

if node.is_last_output_layer:
Expand All @@ -5200,7 +5162,7 @@ def _construct_layer_node(self, node, graphviz_graph, show_buffer_layers, vis_op
return node_color

def _construct_collapsed_module_node(
self, node, graphviz_graph, collapsed_modules, vis_opt, vis_nesting_depth, vis_collapsed_node_overrides
self, node, graphviz_graph, collapsed_modules, vis_opt, vis_nesting_depth
):
module_address_w_pass = node.containing_modules_origin_nested[
vis_nesting_depth - 1
Expand Down Expand Up @@ -5278,12 +5240,6 @@ def _construct_collapsed_module_node(
'ordering': 'out'
}

for arg_name, arg_val in vis_collapsed_node_overrides.items():
if callable(arg_val):
node_args[arg_name] = str(arg_val(self, node))
else:
node_args[arg_name] = str(arg_val)

graphviz_graph.node(**node_args)

def _get_node_address_shape_color(
Expand Down Expand Up @@ -5477,9 +5433,7 @@ def _add_edges_for_node(
edges_used: Set,
graphviz_graph,
vis_opt: str = "unrolled",
show_buffer_layers: bool = False,
vis_edge_overrides: Dict = None,
vis_gradient_edge_overrides: Dict = None
show_buffer_layers: bool = False
):
"""Add the rolled-up edges for a node, marking for the edge which passes it happened for.
Expand Down Expand Up @@ -5591,12 +5545,6 @@ def _add_edges_for_node(
if vis_opt == "rolled":
self._label_rolled_pass_nums(child_node, parent_node, edge_dict)

for arg_name, arg_val in vis_edge_overrides.items():
if callable(arg_val):
edge_dict[arg_name] = str(arg_val(self, parent_node, child_node))
else:
edge_dict[arg_name] = str(arg_val)

# Add it to the appropriate module cluster (most nested one containing both nodes)
containing_module = self._get_lowest_containing_module_for_two_nodes(
parent_node, child_node, both_nodes_collapsed_modules, vis_nesting_depth
Expand Down Expand Up @@ -5625,7 +5573,6 @@ def _add_edges_for_node(
containing_module,
module_edge_dict,
graphviz_graph,
vis_gradient_edge_overrides
)

def _label_node_arguments_if_needed(
Expand Down Expand Up @@ -5814,7 +5761,6 @@ def _add_gradient_edge(
containing_module,
module_edge_dict,
graphviz_graph,
vis_gradient_edge_overrides
):
"""Adds a backwards edge if both layers have saved gradients, showing the backward pass."""
if parent_layer.has_saved_grad and child_layer.has_saved_grad:
Expand All @@ -5827,19 +5773,14 @@ def _add_gradient_edge(
"arrowsize": ".7",
"labelfontsize": "8",
}
for arg_name, arg_val in vis_gradient_edge_overrides.items():
if callable(arg_val):
edge_dict[arg_name] = str(arg_val(self, parent_layer, child_layer))
else:
edge_dict[arg_name] = str(arg_val)

if containing_module != -1:
module_edge_dict[containing_module]["edges"].append(edge_dict)
else:
graphviz_graph.edge(**edge_dict)

def _set_up_subgraphs(
self, graphviz_graph, vis_opt: str, module_edge_dict: Dict[str, List], vis_module_overrides: Dict = None
self, graphviz_graph, vis_opt: str, module_edge_dict: Dict[str, List]
):
"""Given a dictionary specifying the edges in each cluster and the graphviz graph object,
set up the nested subgraphs and the nodes that should go inside each of them. There will be some tricky
Expand Down Expand Up @@ -5876,8 +5817,7 @@ def _set_up_subgraphs(
subgraph_stack,
nesting_depth,
max_nesting_depth,
vis_opt,
vis_module_overrides
vis_opt
)

def _setup_subgraphs_recurse(
Expand All @@ -5889,8 +5829,7 @@ def _setup_subgraphs_recurse(
subgraph_stack,
nesting_depth,
max_nesting_depth,
vis_opt,
vis_module_overrides
vis_opt
):
"""Utility function to crawl down several layers deep into nested subgraphs.
Expand Down Expand Up @@ -5936,8 +5875,7 @@ def _setup_subgraphs_recurse(
subgraph_stack,
nesting_depth + 1,
max_nesting_depth,
vis_opt,
vis_module_overrides
vis_opt
)

else: # we made it, make the subgraph and add all edges.
Expand All @@ -5960,11 +5898,6 @@ def _setup_subgraphs_recurse(
'fillcolor': 'white',
'penwidth': str(pen_width)}

for arg_name, arg_val in vis_module_overrides.items():
if callable(arg_val):
module_args[arg_name] = str(arg_val(self, subgraph_name))
else:
module_args[arg_name] = str(arg_val)
s.attr(**module_args)
subgraph_edges = module_edge_dict[subgraph_name]["edges"]
for edge_dict in subgraph_edges:
Expand Down

0 comments on commit 7f356c8

Please sign in to comment.