diff --git a/tests/example_models.py b/tests/example_models.py index da9dde7..4c75574 100644 --- a/tests/example_models.py +++ b/tests/example_models.py @@ -427,6 +427,41 @@ def forward(self, x): return x +class BufferRewriteModule(nn.Module): + def __init__(self): + super().__init__() + self.register_buffer("buffer1", torch.rand(12, 12)) + self.register_buffer("buffer2", torch.rand(12, 12)) + + def forward(self, x): + x = torch.sin(x) + x = x + self.buffer1 + x = x * self.buffer2 + self.buffer1 = torch.rand(12, 12) + self.buffer2 = x ** 2 + x = self.buffer1 + self.buffer2 + return x + + +class BufferRewriteModel(nn.Module): + def __init__(self): + super().__init__() + self.buffer_mod = BufferRewriteModule() + + def forward(self, x): + x = torch.cos(x) + x = self.buffer_mod(x) + x = x * 4 + x = self.buffer_mod(x) + x = self.buffer_mod(x) + x = x + 1 + x = self.buffer_mod(x) + x = self.buffer_mod(x) + x = self.buffer_mod(x) + x = x * 2 + return x + + class SimpleBranching(nn.Module): def __init__(self): super().__init__() diff --git a/tests/test_validation_and_visuals.py b/tests/test_validation_and_visuals.py index c005e03..bb0fb6a 100644 --- a/tests/test_validation_and_visuals.py +++ b/tests/test_validation_and_visuals.py @@ -522,6 +522,72 @@ def test_buffer_model(): ) +def test_buffer_rewrite_model(): + model = example_models.BufferRewriteModel() + model_input = torch.rand(12, 12) + assert validate_saved_activations(model, model_input) + show_model_graph( + model, + model_input, + vis_opt="unrolled", + vis_nesting_depth=1, + vis_outpath=opj("visualization_outputs", "toy-networks", "buffer_rewrite_model_visible_unnested_unrolled"), + vis_buffer_layers=True, + ) + show_model_graph( + model, + model_input, + vis_opt="unrolled", + vis_nesting_depth=1, + vis_outpath=opj("visualization_outputs", "toy-networks", "buffer_rewrite_model_invisible_unnested_unrolled"), + vis_buffer_layers=False, + ) + show_model_graph( + model, + model_input, + vis_opt="unrolled", + vis_outpath=opj("visualization_outputs", "toy-networks", "buffer_rewrite_model_visible_nested_unrolled"), + vis_buffer_layers=True, + ) + show_model_graph( + model, + model_input, + vis_opt="unrolled", + vis_outpath=opj("visualization_outputs", "toy-networks", "buffer_rewrite_model_invisible_nested_unrolled"), + vis_buffer_layers=False, + ) + show_model_graph( + model, + model_input, + vis_opt="rolled", + vis_nesting_depth=1, + vis_outpath=opj("visualization_outputs", "toy-networks", "buffer_rewrite_model_visible_unnested_rolled"), + vis_buffer_layers=True, + ) + show_model_graph( + model, + model_input, + vis_opt="rolled", + vis_nesting_depth=1, + vis_outpath=opj("visualization_outputs", "toy-networks", "buffer_rewrite_model_invisible_unnested_rolled"), + vis_buffer_layers=False, + ) + show_model_graph( + model, + model_input, + vis_opt="rolled", + vis_outpath=opj("visualization_outputs", "toy-networks", "buffer_rewrite_model_visible_nested_rolled"), + vis_buffer_layers=True, + ) + show_model_graph( + model, + model_input, + vis_opt="rolled", + vis_outpath=opj("visualization_outputs", "toy-networks", "buffer_rewrite_model_invisible_nested_rolled"), + vis_buffer_layers=False, + ) + + def test_simple_branching(default_input1): model = example_models.SimpleBranching() assert validate_saved_activations(model, default_input1) diff --git a/torchlens/constants.py b/torchlens/constants.py index 5561b80..bcde820 100644 --- a/torchlens/constants.py +++ b/torchlens/constants.py @@ -56,6 +56,7 @@ "input_layers", "output_layers", "buffer_layers", + "buffer_num_passes", "internally_initialized_layers", "layers_where_internal_branches_merge_with_input", "internally_terminated_layers", @@ -207,6 +208,8 @@ "max_distance_from_output", "is_buffer_layer", "buffer_address", + "buffer_pass", + "buffer_parent", "initialized_inside_model", "has_internally_initialized_ancestor", "internally_initialized_parents", diff --git a/torchlens/helper_funcs.py b/torchlens/helper_funcs.py index 89ac14f..eb9c5b5 100644 --- a/torchlens/helper_funcs.py +++ b/torchlens/helper_funcs.py @@ -356,7 +356,9 @@ def extend_search_stack_from_item( if (attr_name.startswith("__")) or (attr_name == 'T') or ('grad' in attr_name): continue try: - attr = getattr(item, attr_name) + with warnings.catch_warnings(): + warnings.simplefilter("ignore") + attr = getattr(item, attr_name) except: continue attr_cls = type(attr) @@ -417,6 +419,9 @@ def nested_getattr(obj: Any, attr: str) -> Any: if a in [ "volatile", "T", + 'H', + 'mH', + 'mT' ]: # avoid annoying warning; if there's more, make a list with warnings.catch_warnings(): warnings.simplefilter("ignore") @@ -438,7 +443,9 @@ def nested_assign(obj, addr, val): if entry_type == "ind": obj = obj[entry_val] elif entry_type == "attr": - obj = getattr(obj, entry_val) + with warnings.catch_warnings(): + warnings.simplefilter("ignore") + obj = getattr(obj, entry_val) def remove_attributes_starting_with_str(obj: Any, s: str): diff --git a/torchlens/model_history.py b/torchlens/model_history.py index 0e21dd8..789f4bf 100644 --- a/torchlens/model_history.py +++ b/torchlens/model_history.py @@ -9,12 +9,13 @@ import types from collections import OrderedDict, defaultdict from typing import Any, Callable, Dict, List, Optional, Set, Tuple, Union +import warnings import graphviz import numpy as np import pandas as pd import torch -from IPython.core.display import display +from IPython.display import display from torch import nn from torchlens.constants import ( @@ -191,6 +192,8 @@ def __init__(self, fields_dict: Dict): self.input_output_address = fields_dict["input_output_address"] self.is_buffer_layer = fields_dict["is_buffer_layer"] self.buffer_address = fields_dict["buffer_address"] + self.buffer_pass = fields_dict["buffer_pass"] + self.buffer_parent = fields_dict["buffer_parent"] self.initialized_inside_model = fields_dict["initialized_inside_model"] self.has_internally_initialized_ancestor = fields_dict[ "has_internally_initialized_ancestor" @@ -553,6 +556,7 @@ def __init__(self, source_entry: TensorLogEntry): self.is_last_output_layer = source_entry.is_last_output_layer self.is_buffer_layer = source_entry.is_buffer_layer self.buffer_address = source_entry.buffer_address + self.buffer_pass = source_entry.buffer_pass self.input_output_address = source_entry.input_output_address self.cond_branch_start_children = source_entry.cond_branch_start_children self.is_terminal_bool_layer = source_entry.is_terminal_bool_layer @@ -805,6 +809,7 @@ def __init__( self.input_layers: List[str] = [] self.output_layers: List[str] = [] self.buffer_layers: List[str] = [] + self.buffer_num_passes: Dict = {} self.internally_initialized_layers: List[str] = [] self.layers_where_internal_branches_merge_with_input: List[str] = [] self.internally_terminated_layers: List[str] = [] @@ -1000,6 +1005,12 @@ def wrapped_func(*args, **kwargs): all_args = list(args) + list(kwargs.values()) arg_tensorlike = get_vars_of_type_from_obj(all_args, torch.Tensor) + # Register any buffer tensors in the arguments. + + for t in arg_tensorlike: + if hasattr(t, 'tl_buffer_address'): + self.log_source_tensor(t, 'buffer', getattr(t, 'tl_buffer_address')) + if (func_name in print_funcs) and (len(arg_tensorlike) > 0): out = print_override(args[0], func_name) return out @@ -1085,7 +1096,9 @@ def decorate_pytorch( continue new_func = self.torch_func_decorator(orig_func) try: - setattr(local_func_namespace, func_name, new_func) + with warnings.catch_warnings(): + warnings.simplefilter("ignore") + setattr(local_func_namespace, func_name, new_func) except (AttributeError, TypeError) as _: pass new_func.tl_is_decorated_function = True @@ -1117,10 +1130,14 @@ def undecorate_pytorch( for namespace_name, func_name, orig_func in orig_func_defs: namespace_name_notorch = namespace_name.replace("torch.", "") local_func_namespace = nested_getattr(torch_module, namespace_name_notorch) - decorated_func = getattr(local_func_namespace, func_name) + with warnings.catch_warnings(): + warnings.simplefilter("ignore") + decorated_func = getattr(local_func_namespace, func_name) del decorated_func try: - setattr(local_func_namespace, func_name, orig_func) + with warnings.catch_warnings(): + warnings.simplefilter("ignore") + setattr(local_func_namespace, func_name, orig_func) except (AttributeError, TypeError) as _: continue delattr(torch, "identity") @@ -1260,7 +1277,9 @@ def prepare_buffer_tensors(self, model: nn.Module): submodules = self.get_all_submodules(model) for submodule in submodules: for attribute_name in dir(submodule): - attribute = getattr(submodule, attribute_name) + with warnings.catch_warnings(): + warnings.simplefilter("ignore") + attribute = getattr(submodule, attribute_name) if issubclass(type(attribute), torch.Tensor) and not issubclass( type(attribute), torch.nn.Parameter ): @@ -1270,7 +1289,7 @@ def prepare_buffer_tensors(self, model: nn.Module): buffer_address = ( submodule.tl_module_address + "." + attribute_name ) - self.log_source_tensor(attribute, "buffer", buffer_address) + setattr(attribute, 'tl_buffer_address', buffer_address) def module_forward_decorator( self, orig_forward: Callable, module: nn.Module @@ -1298,6 +1317,8 @@ def decorated_forward(*args, **kwargs): ) input_tensor_labels = set() for t in input_tensors: + if (not hasattr(t, 'tl_tensor_label_raw')) and hasattr(t, 'tl_buffer_address'): + self.log_source_tensor(t, 'buffer', getattr(t, 'tl_buffer_address')) tensor_entry = self.raw_tensor_dict[t.tl_tensor_label_raw] input_tensor_labels.add(t.tl_tensor_label_raw) module.tl_tensors_entered_labels.append(t.tl_tensor_label_raw) @@ -1308,6 +1329,18 @@ def decorated_forward(*args, **kwargs): ("+", module_pass_label[0], module_pass_label[1]) ) + # Check the buffers. + for buffer_name, buffer_tensor in module.named_buffers(): + if hasattr(buffer_tensor, 'tl_buffer_address'): + continue + if module.tl_module_address == '': + buffer_address = buffer_name + else: + buffer_address = f"{module.tl_module_address}.{buffer_name}" + buffer_tensor.tl_buffer_address = buffer_address + buffer_tensor.tl_buffer_parent = buffer_tensor.tl_tensor_label_raw + delattr(buffer_tensor, 'tl_tensor_label_raw') + # The function call out = orig_forward(*args, **kwargs) @@ -1484,13 +1517,17 @@ def restore_module_attributes( if attribute_name.startswith(attribute_keyword): delattr(module, attribute_name) continue - attr = getattr(module, attribute_name) + with warnings.catch_warnings(): + warnings.simplefilter("ignore") + attr = getattr(module, attribute_name) if ( isinstance(attr, Callable) and (attr in decorated_func_mapper) and (attribute_name[0:2] != "__") ): - setattr(module, attribute_name, decorated_func_mapper[attr]) + with warnings.catch_warnings(): + warnings.simplefilter("ignore") + setattr(module, attribute_name, decorated_func_mapper[attr]) def restore_model_attributes( self, @@ -1533,12 +1570,18 @@ def undecorate_model_tensors(self, model: nn.Module): submodules = self.get_all_submodules(model) for submodule in submodules: for attribute_name in dir(submodule): - attribute = getattr(submodule, attribute_name) + with warnings.catch_warnings(): + warnings.simplefilter("ignore") + attribute = getattr(submodule, attribute_name) if issubclass(type(attribute), torch.Tensor): if not issubclass(type(attribute), torch.nn.Parameter) and hasattr( attribute, "tl_tensor_label_raw" ): delattr(attribute, "tl_tensor_label_raw") + if hasattr(attribute, 'tl_buffer_address'): + delattr(attribute, "tl_buffer_address") + if hasattr(attribute, 'tl_buffer_parent'): + delattr(attribute, "tl_buffer_parent") else: remove_attributes_starting_with_str(attribute, "tl_") elif type(attribute) in [list, tuple, set]: @@ -1547,12 +1590,20 @@ def undecorate_model_tensors(self, model: nn.Module): item, "tl_tensor_label_raw" ): delattr(item, "tl_tensor_label_raw") + if hasattr(item, 'tl_buffer_address'): + delattr(item, "tl_buffer_address") + if hasattr(item, 'tl_buffer_parent'): + delattr(item, "tl_buffer_parent") elif type(attribute) == dict: for key, val in attribute.items(): if issubclass(type(val), torch.Tensor) and hasattr( val, "tl_tensor_label_raw" ): delattr(val, "tl_tensor_label_raw") + if hasattr(val, 'tl_buffer_address'): + delattr(val, "tl_buffer_address") + if hasattr(val, 'tl_buffer_parent'): + delattr(val, "tl_buffer_parent") def get_op_nums_from_user_labels( self, which_layers: Union[str, List[Union[str, int]]] @@ -1889,6 +1940,7 @@ def log_source_tensor_exhaustive( input_output_address = extra_addr is_buffer_layer = False buffer_address = None + buffer_parent = None initialized_inside_model = False has_internally_initialized_ancestor = False input_ancestors = {tensor_label} @@ -1907,6 +1959,10 @@ def log_source_tensor_exhaustive( internally_initialized_ancestors = {tensor_label} input_ancestors = set() operation_equivalence_type = f"buffer_{extra_addr}" + if hasattr(t, 'tl_buffer_parent'): + buffer_parent = t.tl_buffer_parent + else: + buffer_parent = None else: raise ValueError("source must be either 'input' or 'buffer'") @@ -2013,6 +2069,8 @@ def log_source_tensor_exhaustive( "input_output_address": input_output_address, "is_buffer_layer": is_buffer_layer, "buffer_address": buffer_address, + "buffer_pass": None, + "buffer_parent": buffer_parent, "initialized_inside_model": initialized_inside_model, "has_internally_initialized_ancestor": has_internally_initialized_ancestor, "internally_initialized_parents": [], @@ -2247,6 +2305,8 @@ def log_function_output_tensors_exhaustive( fields_dict["input_output_address"] = None fields_dict["is_buffer_layer"] = False fields_dict["buffer_address"] = None + fields_dict["buffer_pass"] = None + fields_dict["buffer_parent"] = None fields_dict["initialized_inside_model"] = len(parent_layer_labels) == 0 fields_dict["has_internally_initialized_ancestor"] = ( len(internally_initialized_ancestors) > 0 @@ -3077,8 +3137,10 @@ def _remove_log_entry( else: tensor_label = log_entry.tensor_label_raw for attr in dir(log_entry): - if not attr.startswith("_") and not callable(getattr(log_entry, attr)): - delattr(log_entry, attr) + with warnings.catch_warnings(): + warnings.simplefilter("ignore") + if not attr.startswith("_") and not callable(getattr(log_entry, attr)): + delattr(log_entry, attr) del log_entry if remove_references: self._remove_log_entry_references(tensor_label) @@ -3170,34 +3232,38 @@ def postprocess( self._fix_modules_for_internal_tensors() - # Step 7: Identify all loops, mark repeated layers. + # Step 7: Fix the buffer passes and parent infomration. + + self._fix_buffer_layers() + + # Step 8: Identify all loops, mark repeated layers. self._assign_corresponding_tensors_to_same_layer() - # Step 8: Go down tensor list, get the mapping from raw tensor names to final tensor names. + # Step 9: Go down tensor list, get the mapping from raw tensor names to final tensor names. self._map_raw_tensor_labels_to_final_tensor_labels() - # Step 9: Go through and log information pertaining to all layers: + # Step 10: Go through and log information pertaining to all layers: self._log_final_info_for_all_layers() - # Step 10: Rename the raw tensor entries in the fields of ModelHistory: + # Step 11: Rename the raw tensor entries in the fields of ModelHistory: self._rename_model_history_layer_names() self._trim_and_reorder_model_history_fields() - # Step 11: And one more pass to delete unused layers from the record and do final tidying up: + # Step 12: And one more pass to delete unused layers from the record and do final tidying up: self._remove_unwanted_entries_and_log_remaining() - # Step 12: Undecorate all saved tensors and remove saved grad_fns. + # Step 13: Undecorate all saved tensors and remove saved grad_fns. self._undecorate_all_saved_tensors() - # Step 13: Clear the cache after any tensor deletions for garbage collection purposes: + # Step 14: Clear the cache after any tensor deletions for garbage collection purposes: torch.cuda.empty_cache() - # Step 14: Log time elapsed. + # Step 15: Log time elapsed. self._log_time_elapsed() - # Step 15: log the pass as finished, changing the ModelHistory behavior to its user-facing version. + # Step 16: log the pass as finished, changing the ModelHistory behavior to its user-facing version. self._set_pass_finished() @@ -4152,7 +4218,7 @@ def _fix_modules_for_single_internal_tensor( for enter_or_exit, module_address, module_pass in thread_modules[::step_val]: module_pass_label = (module_address, module_pass) if node_type_to_fix == "parent": - if enter_or_exit == "+": + if (enter_or_exit == "+") and (module_pass_label in node_to_fix.containing_modules_origin_nested): node_to_fix.containing_modules_origin_nested.remove( module_pass_label ) @@ -4172,6 +4238,111 @@ def _fix_modules_for_single_internal_tensor( node_stack.append(node_to_fix_label) nodes_seen.add(node_to_fix_label) + def _fix_buffer_layers(self): + """Connect the buffer parents, merge duplicate buffer nodes, and label buffer passes correctly. + Buffers are duplicates if they happen in the same module, have the same value, and have the same parents. + """ + buffer_counter = defaultdict(lambda: 1) + buffer_hash_groups = defaultdict(list) + + for layer_label in self.buffer_layers: + layer = self[layer_label] + if layer.buffer_parent is not None: + layer.parent_layers.append(layer.buffer_parent) + self[layer.buffer_parent].child_layers.append(layer_label) + layer.func_applied = identity + layer.func_applied_name = 'identity' + layer.has_input_ancestor = True + layer.input_ancestors.update(self[layer.buffer_parent].input_ancestors) + layer.orig_ancestors.remove(layer.tensor_label_raw) + layer.orig_ancestors.update(self[layer.buffer_parent].orig_ancestors) + layer.parent_layer_arg_locs['args'][0] = layer.buffer_parent + if self[layer.buffer_parent].tensor_contents is not None: + layer.creation_args.append(self[layer.buffer_parent].tensor_contents.detach().clone()) + + buffer_hash = str(layer.containing_modules_origin_nested) + str(layer.buffer_parent) + layer.buffer_address + buffer_hash_groups[buffer_hash].append(layer_label) + + # Now go through and merge any layers with the same hash and the same value. + for _, buffers_orig in buffer_hash_groups.items(): + buffers = buffers_orig[1:] + unique_buffers = buffers_orig[0:] + for b, buffer_label in enumerate(buffers): + for unique_buffer_label in unique_buffers: + buffer = self[buffer_label] + unique_buffer = self[unique_buffer_label] + if ((buffer.tensor_contents is not None) and (unique_buffer.tensor_contents is not None) and + (torch.equal(buffer.tensor_contents, unique_buffer.tensor_contents))): + self._merge_buffer_entries(unique_buffer, buffer) + break + unique_buffers.append(buffer) + + # And relabel the buffer passes. + + for layer_label in self.buffer_layers: + layer = self[layer_label] + buffer_address = layer.buffer_address + layer.buffer_pass = buffer_counter[buffer_address] + self.buffer_num_passes[buffer_address] = buffer_counter[buffer_address] + buffer_counter[buffer_address] += 1 + + def _merge_buffer_entries(self, source_buffer: TensorLogEntry, + buffer_to_remove: TensorLogEntry): + """Merges two identical buffer layers. + """ + for child_layer in buffer_to_remove.child_layers: + if child_layer not in source_buffer.child_layers: + source_buffer.child_layers.append(child_layer) + self[child_layer].parent_layers.remove(buffer_to_remove.tensor_label_raw) + self[child_layer].parent_layers.append(source_buffer.tensor_label_raw) + if buffer_to_remove.tensor_label_raw in self[child_layer].internally_initialized_parents: + self[child_layer].internally_initialized_parents.remove(buffer_to_remove.tensor_label_raw) + self[child_layer].internally_initialized_parents.append(source_buffer.tensor_label_raw) + + for arg_type in ['args', 'kwargs']: + for arg_label, arg_val in self[child_layer].parent_layer_arg_locs[arg_type].items(): + if arg_val == buffer_to_remove.tensor_label_raw: + self[child_layer].parent_layer_arg_locs[arg_type][arg_label] = source_buffer.tensor_label_raw + + for parent_layer in buffer_to_remove.parent_layers: + if parent_layer not in source_buffer.parent_layers: + source_buffer.parent_layers.append(parent_layer) + self[parent_layer].child_layers.remove(buffer_to_remove.tensor_label_raw) + self[parent_layer].child_layers.append(source_buffer.tensor_label_raw) + + for parent_layer in buffer_to_remove.internally_initialized_parents: + if parent_layer not in source_buffer.internally_initialized_parents: + source_buffer.internally_initialized_parents.append(parent_layer) + + if buffer_to_remove.tensor_label_raw in source_buffer.spouse_layers: + source_buffer.spouse_layers.remove(buffer_to_remove.tensor_label_raw) + + if buffer_to_remove.tensor_label_raw in source_buffer.sibling_layers: + source_buffer.sibling_layers.remove(buffer_to_remove.tensor_label_raw) + + for spouse_layer in buffer_to_remove.spouse_layers: + if buffer_to_remove.tensor_label_raw in self[spouse_layer].spouse_layers: + self[spouse_layer].spouse_layers.remove(buffer_to_remove.tensor_label_raw) + self[spouse_layer].spouse_layers.append(source_buffer.tensor_label_raw) + + for sibling_layer in buffer_to_remove.sibling_layers: + if buffer_to_remove.tensor_label_raw in self[sibling_layer].sibling_layers: + self[sibling_layer].spouse_layers.remove(buffer_to_remove.tensor_label_raw) + self[sibling_layer].spouse_layers.append(source_buffer.tensor_label_raw) + + self.raw_tensor_labels_list.remove(buffer_to_remove.tensor_label_raw) + self.raw_tensor_dict.pop(buffer_to_remove.tensor_label_raw) + + for layer in self: + if buffer_to_remove.tensor_label_raw in layer.orig_ancestors: + layer.orig_ancestors.remove(buffer_to_remove.tensor_label_raw) + layer.orig_ancestors.add(source_buffer.tensor_label_raw) + if buffer_to_remove.tensor_label_raw in layer.internally_initialized_ancestors: + layer.internally_initialized_ancestors.remove(buffer_to_remove.tensor_label_raw) + layer.internally_initialized_ancestors.add(source_buffer.tensor_label_raw) + + self._remove_log_entry(buffer_to_remove, remove_references=True) + def _map_raw_tensor_labels_to_final_tensor_labels(self): """ Determines the final label for each tensor, and stores this mapping as a dictionary @@ -4564,7 +4735,9 @@ def _add_lookup_keys_for_tensor_entry( # Allow using buffer/input/output address as key, too: if tensor_entry.is_buffer_layer: - lookup_keys_for_tensor.append(tensor_entry.buffer_address) + if self.buffer_num_passes[tensor_entry.buffer_address] == 1: + lookup_keys_for_tensor.append(tensor_entry.buffer_address) + lookup_keys_for_tensor.append(f"{tensor_entry.buffer_address}:{tensor_entry.buffer_pass}") elif tensor_entry.is_input_layer or tensor_entry.is_output_layer: lookup_keys_for_tensor.append(tensor_entry.input_output_address) @@ -4664,7 +4837,9 @@ def _trim_and_reorder_model_history_fields(self): new_dir_dict[field] = getattr(self, field) for field in dir(self): if field.startswith("_"): - new_dir_dict[field] = getattr(self, field) + with warnings.catch_warnings(): + warnings.simplefilter("ignore") + new_dir_dict[field] = getattr(self, field) self.__dict__ = new_dir_dict def _undecorate_all_saved_tensors(self): @@ -5053,7 +5228,12 @@ def _get_node_address_shape_color( node_shape = "box" node_color = "black" elif node.is_buffer_layer: - node_address = "
@" + node.buffer_address + if ((self.buffer_num_passes[node.buffer_address] == 1) or + (isinstance(node, RolledTensorLogEntry) and node.layer_passes_total > 1)): + buffer_address = node.buffer_address + else: + buffer_address = f"{node.buffer_address}:{node.buffer_pass}" + node_address = "
@" + buffer_address node_shape = "box" node_color = self.BUFFER_NODE_COLOR elif node.is_output_layer or node.is_input_layer: @@ -5225,6 +5405,9 @@ def _add_edges_for_node( f"vis_opt must be 'unrolled' or 'rolled', not {vis_opt}" ) + if child_node.is_buffer_layer and not show_buffer_layers: + continue + if parent_node.has_input_ancestor: edge_style = "solid" else: @@ -5296,7 +5479,7 @@ def _add_edges_for_node( } # Mark with "if" in the case edge starts a cond branch - if child_layer_label in parent_node.cond_branch_start_children: + if (child_layer_label in parent_node.cond_branch_start_children) and (not child_is_collapsed_module): edge_dict["label"] = '<IF>' # Label the arguments to the next node if multiple inputs @@ -5369,7 +5552,7 @@ def _label_node_arguments_if_needed( ): arg_labels.append(f"{arg_type[:-1]} {str(arg_loc)}") - arg_labels = "\n".join(arg_labels) + arg_labels = "
".join(arg_labels) arg_label = f"<{arg_labels}>" if "label" not in edge_dict: edge_dict["label"] = arg_label @@ -5502,11 +5685,12 @@ def _get_lowest_containing_module_for_two_nodes( containing_module = node1_modules[-1] return containing_module - if (node1_nestmodules == node2_nestmodules) and both_nodes_collapsed_modules: + if both_nodes_collapsed_modules: if (vis_nesting_depth == 1) or (len(node1_nestmodules) == 1): return -1 - containing_module = node1_modules[vis_nesting_depth - 2] - return containing_module + if node1_modules[vis_nesting_depth - 1] == node2_modules[vis_nesting_depth - 1]: + containing_module = node1_modules[vis_nesting_depth - 2] + return containing_module containing_module = node1_modules[0] for m in range(min([len(node1_modules), len(node2_modules)])): @@ -5833,7 +6017,8 @@ def validate_parents_of_saved_layer( parent_layer.child_layers ): validated_layers.add(parent_layer_label) - if not (parent_layer.is_input_layer or parent_layer.is_buffer_layer): + if ((not parent_layer.is_input_layer) and + not (parent_layer.is_buffer_layer and (parent_layer.buffer_parent is None))): layers_to_validate_parents_for.append(parent_layer_label) return True diff --git a/torchlens/user_funcs.py b/torchlens/user_funcs.py index ddc5def..f5deb73 100644 --- a/torchlens/user_funcs.py +++ b/torchlens/user_funcs.py @@ -272,12 +272,14 @@ def validate_saved_activations( input_kwargs = {} input_args_copy = [copy.deepcopy(arg) for arg in input_args] input_kwargs_copy = {key: copy.deepcopy(val) for key, val in input_kwargs.items()} + state_dict = model.state_dict() ground_truth_output_tensors = get_vars_of_type_from_obj( model(*input_args_copy, **input_kwargs_copy), torch.Tensor, search_depth=5, allow_repeats=True, ) + model.load_state_dict(state_dict) model_history = run_model_and_save_specified_activations( model=model, input_args=input_args,