diff --git a/setup.py b/setup.py index 0e76025..685f373 100644 --- a/setup.py +++ b/setup.py @@ -16,7 +16,7 @@ setup( name="torchlens", - version="0.1.6", + version="0.1.7", description="A package for extracting activations from PyTorch models", long_description="A package for extracting activations from PyTorch models. Contains functionality for " "extracting model activations, visualizing a model's computational graph, and " diff --git a/tests/test_validation_and_visuals.py b/tests/test_validation_and_visuals.py index c4bba9c..2e38f39 100644 --- a/tests/test_validation_and_visuals.py +++ b/tests/test_validation_and_visuals.py @@ -13,6 +13,7 @@ import torchvision import visualpriors from PIL import Image +from StyleTTS.models import TextEncoder from transformers import ( BertForNextSentencePrediction, BertTokenizer, @@ -2696,6 +2697,23 @@ def test_clip(): # for some reason CLIP breaks the PyCharm debugger assert validate_saved_activations(model, [], model_inputs, random_seed=1) +# Text to speech +def test_styletts(): + model = TextEncoder(3, 3, 3, 100) + tokens = torch.tensor([[3, 0, 1, 2, 0, 2, 2, 3, 1, 4]]) + input_lengths = torch.ones(1, dtype=torch.long) * 10 + m = torch.ones(1, 10) + model_inputs = (tokens, input_lengths, m) + show_model_graph( + model, + model_inputs, + random_seed=1, + vis_opt="unrolled", + vis_outpath=opj("visualization_outputs", "text-to-speech", "styletts_text_encoder"), + ) + assert validate_saved_activations(model, model_inputs, random_seed=1) + + # Graph neural networks diff --git a/torchlens/constants.py b/torchlens/constants.py index 550f553..5561b80 100644 --- a/torchlens/constants.py +++ b/torchlens/constants.py @@ -347,11 +347,7 @@ def my_get_overridable_functions() -> List: if namespace is not torch.Tensor: if func_name.startswith("__"): continue - elif func_name.startswith("_"): - ignore = True - elif func_name.endswith("_"): - ignore = True - elif not func_name[0].islower(): + elif func_name[0].isupper(): ignore = True elif func_name == "unique_dim": continue diff --git a/torchlens/model_history.py b/torchlens/model_history.py index fba5c78..c8df2b0 100644 --- a/torchlens/model_history.py +++ b/torchlens/model_history.py @@ -1686,6 +1686,9 @@ def run_and_log_inputs_through_model( if type(input_args) == torch.Tensor: input_args = [input_args] + if type(input_args) == tuple: + input_args = list(input_args) + if not input_args: input_args = [] @@ -5171,7 +5174,7 @@ def _make_param_label(node: Union[TensorLogEntry, RolledTensorLogEntry]) -> str: if len(param_shape) > 1: each_param_shape.append("x".join([str(s) for s in param_shape])) else: - each_param_shape.append("x1") + each_param_shape.append(f"x{param_shape[0]}") param_label = "
params: " + ", ".join( [param_shape for param_shape in each_param_shape] @@ -6001,6 +6004,8 @@ def _check_whether_func_on_saved_parents_yields_saved_tensor( ) ): return True + elif layer_to_validate_parents_for.func_applied_name == 'empty_like': + return True elif ( perturb and (layer_to_validate_parents_for.func_applied_name == "__setitem__") @@ -6035,6 +6040,59 @@ def _check_whether_func_on_saved_parents_yields_saved_tensor( ) ): return True + elif ( + perturb + and (layer_to_validate_parents_for.func_applied_name == "index_select") + and torch.equal( + self[layers_to_perturb[0]].tensor_contents, + layer_to_validate_parents_for.creation_args[2], + ) + ): + return True + elif ( + perturb + and (layer_to_validate_parents_for.func_applied_name == "lstm") + and (torch.equal( + self[layers_to_perturb[0]].tensor_contents, + layer_to_validate_parents_for.creation_args[1][0]) or + torch.equal( + self[layers_to_perturb[0]].tensor_contents, + layer_to_validate_parents_for.creation_args[1][1]) or + torch.equal( + self[layers_to_perturb[0]].tensor_contents, + layer_to_validate_parents_for.creation_args[1]) or + torch.equal( + self[layers_to_perturb[0]].tensor_contents, + layer_to_validate_parents_for.creation_args[2][0]) or + torch.equal( + self[layers_to_perturb[0]].tensor_contents, + layer_to_validate_parents_for.creation_args[2][1]) + )): + return True + elif ( + perturb + and (layer_to_validate_parents_for.func_applied_name == "_pad_packed_sequence") + and torch.equal( + self[layers_to_perturb[0]].tensor_contents, + layer_to_validate_parents_for.creation_args[1] + )): + return True + elif ( + perturb + and (layer_to_validate_parents_for.func_applied_name == "masked_fill_") + and torch.equal( + self[layers_to_perturb[0]].tensor_contents, + layer_to_validate_parents_for.creation_args[1] + )): + return True + elif ( + perturb + and (layer_to_validate_parents_for.func_applied_name == "scatter_") + and torch.equal( + self[layers_to_perturb[0]].tensor_contents, + layer_to_validate_parents_for.creation_args[2] + )): + return True # Prepare input arguments: keep the ones that should just be kept, perturb those that should be perturbed @@ -6220,14 +6278,16 @@ def _perturb_layer_activations( device=device, ).type(parent_activations.dtype) else: - if torch.min(parent_activations) < 0: - perturbed_activations = torch.randint( - -10, 11, size=parent_activations.shape, device=device - ).type(parent_activations.dtype) - else: - perturbed_activations = torch.randint( - 0, 11, size=parent_activations.shape, device=device - ).type(parent_activations.dtype) + perturbed_activations = parent_activations.detach().clone() + while torch.equal(perturbed_activations, parent_activations): + if torch.min(parent_activations) < 0: + perturbed_activations = torch.randint( + -10, 11, size=parent_activations.shape, device=device + ).type(parent_activations.dtype) + else: + perturbed_activations = torch.randint( + 0, 11, size=parent_activations.shape, device=device + ).type(parent_activations.dtype) elif parent_activations.dtype == torch.bool: perturbed_activations = parent_activations.detach().clone() diff --git a/torchlens/user_funcs.py b/torchlens/user_funcs.py index d05fc31..ddc5def 100644 --- a/torchlens/user_funcs.py +++ b/torchlens/user_funcs.py @@ -21,25 +21,25 @@ def log_forward_pass( - model: nn.Module, - input_args: Union[torch.Tensor, List[Any], Tuple[Any]], - input_kwargs: Dict[Any, Any] = None, - layers_to_save: Optional[Union[str, List]] = "all", - keep_unsaved_layers: bool = True, - output_device: str = "same", - activation_postfunc: Optional[Callable] = None, - mark_input_output_distances: bool = False, - detach_saved_tensors: bool = False, - save_function_args: bool = False, - save_gradients: bool = False, - vis_opt: str = "none", - vis_nesting_depth: int = 1000, - vis_outpath: str = "graph.gv", - vis_save_only: bool = False, - vis_fileformat: str = "pdf", - vis_buffer_layers: bool = False, - vis_direction: str = "bottomup", - random_seed: Optional[int] = None, + model: nn.Module, + input_args: Union[torch.Tensor, List[Any], Tuple[Any]], + input_kwargs: Dict[Any, Any] = None, + layers_to_save: Optional[Union[str, List]] = "all", + keep_unsaved_layers: bool = True, + output_device: str = "same", + activation_postfunc: Optional[Callable] = None, + mark_input_output_distances: bool = False, + detach_saved_tensors: bool = False, + save_function_args: bool = False, + save_gradients: bool = False, + vis_opt: str = "none", + vis_nesting_depth: int = 1000, + vis_outpath: str = "graph.gv", + vis_save_only: bool = False, + vis_fileformat: str = "pdf", + vis_buffer_layers: bool = False, + vis_direction: str = "bottomup", + random_seed: Optional[int] = None, ) -> ModelHistory: """Runs a forward pass through a model given input x, and returns a ModelHistory object containing a log (layer activations and accompanying layer metadata) of the forward pass for all layers specified in which_layers, @@ -149,9 +149,9 @@ def log_forward_pass( def get_model_metadata( - model: nn.Module, - input_args: Union[torch.Tensor, List[Any], Tuple[Any]], - input_kwargs: Dict[Any, Any] = None, + model: nn.Module, + input_args: Union[torch.Tensor, List[Any], Tuple[Any]], + input_kwargs: Dict[Any, Any] = None, ) -> ModelHistory: """Logs all metadata for a given model and inputs without saving any activations. NOTE: this function will be removed in a future version of TorchLens, since calling it is identical to calling @@ -175,17 +175,17 @@ def get_model_metadata( def show_model_graph( - model: nn.Module, - input_args: Union[torch.Tensor, List[Any], Tuple[Any]], - input_kwargs: Dict[Any, Any] = None, - vis_opt: str = "unrolled", - vis_nesting_depth: int = 1000, - vis_outpath: str = "graph.gv", - save_only: bool = False, - vis_fileformat: str = "pdf", - vis_buffer_layers: bool = False, - vis_direction: str = "bottomup", - random_seed: Optional[int] = None, + model: nn.Module, + input_args: Union[torch.Tensor, List, Tuple], + input_kwargs: Dict[Any, Any] = None, + vis_opt: str = "unrolled", + vis_nesting_depth: int = 1000, + vis_outpath: str = "graph.gv", + save_only: bool = False, + vis_fileformat: str = "pdf", + vis_buffer_layers: bool = False, + vis_direction: str = "bottomup", + random_seed: Optional[int] = None, ) -> None: """Visualize the model graph without saving any activations. @@ -240,11 +240,11 @@ def show_model_graph( def validate_saved_activations( - model: nn.Module, - input_args: Union[torch.Tensor, List[Any], Tuple[Any]], - input_kwargs: Dict[Any, Any] = None, - random_seed: Union[int, None] = None, - verbose: bool = False, + model: nn.Module, + input_args: Union[torch.Tensor, List[Any], Tuple[Any]], + input_kwargs: Dict[Any, Any] = None, + random_seed: Union[int, None] = None, + verbose: bool = False, ) -> bool: """Validate that the saved model activations correctly reproduce the ground truth output. This function works by running a forward pass through the model, saving all activations, re-running the forward pass starting from @@ -301,9 +301,9 @@ def validate_saved_activations( def validate_batch_of_models_and_inputs( - models_and_inputs_dict: Dict[str, Dict[str, Union[str, Callable, Dict]]], - out_path: str, - redo_model_if_already_run: bool = True, + models_and_inputs_dict: Dict[str, Dict[str, Union[str, Callable, Dict]]], + out_path: str, + redo_model_if_already_run: bool = True, ) -> pd.DataFrame: """Given multiple models and several inputs for each, validates the saved activations for all of them and returns a Pandas dataframe summarizing the validation results. @@ -332,7 +332,7 @@ def validate_batch_of_models_and_inputs( ) models_already_run = current_csv["model_name"].unique() for model_name, model_info in tqdm( - models_and_inputs_dict.items(), desc="Validating models" + models_and_inputs_dict.items(), desc="Validating models" ): print(f"Validating model {model_name}") if model_name in models_already_run and not redo_model_if_already_run: