Skip to content

Commit

Permalink
Fixed edge case of models where the buffer gets rewritten during the …
Browse files Browse the repository at this point in the history
…forward pass.
  • Loading branch information
JohnMark Taylor committed Oct 12, 2023
1 parent e05818c commit 7caa777
Show file tree
Hide file tree
Showing 5 changed files with 291 additions and 19 deletions.
35 changes: 35 additions & 0 deletions tests/example_models.py
Original file line number Diff line number Diff line change
Expand Up @@ -427,6 +427,41 @@ def forward(self, x):
return x


class BufferRewriteModule(nn.Module):
def __init__(self):
super().__init__()
self.register_buffer("buffer1", torch.rand(12, 12))
self.register_buffer("buffer2", torch.rand(12, 12))

def forward(self, x):
x = torch.sin(x)
x = x + self.buffer1
x = x * self.buffer2
self.buffer1 = torch.rand(12, 12)
self.buffer2 = x ** 2
x = self.buffer1 + self.buffer2
return x


class BufferRewriteModel(nn.Module):
def __init__(self):
super().__init__()
self.buffer_mod = BufferRewriteModule()

def forward(self, x):
x = torch.cos(x)
x = self.buffer_mod(x)
x = x * 4
x = self.buffer_mod(x)
x = self.buffer_mod(x)
x = x + 1
x = self.buffer_mod(x)
x = self.buffer_mod(x)
x = self.buffer_mod(x)
x = x * 2
return x


class SimpleBranching(nn.Module):
def __init__(self):
super().__init__()
Expand Down
66 changes: 66 additions & 0 deletions tests/test_validation_and_visuals.py
Original file line number Diff line number Diff line change
Expand Up @@ -522,6 +522,72 @@ def test_buffer_model():
)


def test_buffer_rewrite_model():
model = example_models.BufferRewriteModel()
model_input = torch.rand(12, 12)
assert validate_saved_activations(model, model_input)
show_model_graph(
model,
model_input,
vis_opt="unrolled",
vis_nesting_depth=1,
vis_outpath=opj("visualization_outputs", "toy-networks", "buffer_rewrite_model_visible_unnested_unrolled"),
vis_buffer_layers=True,
)
show_model_graph(
model,
model_input,
vis_opt="unrolled",
vis_nesting_depth=1,
vis_outpath=opj("visualization_outputs", "toy-networks", "buffer_rewrite_model_invisible_unnested_unrolled"),
vis_buffer_layers=False,
)
show_model_graph(
model,
model_input,
vis_opt="unrolled",
vis_outpath=opj("visualization_outputs", "toy-networks", "buffer_rewrite_model_visible_nested_unrolled"),
vis_buffer_layers=True,
)
show_model_graph(
model,
model_input,
vis_opt="unrolled",
vis_outpath=opj("visualization_outputs", "toy-networks", "buffer_rewrite_model_invisible_nested_unrolled"),
vis_buffer_layers=False,
)
show_model_graph(
model,
model_input,
vis_opt="rolled",
vis_nesting_depth=1,
vis_outpath=opj("visualization_outputs", "toy-networks", "buffer_rewrite_model_visible_unnested_rolled"),
vis_buffer_layers=True,
)
show_model_graph(
model,
model_input,
vis_opt="rolled",
vis_nesting_depth=1,
vis_outpath=opj("visualization_outputs", "toy-networks", "buffer_rewrite_model_invisible_unnested_rolled"),
vis_buffer_layers=False,
)
show_model_graph(
model,
model_input,
vis_opt="rolled",
vis_outpath=opj("visualization_outputs", "toy-networks", "buffer_rewrite_model_visible_nested_rolled"),
vis_buffer_layers=True,
)
show_model_graph(
model,
model_input,
vis_opt="rolled",
vis_outpath=opj("visualization_outputs", "toy-networks", "buffer_rewrite_model_invisible_nested_rolled"),
vis_buffer_layers=False,
)


def test_simple_branching(default_input1):
model = example_models.SimpleBranching()
assert validate_saved_activations(model, default_input1)
Expand Down
3 changes: 3 additions & 0 deletions torchlens/constants.py
Original file line number Diff line number Diff line change
Expand Up @@ -56,6 +56,7 @@
"input_layers",
"output_layers",
"buffer_layers",
"buffer_num_passes",
"internally_initialized_layers",
"layers_where_internal_branches_merge_with_input",
"internally_terminated_layers",
Expand Down Expand Up @@ -207,6 +208,8 @@
"max_distance_from_output",
"is_buffer_layer",
"buffer_address",
"buffer_pass",
"buffer_parent",
"initialized_inside_model",
"has_internally_initialized_ancestor",
"internally_initialized_parents",
Expand Down
Loading

0 comments on commit 7caa777

Please sign in to comment.