Skip to content

Commit

Permalink
Merge pull request #17 from johnmarktaylor91/jmt-repeated-buffer-passes
Browse files Browse the repository at this point in the history
Repeated Buffer Passes
  • Loading branch information
johnmarktaylor91 authored Oct 12, 2023
2 parents 98e9fd1 + 7caa777 commit 04ffe48
Show file tree
Hide file tree
Showing 6 changed files with 330 additions and 32 deletions.
35 changes: 35 additions & 0 deletions tests/example_models.py
Original file line number Diff line number Diff line change
Expand Up @@ -427,6 +427,41 @@ def forward(self, x):
return x


class BufferRewriteModule(nn.Module):
def __init__(self):
super().__init__()
self.register_buffer("buffer1", torch.rand(12, 12))
self.register_buffer("buffer2", torch.rand(12, 12))

def forward(self, x):
x = torch.sin(x)
x = x + self.buffer1
x = x * self.buffer2
self.buffer1 = torch.rand(12, 12)
self.buffer2 = x ** 2
x = self.buffer1 + self.buffer2
return x


class BufferRewriteModel(nn.Module):
def __init__(self):
super().__init__()
self.buffer_mod = BufferRewriteModule()

def forward(self, x):
x = torch.cos(x)
x = self.buffer_mod(x)
x = x * 4
x = self.buffer_mod(x)
x = self.buffer_mod(x)
x = x + 1
x = self.buffer_mod(x)
x = self.buffer_mod(x)
x = self.buffer_mod(x)
x = x * 2
return x


class SimpleBranching(nn.Module):
def __init__(self):
super().__init__()
Expand Down
66 changes: 66 additions & 0 deletions tests/test_validation_and_visuals.py
Original file line number Diff line number Diff line change
Expand Up @@ -522,6 +522,72 @@ def test_buffer_model():
)


def test_buffer_rewrite_model():
model = example_models.BufferRewriteModel()
model_input = torch.rand(12, 12)
assert validate_saved_activations(model, model_input)
show_model_graph(
model,
model_input,
vis_opt="unrolled",
vis_nesting_depth=1,
vis_outpath=opj("visualization_outputs", "toy-networks", "buffer_rewrite_model_visible_unnested_unrolled"),
vis_buffer_layers=True,
)
show_model_graph(
model,
model_input,
vis_opt="unrolled",
vis_nesting_depth=1,
vis_outpath=opj("visualization_outputs", "toy-networks", "buffer_rewrite_model_invisible_unnested_unrolled"),
vis_buffer_layers=False,
)
show_model_graph(
model,
model_input,
vis_opt="unrolled",
vis_outpath=opj("visualization_outputs", "toy-networks", "buffer_rewrite_model_visible_nested_unrolled"),
vis_buffer_layers=True,
)
show_model_graph(
model,
model_input,
vis_opt="unrolled",
vis_outpath=opj("visualization_outputs", "toy-networks", "buffer_rewrite_model_invisible_nested_unrolled"),
vis_buffer_layers=False,
)
show_model_graph(
model,
model_input,
vis_opt="rolled",
vis_nesting_depth=1,
vis_outpath=opj("visualization_outputs", "toy-networks", "buffer_rewrite_model_visible_unnested_rolled"),
vis_buffer_layers=True,
)
show_model_graph(
model,
model_input,
vis_opt="rolled",
vis_nesting_depth=1,
vis_outpath=opj("visualization_outputs", "toy-networks", "buffer_rewrite_model_invisible_unnested_rolled"),
vis_buffer_layers=False,
)
show_model_graph(
model,
model_input,
vis_opt="rolled",
vis_outpath=opj("visualization_outputs", "toy-networks", "buffer_rewrite_model_visible_nested_rolled"),
vis_buffer_layers=True,
)
show_model_graph(
model,
model_input,
vis_opt="rolled",
vis_outpath=opj("visualization_outputs", "toy-networks", "buffer_rewrite_model_invisible_nested_rolled"),
vis_buffer_layers=False,
)


def test_simple_branching(default_input1):
model = example_models.SimpleBranching()
assert validate_saved_activations(model, default_input1)
Expand Down
3 changes: 3 additions & 0 deletions torchlens/constants.py
Original file line number Diff line number Diff line change
Expand Up @@ -56,6 +56,7 @@
"input_layers",
"output_layers",
"buffer_layers",
"buffer_num_passes",
"internally_initialized_layers",
"layers_where_internal_branches_merge_with_input",
"internally_terminated_layers",
Expand Down Expand Up @@ -207,6 +208,8 @@
"max_distance_from_output",
"is_buffer_layer",
"buffer_address",
"buffer_pass",
"buffer_parent",
"initialized_inside_model",
"has_internally_initialized_ancestor",
"internally_initialized_parents",
Expand Down
11 changes: 9 additions & 2 deletions torchlens/helper_funcs.py
Original file line number Diff line number Diff line change
Expand Up @@ -356,7 +356,9 @@ def extend_search_stack_from_item(
if (attr_name.startswith("__")) or (attr_name == 'T') or ('grad' in attr_name):
continue
try:
attr = getattr(item, attr_name)
with warnings.catch_warnings():
warnings.simplefilter("ignore")
attr = getattr(item, attr_name)
except:
continue
attr_cls = type(attr)
Expand Down Expand Up @@ -417,6 +419,9 @@ def nested_getattr(obj: Any, attr: str) -> Any:
if a in [
"volatile",
"T",
'H',
'mH',
'mT'
]: # avoid annoying warning; if there's more, make a list
with warnings.catch_warnings():
warnings.simplefilter("ignore")
Expand All @@ -438,7 +443,9 @@ def nested_assign(obj, addr, val):
if entry_type == "ind":
obj = obj[entry_val]
elif entry_type == "attr":
obj = getattr(obj, entry_val)
with warnings.catch_warnings():
warnings.simplefilter("ignore")
obj = getattr(obj, entry_val)


def remove_attributes_starting_with_str(obj: Any, s: str):
Expand Down
Loading

0 comments on commit 04ffe48

Please sign in to comment.