Skip to content

Commit

Permalink
Small fix with scalar parameters.
Browse files Browse the repository at this point in the history
  • Loading branch information
JohnMark Taylor committed Oct 3, 2023
1 parent 51ab52a commit 234e8a9
Show file tree
Hide file tree
Showing 3 changed files with 14 additions and 16 deletions.
2 changes: 1 addition & 1 deletion setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@

setup(
name="torchlens",
version="0.1.9",
version="0.1.10",
description="A package for extracting activations from PyTorch models",
long_description="A package for extracting activations from PyTorch models. Contains functionality for "
"extracting model activations, visualizing a model's computational graph, and "
Expand Down
15 changes: 1 addition & 14 deletions tests/test_validation_and_visuals.py
Original file line number Diff line number Diff line change
Expand Up @@ -2330,19 +2330,6 @@ def test_timm_beit_base_patch16_224(default_input1):
assert validate_saved_activations(model, default_input1)


def test_timm_beit_base_patch16_224_in22k(default_input1):
model = timm.models.beit_base_patch16_224_in22k()
show_model_graph(
model,
default_input1,
vis_opt="unrolled",
vis_outpath=opj(
"visualization_outputs", "timm", "timm_beit_base_patch16_224_in22k"
),
)
assert validate_saved_activations(model, default_input1)


def test_timm_cait_s24_224(default_input1):
model = timm.models.cait_s24_224()
show_model_graph(
Expand Down Expand Up @@ -2377,7 +2364,7 @@ def test_timm_convit_base(default_input1):


def test_timm_darknet21():
model = timm.create_model("darknet21", pretrained=True)
model = timm.create_model("darknet21", pretrained=False)
model_input = torch.randn(1, 3, 224, 224)
show_model_graph(
model,
Expand Down
13 changes: 12 additions & 1 deletion torchlens/model_history.py
Original file line number Diff line number Diff line change
Expand Up @@ -5180,8 +5180,10 @@ def _make_param_label(node: Union[TensorLogEntry, RolledTensorLogEntry]) -> str:
for param_shape in node.parent_param_shapes:
if len(param_shape) > 1:
each_param_shape.append("x".join([str(s) for s in param_shape]))
else:
elif len(param_shape) == 1:
each_param_shape.append(f"x{param_shape[0]}")
else:
each_param_shape.append("x1")

param_label = "<br/>params: " + ", ".join(
[param_shape for param_shape in each_param_shape]
Expand Down Expand Up @@ -6009,6 +6011,15 @@ def _check_whether_func_on_saved_parents_yields_saved_tensor(
self[layers_to_perturb[0]].tensor_contents,
layer_to_validate_parents_for.creation_args[1],
)
):
return True
elif (
perturb
and (layer_to_validate_parents_for.func_applied_name == "__getitem__")
and not torch.equal(
self[layers_to_perturb[0]].tensor_contents,
layer_to_validate_parents_for.creation_args[0],
)
):
return True
elif layer_to_validate_parents_for.func_applied_name == 'empty_like':
Expand Down

0 comments on commit 234e8a9

Please sign in to comment.