diff --git a/tests/example_models.py b/tests/example_models.py
index da9dde7..4c75574 100644
--- a/tests/example_models.py
+++ b/tests/example_models.py
@@ -427,6 +427,41 @@ def forward(self, x):
return x
+class BufferRewriteModule(nn.Module):
+ def __init__(self):
+ super().__init__()
+ self.register_buffer("buffer1", torch.rand(12, 12))
+ self.register_buffer("buffer2", torch.rand(12, 12))
+
+ def forward(self, x):
+ x = torch.sin(x)
+ x = x + self.buffer1
+ x = x * self.buffer2
+ self.buffer1 = torch.rand(12, 12)
+ self.buffer2 = x ** 2
+ x = self.buffer1 + self.buffer2
+ return x
+
+
+class BufferRewriteModel(nn.Module):
+ def __init__(self):
+ super().__init__()
+ self.buffer_mod = BufferRewriteModule()
+
+ def forward(self, x):
+ x = torch.cos(x)
+ x = self.buffer_mod(x)
+ x = x * 4
+ x = self.buffer_mod(x)
+ x = self.buffer_mod(x)
+ x = x + 1
+ x = self.buffer_mod(x)
+ x = self.buffer_mod(x)
+ x = self.buffer_mod(x)
+ x = x * 2
+ return x
+
+
class SimpleBranching(nn.Module):
def __init__(self):
super().__init__()
diff --git a/tests/test_validation_and_visuals.py b/tests/test_validation_and_visuals.py
index c005e03..bb0fb6a 100644
--- a/tests/test_validation_and_visuals.py
+++ b/tests/test_validation_and_visuals.py
@@ -522,6 +522,72 @@ def test_buffer_model():
)
+def test_buffer_rewrite_model():
+ model = example_models.BufferRewriteModel()
+ model_input = torch.rand(12, 12)
+ assert validate_saved_activations(model, model_input)
+ show_model_graph(
+ model,
+ model_input,
+ vis_opt="unrolled",
+ vis_nesting_depth=1,
+ vis_outpath=opj("visualization_outputs", "toy-networks", "buffer_rewrite_model_visible_unnested_unrolled"),
+ vis_buffer_layers=True,
+ )
+ show_model_graph(
+ model,
+ model_input,
+ vis_opt="unrolled",
+ vis_nesting_depth=1,
+ vis_outpath=opj("visualization_outputs", "toy-networks", "buffer_rewrite_model_invisible_unnested_unrolled"),
+ vis_buffer_layers=False,
+ )
+ show_model_graph(
+ model,
+ model_input,
+ vis_opt="unrolled",
+ vis_outpath=opj("visualization_outputs", "toy-networks", "buffer_rewrite_model_visible_nested_unrolled"),
+ vis_buffer_layers=True,
+ )
+ show_model_graph(
+ model,
+ model_input,
+ vis_opt="unrolled",
+ vis_outpath=opj("visualization_outputs", "toy-networks", "buffer_rewrite_model_invisible_nested_unrolled"),
+ vis_buffer_layers=False,
+ )
+ show_model_graph(
+ model,
+ model_input,
+ vis_opt="rolled",
+ vis_nesting_depth=1,
+ vis_outpath=opj("visualization_outputs", "toy-networks", "buffer_rewrite_model_visible_unnested_rolled"),
+ vis_buffer_layers=True,
+ )
+ show_model_graph(
+ model,
+ model_input,
+ vis_opt="rolled",
+ vis_nesting_depth=1,
+ vis_outpath=opj("visualization_outputs", "toy-networks", "buffer_rewrite_model_invisible_unnested_rolled"),
+ vis_buffer_layers=False,
+ )
+ show_model_graph(
+ model,
+ model_input,
+ vis_opt="rolled",
+ vis_outpath=opj("visualization_outputs", "toy-networks", "buffer_rewrite_model_visible_nested_rolled"),
+ vis_buffer_layers=True,
+ )
+ show_model_graph(
+ model,
+ model_input,
+ vis_opt="rolled",
+ vis_outpath=opj("visualization_outputs", "toy-networks", "buffer_rewrite_model_invisible_nested_rolled"),
+ vis_buffer_layers=False,
+ )
+
+
def test_simple_branching(default_input1):
model = example_models.SimpleBranching()
assert validate_saved_activations(model, default_input1)
diff --git a/torchlens/constants.py b/torchlens/constants.py
index 5561b80..bcde820 100644
--- a/torchlens/constants.py
+++ b/torchlens/constants.py
@@ -56,6 +56,7 @@
"input_layers",
"output_layers",
"buffer_layers",
+ "buffer_num_passes",
"internally_initialized_layers",
"layers_where_internal_branches_merge_with_input",
"internally_terminated_layers",
@@ -207,6 +208,8 @@
"max_distance_from_output",
"is_buffer_layer",
"buffer_address",
+ "buffer_pass",
+ "buffer_parent",
"initialized_inside_model",
"has_internally_initialized_ancestor",
"internally_initialized_parents",
diff --git a/torchlens/helper_funcs.py b/torchlens/helper_funcs.py
index 89ac14f..eb9c5b5 100644
--- a/torchlens/helper_funcs.py
+++ b/torchlens/helper_funcs.py
@@ -356,7 +356,9 @@ def extend_search_stack_from_item(
if (attr_name.startswith("__")) or (attr_name == 'T') or ('grad' in attr_name):
continue
try:
- attr = getattr(item, attr_name)
+ with warnings.catch_warnings():
+ warnings.simplefilter("ignore")
+ attr = getattr(item, attr_name)
except:
continue
attr_cls = type(attr)
@@ -417,6 +419,9 @@ def nested_getattr(obj: Any, attr: str) -> Any:
if a in [
"volatile",
"T",
+ 'H',
+ 'mH',
+ 'mT'
]: # avoid annoying warning; if there's more, make a list
with warnings.catch_warnings():
warnings.simplefilter("ignore")
@@ -438,7 +443,9 @@ def nested_assign(obj, addr, val):
if entry_type == "ind":
obj = obj[entry_val]
elif entry_type == "attr":
- obj = getattr(obj, entry_val)
+ with warnings.catch_warnings():
+ warnings.simplefilter("ignore")
+ obj = getattr(obj, entry_val)
def remove_attributes_starting_with_str(obj: Any, s: str):
diff --git a/torchlens/model_history.py b/torchlens/model_history.py
index 0e21dd8..789f4bf 100644
--- a/torchlens/model_history.py
+++ b/torchlens/model_history.py
@@ -9,12 +9,13 @@
import types
from collections import OrderedDict, defaultdict
from typing import Any, Callable, Dict, List, Optional, Set, Tuple, Union
+import warnings
import graphviz
import numpy as np
import pandas as pd
import torch
-from IPython.core.display import display
+from IPython.display import display
from torch import nn
from torchlens.constants import (
@@ -191,6 +192,8 @@ def __init__(self, fields_dict: Dict):
self.input_output_address = fields_dict["input_output_address"]
self.is_buffer_layer = fields_dict["is_buffer_layer"]
self.buffer_address = fields_dict["buffer_address"]
+ self.buffer_pass = fields_dict["buffer_pass"]
+ self.buffer_parent = fields_dict["buffer_parent"]
self.initialized_inside_model = fields_dict["initialized_inside_model"]
self.has_internally_initialized_ancestor = fields_dict[
"has_internally_initialized_ancestor"
@@ -553,6 +556,7 @@ def __init__(self, source_entry: TensorLogEntry):
self.is_last_output_layer = source_entry.is_last_output_layer
self.is_buffer_layer = source_entry.is_buffer_layer
self.buffer_address = source_entry.buffer_address
+ self.buffer_pass = source_entry.buffer_pass
self.input_output_address = source_entry.input_output_address
self.cond_branch_start_children = source_entry.cond_branch_start_children
self.is_terminal_bool_layer = source_entry.is_terminal_bool_layer
@@ -805,6 +809,7 @@ def __init__(
self.input_layers: List[str] = []
self.output_layers: List[str] = []
self.buffer_layers: List[str] = []
+ self.buffer_num_passes: Dict = {}
self.internally_initialized_layers: List[str] = []
self.layers_where_internal_branches_merge_with_input: List[str] = []
self.internally_terminated_layers: List[str] = []
@@ -1000,6 +1005,12 @@ def wrapped_func(*args, **kwargs):
all_args = list(args) + list(kwargs.values())
arg_tensorlike = get_vars_of_type_from_obj(all_args, torch.Tensor)
+ # Register any buffer tensors in the arguments.
+
+ for t in arg_tensorlike:
+ if hasattr(t, 'tl_buffer_address'):
+ self.log_source_tensor(t, 'buffer', getattr(t, 'tl_buffer_address'))
+
if (func_name in print_funcs) and (len(arg_tensorlike) > 0):
out = print_override(args[0], func_name)
return out
@@ -1085,7 +1096,9 @@ def decorate_pytorch(
continue
new_func = self.torch_func_decorator(orig_func)
try:
- setattr(local_func_namespace, func_name, new_func)
+ with warnings.catch_warnings():
+ warnings.simplefilter("ignore")
+ setattr(local_func_namespace, func_name, new_func)
except (AttributeError, TypeError) as _:
pass
new_func.tl_is_decorated_function = True
@@ -1117,10 +1130,14 @@ def undecorate_pytorch(
for namespace_name, func_name, orig_func in orig_func_defs:
namespace_name_notorch = namespace_name.replace("torch.", "")
local_func_namespace = nested_getattr(torch_module, namespace_name_notorch)
- decorated_func = getattr(local_func_namespace, func_name)
+ with warnings.catch_warnings():
+ warnings.simplefilter("ignore")
+ decorated_func = getattr(local_func_namespace, func_name)
del decorated_func
try:
- setattr(local_func_namespace, func_name, orig_func)
+ with warnings.catch_warnings():
+ warnings.simplefilter("ignore")
+ setattr(local_func_namespace, func_name, orig_func)
except (AttributeError, TypeError) as _:
continue
delattr(torch, "identity")
@@ -1260,7 +1277,9 @@ def prepare_buffer_tensors(self, model: nn.Module):
submodules = self.get_all_submodules(model)
for submodule in submodules:
for attribute_name in dir(submodule):
- attribute = getattr(submodule, attribute_name)
+ with warnings.catch_warnings():
+ warnings.simplefilter("ignore")
+ attribute = getattr(submodule, attribute_name)
if issubclass(type(attribute), torch.Tensor) and not issubclass(
type(attribute), torch.nn.Parameter
):
@@ -1270,7 +1289,7 @@ def prepare_buffer_tensors(self, model: nn.Module):
buffer_address = (
submodule.tl_module_address + "." + attribute_name
)
- self.log_source_tensor(attribute, "buffer", buffer_address)
+ setattr(attribute, 'tl_buffer_address', buffer_address)
def module_forward_decorator(
self, orig_forward: Callable, module: nn.Module
@@ -1298,6 +1317,8 @@ def decorated_forward(*args, **kwargs):
)
input_tensor_labels = set()
for t in input_tensors:
+ if (not hasattr(t, 'tl_tensor_label_raw')) and hasattr(t, 'tl_buffer_address'):
+ self.log_source_tensor(t, 'buffer', getattr(t, 'tl_buffer_address'))
tensor_entry = self.raw_tensor_dict[t.tl_tensor_label_raw]
input_tensor_labels.add(t.tl_tensor_label_raw)
module.tl_tensors_entered_labels.append(t.tl_tensor_label_raw)
@@ -1308,6 +1329,18 @@ def decorated_forward(*args, **kwargs):
("+", module_pass_label[0], module_pass_label[1])
)
+ # Check the buffers.
+ for buffer_name, buffer_tensor in module.named_buffers():
+ if hasattr(buffer_tensor, 'tl_buffer_address'):
+ continue
+ if module.tl_module_address == '':
+ buffer_address = buffer_name
+ else:
+ buffer_address = f"{module.tl_module_address}.{buffer_name}"
+ buffer_tensor.tl_buffer_address = buffer_address
+ buffer_tensor.tl_buffer_parent = buffer_tensor.tl_tensor_label_raw
+ delattr(buffer_tensor, 'tl_tensor_label_raw')
+
# The function call
out = orig_forward(*args, **kwargs)
@@ -1484,13 +1517,17 @@ def restore_module_attributes(
if attribute_name.startswith(attribute_keyword):
delattr(module, attribute_name)
continue
- attr = getattr(module, attribute_name)
+ with warnings.catch_warnings():
+ warnings.simplefilter("ignore")
+ attr = getattr(module, attribute_name)
if (
isinstance(attr, Callable)
and (attr in decorated_func_mapper)
and (attribute_name[0:2] != "__")
):
- setattr(module, attribute_name, decorated_func_mapper[attr])
+ with warnings.catch_warnings():
+ warnings.simplefilter("ignore")
+ setattr(module, attribute_name, decorated_func_mapper[attr])
def restore_model_attributes(
self,
@@ -1533,12 +1570,18 @@ def undecorate_model_tensors(self, model: nn.Module):
submodules = self.get_all_submodules(model)
for submodule in submodules:
for attribute_name in dir(submodule):
- attribute = getattr(submodule, attribute_name)
+ with warnings.catch_warnings():
+ warnings.simplefilter("ignore")
+ attribute = getattr(submodule, attribute_name)
if issubclass(type(attribute), torch.Tensor):
if not issubclass(type(attribute), torch.nn.Parameter) and hasattr(
attribute, "tl_tensor_label_raw"
):
delattr(attribute, "tl_tensor_label_raw")
+ if hasattr(attribute, 'tl_buffer_address'):
+ delattr(attribute, "tl_buffer_address")
+ if hasattr(attribute, 'tl_buffer_parent'):
+ delattr(attribute, "tl_buffer_parent")
else:
remove_attributes_starting_with_str(attribute, "tl_")
elif type(attribute) in [list, tuple, set]:
@@ -1547,12 +1590,20 @@ def undecorate_model_tensors(self, model: nn.Module):
item, "tl_tensor_label_raw"
):
delattr(item, "tl_tensor_label_raw")
+ if hasattr(item, 'tl_buffer_address'):
+ delattr(item, "tl_buffer_address")
+ if hasattr(item, 'tl_buffer_parent'):
+ delattr(item, "tl_buffer_parent")
elif type(attribute) == dict:
for key, val in attribute.items():
if issubclass(type(val), torch.Tensor) and hasattr(
val, "tl_tensor_label_raw"
):
delattr(val, "tl_tensor_label_raw")
+ if hasattr(val, 'tl_buffer_address'):
+ delattr(val, "tl_buffer_address")
+ if hasattr(val, 'tl_buffer_parent'):
+ delattr(val, "tl_buffer_parent")
def get_op_nums_from_user_labels(
self, which_layers: Union[str, List[Union[str, int]]]
@@ -1889,6 +1940,7 @@ def log_source_tensor_exhaustive(
input_output_address = extra_addr
is_buffer_layer = False
buffer_address = None
+ buffer_parent = None
initialized_inside_model = False
has_internally_initialized_ancestor = False
input_ancestors = {tensor_label}
@@ -1907,6 +1959,10 @@ def log_source_tensor_exhaustive(
internally_initialized_ancestors = {tensor_label}
input_ancestors = set()
operation_equivalence_type = f"buffer_{extra_addr}"
+ if hasattr(t, 'tl_buffer_parent'):
+ buffer_parent = t.tl_buffer_parent
+ else:
+ buffer_parent = None
else:
raise ValueError("source must be either 'input' or 'buffer'")
@@ -2013,6 +2069,8 @@ def log_source_tensor_exhaustive(
"input_output_address": input_output_address,
"is_buffer_layer": is_buffer_layer,
"buffer_address": buffer_address,
+ "buffer_pass": None,
+ "buffer_parent": buffer_parent,
"initialized_inside_model": initialized_inside_model,
"has_internally_initialized_ancestor": has_internally_initialized_ancestor,
"internally_initialized_parents": [],
@@ -2247,6 +2305,8 @@ def log_function_output_tensors_exhaustive(
fields_dict["input_output_address"] = None
fields_dict["is_buffer_layer"] = False
fields_dict["buffer_address"] = None
+ fields_dict["buffer_pass"] = None
+ fields_dict["buffer_parent"] = None
fields_dict["initialized_inside_model"] = len(parent_layer_labels) == 0
fields_dict["has_internally_initialized_ancestor"] = (
len(internally_initialized_ancestors) > 0
@@ -3077,8 +3137,10 @@ def _remove_log_entry(
else:
tensor_label = log_entry.tensor_label_raw
for attr in dir(log_entry):
- if not attr.startswith("_") and not callable(getattr(log_entry, attr)):
- delattr(log_entry, attr)
+ with warnings.catch_warnings():
+ warnings.simplefilter("ignore")
+ if not attr.startswith("_") and not callable(getattr(log_entry, attr)):
+ delattr(log_entry, attr)
del log_entry
if remove_references:
self._remove_log_entry_references(tensor_label)
@@ -3170,34 +3232,38 @@ def postprocess(
self._fix_modules_for_internal_tensors()
- # Step 7: Identify all loops, mark repeated layers.
+ # Step 7: Fix the buffer passes and parent infomration.
+
+ self._fix_buffer_layers()
+
+ # Step 8: Identify all loops, mark repeated layers.
self._assign_corresponding_tensors_to_same_layer()
- # Step 8: Go down tensor list, get the mapping from raw tensor names to final tensor names.
+ # Step 9: Go down tensor list, get the mapping from raw tensor names to final tensor names.
self._map_raw_tensor_labels_to_final_tensor_labels()
- # Step 9: Go through and log information pertaining to all layers:
+ # Step 10: Go through and log information pertaining to all layers:
self._log_final_info_for_all_layers()
- # Step 10: Rename the raw tensor entries in the fields of ModelHistory:
+ # Step 11: Rename the raw tensor entries in the fields of ModelHistory:
self._rename_model_history_layer_names()
self._trim_and_reorder_model_history_fields()
- # Step 11: And one more pass to delete unused layers from the record and do final tidying up:
+ # Step 12: And one more pass to delete unused layers from the record and do final tidying up:
self._remove_unwanted_entries_and_log_remaining()
- # Step 12: Undecorate all saved tensors and remove saved grad_fns.
+ # Step 13: Undecorate all saved tensors and remove saved grad_fns.
self._undecorate_all_saved_tensors()
- # Step 13: Clear the cache after any tensor deletions for garbage collection purposes:
+ # Step 14: Clear the cache after any tensor deletions for garbage collection purposes:
torch.cuda.empty_cache()
- # Step 14: Log time elapsed.
+ # Step 15: Log time elapsed.
self._log_time_elapsed()
- # Step 15: log the pass as finished, changing the ModelHistory behavior to its user-facing version.
+ # Step 16: log the pass as finished, changing the ModelHistory behavior to its user-facing version.
self._set_pass_finished()
@@ -4152,7 +4218,7 @@ def _fix_modules_for_single_internal_tensor(
for enter_or_exit, module_address, module_pass in thread_modules[::step_val]:
module_pass_label = (module_address, module_pass)
if node_type_to_fix == "parent":
- if enter_or_exit == "+":
+ if (enter_or_exit == "+") and (module_pass_label in node_to_fix.containing_modules_origin_nested):
node_to_fix.containing_modules_origin_nested.remove(
module_pass_label
)
@@ -4172,6 +4238,111 @@ def _fix_modules_for_single_internal_tensor(
node_stack.append(node_to_fix_label)
nodes_seen.add(node_to_fix_label)
+ def _fix_buffer_layers(self):
+ """Connect the buffer parents, merge duplicate buffer nodes, and label buffer passes correctly.
+ Buffers are duplicates if they happen in the same module, have the same value, and have the same parents.
+ """
+ buffer_counter = defaultdict(lambda: 1)
+ buffer_hash_groups = defaultdict(list)
+
+ for layer_label in self.buffer_layers:
+ layer = self[layer_label]
+ if layer.buffer_parent is not None:
+ layer.parent_layers.append(layer.buffer_parent)
+ self[layer.buffer_parent].child_layers.append(layer_label)
+ layer.func_applied = identity
+ layer.func_applied_name = 'identity'
+ layer.has_input_ancestor = True
+ layer.input_ancestors.update(self[layer.buffer_parent].input_ancestors)
+ layer.orig_ancestors.remove(layer.tensor_label_raw)
+ layer.orig_ancestors.update(self[layer.buffer_parent].orig_ancestors)
+ layer.parent_layer_arg_locs['args'][0] = layer.buffer_parent
+ if self[layer.buffer_parent].tensor_contents is not None:
+ layer.creation_args.append(self[layer.buffer_parent].tensor_contents.detach().clone())
+
+ buffer_hash = str(layer.containing_modules_origin_nested) + str(layer.buffer_parent) + layer.buffer_address
+ buffer_hash_groups[buffer_hash].append(layer_label)
+
+ # Now go through and merge any layers with the same hash and the same value.
+ for _, buffers_orig in buffer_hash_groups.items():
+ buffers = buffers_orig[1:]
+ unique_buffers = buffers_orig[0:]
+ for b, buffer_label in enumerate(buffers):
+ for unique_buffer_label in unique_buffers:
+ buffer = self[buffer_label]
+ unique_buffer = self[unique_buffer_label]
+ if ((buffer.tensor_contents is not None) and (unique_buffer.tensor_contents is not None) and
+ (torch.equal(buffer.tensor_contents, unique_buffer.tensor_contents))):
+ self._merge_buffer_entries(unique_buffer, buffer)
+ break
+ unique_buffers.append(buffer)
+
+ # And relabel the buffer passes.
+
+ for layer_label in self.buffer_layers:
+ layer = self[layer_label]
+ buffer_address = layer.buffer_address
+ layer.buffer_pass = buffer_counter[buffer_address]
+ self.buffer_num_passes[buffer_address] = buffer_counter[buffer_address]
+ buffer_counter[buffer_address] += 1
+
+ def _merge_buffer_entries(self, source_buffer: TensorLogEntry,
+ buffer_to_remove: TensorLogEntry):
+ """Merges two identical buffer layers.
+ """
+ for child_layer in buffer_to_remove.child_layers:
+ if child_layer not in source_buffer.child_layers:
+ source_buffer.child_layers.append(child_layer)
+ self[child_layer].parent_layers.remove(buffer_to_remove.tensor_label_raw)
+ self[child_layer].parent_layers.append(source_buffer.tensor_label_raw)
+ if buffer_to_remove.tensor_label_raw in self[child_layer].internally_initialized_parents:
+ self[child_layer].internally_initialized_parents.remove(buffer_to_remove.tensor_label_raw)
+ self[child_layer].internally_initialized_parents.append(source_buffer.tensor_label_raw)
+
+ for arg_type in ['args', 'kwargs']:
+ for arg_label, arg_val in self[child_layer].parent_layer_arg_locs[arg_type].items():
+ if arg_val == buffer_to_remove.tensor_label_raw:
+ self[child_layer].parent_layer_arg_locs[arg_type][arg_label] = source_buffer.tensor_label_raw
+
+ for parent_layer in buffer_to_remove.parent_layers:
+ if parent_layer not in source_buffer.parent_layers:
+ source_buffer.parent_layers.append(parent_layer)
+ self[parent_layer].child_layers.remove(buffer_to_remove.tensor_label_raw)
+ self[parent_layer].child_layers.append(source_buffer.tensor_label_raw)
+
+ for parent_layer in buffer_to_remove.internally_initialized_parents:
+ if parent_layer not in source_buffer.internally_initialized_parents:
+ source_buffer.internally_initialized_parents.append(parent_layer)
+
+ if buffer_to_remove.tensor_label_raw in source_buffer.spouse_layers:
+ source_buffer.spouse_layers.remove(buffer_to_remove.tensor_label_raw)
+
+ if buffer_to_remove.tensor_label_raw in source_buffer.sibling_layers:
+ source_buffer.sibling_layers.remove(buffer_to_remove.tensor_label_raw)
+
+ for spouse_layer in buffer_to_remove.spouse_layers:
+ if buffer_to_remove.tensor_label_raw in self[spouse_layer].spouse_layers:
+ self[spouse_layer].spouse_layers.remove(buffer_to_remove.tensor_label_raw)
+ self[spouse_layer].spouse_layers.append(source_buffer.tensor_label_raw)
+
+ for sibling_layer in buffer_to_remove.sibling_layers:
+ if buffer_to_remove.tensor_label_raw in self[sibling_layer].sibling_layers:
+ self[sibling_layer].spouse_layers.remove(buffer_to_remove.tensor_label_raw)
+ self[sibling_layer].spouse_layers.append(source_buffer.tensor_label_raw)
+
+ self.raw_tensor_labels_list.remove(buffer_to_remove.tensor_label_raw)
+ self.raw_tensor_dict.pop(buffer_to_remove.tensor_label_raw)
+
+ for layer in self:
+ if buffer_to_remove.tensor_label_raw in layer.orig_ancestors:
+ layer.orig_ancestors.remove(buffer_to_remove.tensor_label_raw)
+ layer.orig_ancestors.add(source_buffer.tensor_label_raw)
+ if buffer_to_remove.tensor_label_raw in layer.internally_initialized_ancestors:
+ layer.internally_initialized_ancestors.remove(buffer_to_remove.tensor_label_raw)
+ layer.internally_initialized_ancestors.add(source_buffer.tensor_label_raw)
+
+ self._remove_log_entry(buffer_to_remove, remove_references=True)
+
def _map_raw_tensor_labels_to_final_tensor_labels(self):
"""
Determines the final label for each tensor, and stores this mapping as a dictionary
@@ -4564,7 +4735,9 @@ def _add_lookup_keys_for_tensor_entry(
# Allow using buffer/input/output address as key, too:
if tensor_entry.is_buffer_layer:
- lookup_keys_for_tensor.append(tensor_entry.buffer_address)
+ if self.buffer_num_passes[tensor_entry.buffer_address] == 1:
+ lookup_keys_for_tensor.append(tensor_entry.buffer_address)
+ lookup_keys_for_tensor.append(f"{tensor_entry.buffer_address}:{tensor_entry.buffer_pass}")
elif tensor_entry.is_input_layer or tensor_entry.is_output_layer:
lookup_keys_for_tensor.append(tensor_entry.input_output_address)
@@ -4664,7 +4837,9 @@ def _trim_and_reorder_model_history_fields(self):
new_dir_dict[field] = getattr(self, field)
for field in dir(self):
if field.startswith("_"):
- new_dir_dict[field] = getattr(self, field)
+ with warnings.catch_warnings():
+ warnings.simplefilter("ignore")
+ new_dir_dict[field] = getattr(self, field)
self.__dict__ = new_dir_dict
def _undecorate_all_saved_tensors(self):
@@ -5053,7 +5228,12 @@ def _get_node_address_shape_color(
node_shape = "box"
node_color = "black"
elif node.is_buffer_layer:
- node_address = "
@" + node.buffer_address
+ if ((self.buffer_num_passes[node.buffer_address] == 1) or
+ (isinstance(node, RolledTensorLogEntry) and node.layer_passes_total > 1)):
+ buffer_address = node.buffer_address
+ else:
+ buffer_address = f"{node.buffer_address}:{node.buffer_pass}"
+ node_address = "
@" + buffer_address
node_shape = "box"
node_color = self.BUFFER_NODE_COLOR
elif node.is_output_layer or node.is_input_layer:
@@ -5225,6 +5405,9 @@ def _add_edges_for_node(
f"vis_opt must be 'unrolled' or 'rolled', not {vis_opt}"
)
+ if child_node.is_buffer_layer and not show_buffer_layers:
+ continue
+
if parent_node.has_input_ancestor:
edge_style = "solid"
else:
@@ -5296,7 +5479,7 @@ def _add_edges_for_node(
}
# Mark with "if" in the case edge starts a cond branch
- if child_layer_label in parent_node.cond_branch_start_children:
+ if (child_layer_label in parent_node.cond_branch_start_children) and (not child_is_collapsed_module):
edge_dict["label"] = '<IF>'
# Label the arguments to the next node if multiple inputs
@@ -5369,7 +5552,7 @@ def _label_node_arguments_if_needed(
):
arg_labels.append(f"{arg_type[:-1]} {str(arg_loc)}")
- arg_labels = "\n".join(arg_labels)
+ arg_labels = "
".join(arg_labels)
arg_label = f"<{arg_labels}>"
if "label" not in edge_dict:
edge_dict["label"] = arg_label
@@ -5502,11 +5685,12 @@ def _get_lowest_containing_module_for_two_nodes(
containing_module = node1_modules[-1]
return containing_module
- if (node1_nestmodules == node2_nestmodules) and both_nodes_collapsed_modules:
+ if both_nodes_collapsed_modules:
if (vis_nesting_depth == 1) or (len(node1_nestmodules) == 1):
return -1
- containing_module = node1_modules[vis_nesting_depth - 2]
- return containing_module
+ if node1_modules[vis_nesting_depth - 1] == node2_modules[vis_nesting_depth - 1]:
+ containing_module = node1_modules[vis_nesting_depth - 2]
+ return containing_module
containing_module = node1_modules[0]
for m in range(min([len(node1_modules), len(node2_modules)])):
@@ -5833,7 +6017,8 @@ def validate_parents_of_saved_layer(
parent_layer.child_layers
):
validated_layers.add(parent_layer_label)
- if not (parent_layer.is_input_layer or parent_layer.is_buffer_layer):
+ if ((not parent_layer.is_input_layer) and
+ not (parent_layer.is_buffer_layer and (parent_layer.buffer_parent is None))):
layers_to_validate_parents_for.append(parent_layer_label)
return True
diff --git a/torchlens/user_funcs.py b/torchlens/user_funcs.py
index ddc5def..f5deb73 100644
--- a/torchlens/user_funcs.py
+++ b/torchlens/user_funcs.py
@@ -272,12 +272,14 @@ def validate_saved_activations(
input_kwargs = {}
input_args_copy = [copy.deepcopy(arg) for arg in input_args]
input_kwargs_copy = {key: copy.deepcopy(val) for key, val in input_kwargs.items()}
+ state_dict = model.state_dict()
ground_truth_output_tensors = get_vars_of_type_from_obj(
model(*input_args_copy, **input_kwargs_copy),
torch.Tensor,
search_depth=5,
allow_repeats=True,
)
+ model.load_state_dict(state_dict)
model_history = run_model_and_save_specified_activations(
model=model,
input_args=input_args,