Skip to content

Commit

Permalink
Fixed a bug with decorating some of the PyTorch functions.
Browse files Browse the repository at this point in the history
  • Loading branch information
JohnMark Taylor committed Sep 9, 2023
1 parent f4517e8 commit 8aec5c6
Show file tree
Hide file tree
Showing 5 changed files with 131 additions and 57 deletions.
2 changes: 1 addition & 1 deletion setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@

setup(
name="torchlens",
version="0.1.6",
version="0.1.7",
description="A package for extracting activations from PyTorch models",
long_description="A package for extracting activations from PyTorch models. Contains functionality for "
"extracting model activations, visualizing a model's computational graph, and "
Expand Down
18 changes: 18 additions & 0 deletions tests/test_validation_and_visuals.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
import torchvision
import visualpriors
from PIL import Image
from StyleTTS.models import TextEncoder
from transformers import (
BertForNextSentencePrediction,
BertTokenizer,
Expand Down Expand Up @@ -2696,6 +2697,23 @@ def test_clip(): # for some reason CLIP breaks the PyCharm debugger
assert validate_saved_activations(model, [], model_inputs, random_seed=1)


# Text to speech
def test_styletts():
model = TextEncoder(3, 3, 3, 100)
tokens = torch.tensor([[3, 0, 1, 2, 0, 2, 2, 3, 1, 4]])
input_lengths = torch.ones(1, dtype=torch.long) * 10
m = torch.ones(1, 10)
model_inputs = (tokens, input_lengths, m)
show_model_graph(
model,
model_inputs,
random_seed=1,
vis_opt="unrolled",
vis_outpath=opj("visualization_outputs", "text-to-speech", "styletts_text_encoder"),
)
assert validate_saved_activations(model, model_inputs, random_seed=1)


# Graph neural networks


Expand Down
6 changes: 1 addition & 5 deletions torchlens/constants.py
Original file line number Diff line number Diff line change
Expand Up @@ -347,11 +347,7 @@ def my_get_overridable_functions() -> List:
if namespace is not torch.Tensor:
if func_name.startswith("__"):
continue
elif func_name.startswith("_"):
ignore = True
elif func_name.endswith("_"):
ignore = True
elif not func_name[0].islower():
elif func_name[0].isupper():
ignore = True
elif func_name == "unique_dim":
continue
Expand Down
78 changes: 69 additions & 9 deletions torchlens/model_history.py
Original file line number Diff line number Diff line change
Expand Up @@ -1686,6 +1686,9 @@ def run_and_log_inputs_through_model(
if type(input_args) == torch.Tensor:
input_args = [input_args]

if type(input_args) == tuple:
input_args = list(input_args)

if not input_args:
input_args = []

Expand Down Expand Up @@ -5171,7 +5174,7 @@ def _make_param_label(node: Union[TensorLogEntry, RolledTensorLogEntry]) -> str:
if len(param_shape) > 1:
each_param_shape.append("x".join([str(s) for s in param_shape]))
else:
each_param_shape.append("x1")
each_param_shape.append(f"x{param_shape[0]}")

param_label = "<br/>params: " + ", ".join(
[param_shape for param_shape in each_param_shape]
Expand Down Expand Up @@ -6001,6 +6004,8 @@ def _check_whether_func_on_saved_parents_yields_saved_tensor(
)
):
return True
elif layer_to_validate_parents_for.func_applied_name == 'empty_like':
return True
elif (
perturb
and (layer_to_validate_parents_for.func_applied_name == "__setitem__")
Expand Down Expand Up @@ -6035,6 +6040,59 @@ def _check_whether_func_on_saved_parents_yields_saved_tensor(
)
):
return True
elif (
perturb
and (layer_to_validate_parents_for.func_applied_name == "index_select")
and torch.equal(
self[layers_to_perturb[0]].tensor_contents,
layer_to_validate_parents_for.creation_args[2],
)
):
return True
elif (
perturb
and (layer_to_validate_parents_for.func_applied_name == "lstm")
and (torch.equal(
self[layers_to_perturb[0]].tensor_contents,
layer_to_validate_parents_for.creation_args[1][0]) or
torch.equal(
self[layers_to_perturb[0]].tensor_contents,
layer_to_validate_parents_for.creation_args[1][1]) or
torch.equal(
self[layers_to_perturb[0]].tensor_contents,
layer_to_validate_parents_for.creation_args[1]) or
torch.equal(
self[layers_to_perturb[0]].tensor_contents,
layer_to_validate_parents_for.creation_args[2][0]) or
torch.equal(
self[layers_to_perturb[0]].tensor_contents,
layer_to_validate_parents_for.creation_args[2][1])
)):
return True
elif (
perturb
and (layer_to_validate_parents_for.func_applied_name == "_pad_packed_sequence")
and torch.equal(
self[layers_to_perturb[0]].tensor_contents,
layer_to_validate_parents_for.creation_args[1]
)):
return True
elif (
perturb
and (layer_to_validate_parents_for.func_applied_name == "masked_fill_")
and torch.equal(
self[layers_to_perturb[0]].tensor_contents,
layer_to_validate_parents_for.creation_args[1]
)):
return True
elif (
perturb
and (layer_to_validate_parents_for.func_applied_name == "scatter_")
and torch.equal(
self[layers_to_perturb[0]].tensor_contents,
layer_to_validate_parents_for.creation_args[2]
)):
return True

# Prepare input arguments: keep the ones that should just be kept, perturb those that should be perturbed

Expand Down Expand Up @@ -6220,14 +6278,16 @@ def _perturb_layer_activations(
device=device,
).type(parent_activations.dtype)
else:
if torch.min(parent_activations) < 0:
perturbed_activations = torch.randint(
-10, 11, size=parent_activations.shape, device=device
).type(parent_activations.dtype)
else:
perturbed_activations = torch.randint(
0, 11, size=parent_activations.shape, device=device
).type(parent_activations.dtype)
perturbed_activations = parent_activations.detach().clone()
while torch.equal(perturbed_activations, parent_activations):
if torch.min(parent_activations) < 0:
perturbed_activations = torch.randint(
-10, 11, size=parent_activations.shape, device=device
).type(parent_activations.dtype)
else:
perturbed_activations = torch.randint(
0, 11, size=parent_activations.shape, device=device
).type(parent_activations.dtype)

elif parent_activations.dtype == torch.bool:
perturbed_activations = parent_activations.detach().clone()
Expand Down
84 changes: 42 additions & 42 deletions torchlens/user_funcs.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,25 +21,25 @@


def log_forward_pass(
model: nn.Module,
input_args: Union[torch.Tensor, List[Any], Tuple[Any]],
input_kwargs: Dict[Any, Any] = None,
layers_to_save: Optional[Union[str, List]] = "all",
keep_unsaved_layers: bool = True,
output_device: str = "same",
activation_postfunc: Optional[Callable] = None,
mark_input_output_distances: bool = False,
detach_saved_tensors: bool = False,
save_function_args: bool = False,
save_gradients: bool = False,
vis_opt: str = "none",
vis_nesting_depth: int = 1000,
vis_outpath: str = "graph.gv",
vis_save_only: bool = False,
vis_fileformat: str = "pdf",
vis_buffer_layers: bool = False,
vis_direction: str = "bottomup",
random_seed: Optional[int] = None,
model: nn.Module,
input_args: Union[torch.Tensor, List[Any], Tuple[Any]],
input_kwargs: Dict[Any, Any] = None,
layers_to_save: Optional[Union[str, List]] = "all",
keep_unsaved_layers: bool = True,
output_device: str = "same",
activation_postfunc: Optional[Callable] = None,
mark_input_output_distances: bool = False,
detach_saved_tensors: bool = False,
save_function_args: bool = False,
save_gradients: bool = False,
vis_opt: str = "none",
vis_nesting_depth: int = 1000,
vis_outpath: str = "graph.gv",
vis_save_only: bool = False,
vis_fileformat: str = "pdf",
vis_buffer_layers: bool = False,
vis_direction: str = "bottomup",
random_seed: Optional[int] = None,
) -> ModelHistory:
"""Runs a forward pass through a model given input x, and returns a ModelHistory object containing a log
(layer activations and accompanying layer metadata) of the forward pass for all layers specified in which_layers,
Expand Down Expand Up @@ -149,9 +149,9 @@ def log_forward_pass(


def get_model_metadata(
model: nn.Module,
input_args: Union[torch.Tensor, List[Any], Tuple[Any]],
input_kwargs: Dict[Any, Any] = None,
model: nn.Module,
input_args: Union[torch.Tensor, List[Any], Tuple[Any]],
input_kwargs: Dict[Any, Any] = None,
) -> ModelHistory:
"""Logs all metadata for a given model and inputs without saving any activations. NOTE: this function
will be removed in a future version of TorchLens, since calling it is identical to calling
Expand All @@ -175,17 +175,17 @@ def get_model_metadata(


def show_model_graph(
model: nn.Module,
input_args: Union[torch.Tensor, List[Any], Tuple[Any]],
input_kwargs: Dict[Any, Any] = None,
vis_opt: str = "unrolled",
vis_nesting_depth: int = 1000,
vis_outpath: str = "graph.gv",
save_only: bool = False,
vis_fileformat: str = "pdf",
vis_buffer_layers: bool = False,
vis_direction: str = "bottomup",
random_seed: Optional[int] = None,
model: nn.Module,
input_args: Union[torch.Tensor, List, Tuple],
input_kwargs: Dict[Any, Any] = None,
vis_opt: str = "unrolled",
vis_nesting_depth: int = 1000,
vis_outpath: str = "graph.gv",
save_only: bool = False,
vis_fileformat: str = "pdf",
vis_buffer_layers: bool = False,
vis_direction: str = "bottomup",
random_seed: Optional[int] = None,
) -> None:
"""Visualize the model graph without saving any activations.
Expand Down Expand Up @@ -240,11 +240,11 @@ def show_model_graph(


def validate_saved_activations(
model: nn.Module,
input_args: Union[torch.Tensor, List[Any], Tuple[Any]],
input_kwargs: Dict[Any, Any] = None,
random_seed: Union[int, None] = None,
verbose: bool = False,
model: nn.Module,
input_args: Union[torch.Tensor, List[Any], Tuple[Any]],
input_kwargs: Dict[Any, Any] = None,
random_seed: Union[int, None] = None,
verbose: bool = False,
) -> bool:
"""Validate that the saved model activations correctly reproduce the ground truth output. This function works by
running a forward pass through the model, saving all activations, re-running the forward pass starting from
Expand Down Expand Up @@ -301,9 +301,9 @@ def validate_saved_activations(


def validate_batch_of_models_and_inputs(
models_and_inputs_dict: Dict[str, Dict[str, Union[str, Callable, Dict]]],
out_path: str,
redo_model_if_already_run: bool = True,
models_and_inputs_dict: Dict[str, Dict[str, Union[str, Callable, Dict]]],
out_path: str,
redo_model_if_already_run: bool = True,
) -> pd.DataFrame:
"""Given multiple models and several inputs for each, validates the saved activations for all of them
and returns a Pandas dataframe summarizing the validation results.
Expand Down Expand Up @@ -332,7 +332,7 @@ def validate_batch_of_models_and_inputs(
)
models_already_run = current_csv["model_name"].unique()
for model_name, model_info in tqdm(
models_and_inputs_dict.items(), desc="Validating models"
models_and_inputs_dict.items(), desc="Validating models"
):
print(f"Validating model {model_name}")
if model_name in models_already_run and not redo_model_if_already_run:
Expand Down

0 comments on commit 8aec5c6

Please sign in to comment.