diff --git a/examples/01_icetray/01_convert_i3_files.py b/examples/01_icetray/01_convert_i3_files.py index 03517f51b..495c9a13c 100644 --- a/examples/01_icetray/01_convert_i3_files.py +++ b/examples/01_icetray/01_convert_i3_files.py @@ -5,6 +5,7 @@ from graphnet.constants import EXAMPLE_OUTPUT_DIR, TEST_DATA_DIR from graphnet.data.extractors import ( I3FeatureExtractorIceCubeUpgrade, + I3FeatureExtractorIceCube86, I3RetroExtractor, I3TruthExtractor, I3GenericExtractor, @@ -34,12 +35,7 @@ def main_icecube86(backend: str) -> None: converter: DataConverter = CONVERTER_CLASS[backend]( [ - I3GenericExtractor( - keys=[ - "SRTInIcePulses", - "I3MCTree", - ] - ), + I3FeatureExtractorIceCube86("SRTInIcePulses"), I3TruthExtractor(), ], outdir, diff --git a/examples/02_data/01_read_dataset.py b/examples/02_data/01_read_dataset.py index 302529050..be913e190 100644 --- a/examples/02_data/01_read_dataset.py +++ b/examples/02_data/01_read_dataset.py @@ -17,7 +17,10 @@ from graphnet.data.dataset import ParquetDataset from graphnet.utilities.argparse import ArgumentParser from graphnet.utilities.logging import Logger - +from graphnet.models.graphs import KNNGraph +from graphnet.models.detector.icecube import ( + IceCubeDeepCore, +) DATASET_CLASS = { "sqlite": SQLiteDataset, @@ -44,6 +47,9 @@ def main(backend: str) -> None: num_workers = 30 wait_time = 0.00 # sec. + # Define graph representation + graph_definition = KNNGraph(detector=IceCubeDeepCore()) + for table in [pulsemap, truth_table]: # Get column names from backend if backend == "sqlite": @@ -62,15 +68,16 @@ def main(backend: str) -> None: # Common variables dataset = DATASET_CLASS[backend]( - path, - pulsemap, - features, - truth, + path=path, + pulsemaps=pulsemap, + features=features, + truth=truth, truth_table=truth_table, + graph_definition=graph_definition, ) assert isinstance(dataset, Dataset) - logger.info(dataset[1]) + logger.info(str(dataset[1])) logger.info(dataset[1].x) if backend == "sqlite": assert isinstance(dataset, SQLiteDataset) @@ -92,7 +99,7 @@ def main(backend: str) -> None: for batch in tqdm(dataloader, unit=" batches", colour="green"): time.sleep(wait_time) - logger.info(batch) + logger.info(str(batch)) logger.info(batch.size()) logger.info(batch.num_graphs) diff --git a/examples/02_data/02_plot_feature_distributions.py b/examples/02_data/02_plot_feature_distributions.py index 88ffef576..b46be0623 100644 --- a/examples/02_data/02_plot_feature_distributions.py +++ b/examples/02_data/02_plot_feature_distributions.py @@ -1,4 +1,4 @@ -"""Example of plotting feature distributions from SQLite database.""" +"""Example of visualization of input data from a configured Dataset.""" import os.path @@ -8,8 +8,6 @@ from graphnet.constants import CONFIG_DIR from graphnet.data.dataset import Dataset -from graphnet.models.detector.icecube import IceCubeDeepCore -from graphnet.models.graph_builders import KNNGraphBuilder from graphnet.utilities.logging import Logger from graphnet.utilities.argparse import ArgumentParser @@ -27,46 +25,20 @@ def main() -> None: assert isinstance(dataset, Dataset) features = dataset._features[1:] - # Building model - detector = IceCubeDeepCore( - graph_builder=KNNGraphBuilder(nb_nearest_neighbours=8), - ) - # Get feature matrix - x_original_list = [] x_preprocessed_list = [] for batch in tqdm(dataset, colour="green"): - x_original_list.append(batch.x.numpy()) - x_preprocessed_list.append(detector(batch).x.numpy()) + x_preprocessed_list.append(batch.x.numpy()) - x_original = np.concatenate(x_original_list, axis=0) x_preprocessed = np.concatenate(x_preprocessed_list, axis=0) - - logger.info(f"Number of NaNs: {np.sum(np.isnan(x_original))}") - logger.info(f"Number of infs: {np.sum(np.isinf(x_original))}") + logger.info(f"Number of NaNs: {np.sum(np.isnan(x_preprocessed))}") + logger.info(f"Number of infs: {np.sum(np.isinf(x_preprocessed))}") # Plot feature distributions - nb_features_original = x_original.shape[1] nb_features_preprocessed = x_preprocessed.shape[1] dim = int(np.ceil(np.sqrt(nb_features_preprocessed))) axis_size = 4 - bins = 100 - - # -- Original - fig, axes = plt.subplots( - dim, dim, figsize=(dim * axis_size, dim * axis_size) - ) - for ix, ax in enumerate(axes.ravel()[:nb_features_original]): - ax.hist(x_original[:, ix], bins=bins) - ax.set_xlabel( - f"x{ix}: {features[ix] if ix < len(features) else 'N/A'}" - ) - ax.set_yscale("log") - - fig.tight_layout - figure_name_original = "feature_distribution_original.png" - fig.savefig(figure_name_original) - logger.info(f"Figure written to {figure_name_original}") + bins = 50 # -- Preprocessed fig, axes = plt.subplots( diff --git a/examples/04_training/02_train_model_without_configs.py b/examples/04_training/01_train_dynedge.py similarity index 93% rename from examples/04_training/02_train_model_without_configs.py rename to examples/04_training/01_train_dynedge.py index 6d9c5746e..58e513ec2 100644 --- a/examples/04_training/02_train_model_without_configs.py +++ b/examples/04_training/01_train_dynedge.py @@ -79,12 +79,7 @@ def main( wandb_logger.experiment.config.update(config) # Define graph representation - graph_definition = KNNGraph( - detector=Prometheus(), - node_definition=NodesAsPulses(), - nb_nearest_neighbours=8, - node_feature_names=features, - ) + graph_definition = KNNGraph(detector=Prometheus()) ( training_dataloader, @@ -166,10 +161,19 @@ def main( logger.info(f"Writing results to {path}") os.makedirs(path, exist_ok=True) + # Save results as .csv results.to_csv(f"{path}/results.csv") - model.save_state_dict(f"{path}/state_dict.pth") + + # Save full model (including weights) to .pth file - not version safe + # Note: Models saved as .pth files in one version of graphnet + # may not be compatible with a different version of graphnet. model.save(f"{path}/model.pth") + # Save model config and state dict - Version safe save method. + # This method of saving models is the safest way. + model.save_state_dict(f"{path}/state_dict.pth") + model.save_config(f"{path}/model_config.yml") + if __name__ == "__main__": diff --git a/examples/04_training/04_train_tito_model_without_configs.py b/examples/04_training/02_train_tito_model.py similarity index 95% rename from examples/04_training/04_train_tito_model_without_configs.py rename to examples/04_training/02_train_tito_model.py index 858aa45f4..ee3d89760 100644 --- a/examples/04_training/04_train_tito_model_without_configs.py +++ b/examples/04_training/02_train_tito_model.py @@ -14,7 +14,6 @@ from graphnet.models.detector.prometheus import Prometheus from graphnet.models.gnn import DynEdgeTITO from graphnet.models.graphs import KNNGraph -from graphnet.models.graphs.nodes import NodesAsPulses from graphnet.models.task.reconstruction import ( DirectionReconstructionWithKappa, ) @@ -28,7 +27,6 @@ # Constants features = FEATURES.PROMETHEUS truth = TRUTH.PROMETHEUS -DYNTRANS_LAYER_SIZES = [(256, 256), (256, 256), (256, 256)] def main( @@ -76,12 +74,7 @@ def main( }, } - graph_definition = KNNGraph( - detector=Prometheus(), - node_definition=NodesAsPulses(), - nb_nearest_neighbours=8, - node_feature_names=features, - ) + graph_definition = KNNGraph(detector=Prometheus()) archive = os.path.join(EXAMPLE_OUTPUT_DIR, "train_tito_model") run_name = "dynedgeTITO_{}_example".format(config["target"]) if wandb: @@ -115,7 +108,6 @@ def main( gnn = DynEdgeTITO( nb_inputs=graph_definition.nb_outputs, global_pooling_schemes=["max"], - dyntrans_layer_sizes=DYNTRANS_LAYER_SIZES, ) task = DirectionReconstructionWithKappa( hidden_size=gnn.nb_outputs, @@ -182,10 +174,16 @@ def main( logger.info(f"Writing results to {path}") os.makedirs(path, exist_ok=True) + # Save results as .csv results.to_csv(f"{path}/results.csv") - model.save_state_dict(f"{path}/state_dict.pth") + + # Save full model (including weights) to .pth file - Not version proof model.save(f"{path}/model.pth") + # Save model config and state dict - Version safe save method. + model.save_state_dict(f"{path}/state_dict.pth") + model.save_config(f"{path}/model_config.yml") + if __name__ == "__main__": diff --git a/examples/04_training/01_train_model.py b/examples/04_training/03_train_dynedge_from_config.py similarity index 94% rename from examples/04_training/01_train_model.py rename to examples/04_training/03_train_dynedge_from_config.py index 0250f6602..9a6df73dd 100644 --- a/examples/04_training/01_train_model.py +++ b/examples/04_training/03_train_dynedge_from_config.py @@ -1,4 +1,4 @@ -"""Simplified example of training Model.""" +"""Simplified example of training DynEdge from pre-defined config files.""" from typing import List, Optional import os @@ -46,7 +46,7 @@ def main( log_model=True, ) - # Build model + # Build model from pre-defined config file made from Model.save_config model_config = ModelConfig.load(model_config_path) model: StandardModel = StandardModel.from_config(model_config, trust=True) @@ -69,7 +69,8 @@ def main( archive = os.path.join(EXAMPLE_OUTPUT_DIR, "train_model") run_name = "dynedge_{}_example".format("_".join(config.target)) - # Construct dataloaders + # Construct dataloaders from pre-defined dataset config files. + # i.e. from Dataset.save_config dataset_config = DatasetConfig.load(dataset_config_path) dataloaders = DataLoader.from_dataset_config( dataset_config, diff --git a/examples/04_training/03_train_multiple_models.sh b/examples/04_training/03_train_multiple_models.sh deleted file mode 100644 index 931b574ca..000000000 --- a/examples/04_training/03_train_multiple_models.sh +++ /dev/null @@ -1,62 +0,0 @@ -#!/bin/bash - -#### This script enables the user to run multiple trainings in sequence on the same database but for different model configs. -# To execute this file, copy the file path and write in the terminal; $ bash - - -# execution of bash file in same directory as the script -bash_directory=$(dirname -- "$(readlink -f "${BASH_SOURCE}")") - -## Global; applies to all models -# path to dataset configuration file in the GraphNeT directory -dataset_config=$(realpath "$bash_directory/../../configs/datasets/training_example_data_sqlite.yml") -# what GPU to use; more information can be gained with the module nvitop -gpus=0 -# the maximum number of epochs; if used, this greatly affect learning rate scheduling -max_epochs=5 -# early stopping threshold -early_stopping_patience=5 -# events in a batch -batch_size=16 -# number of CPUs to use -num_workers=2 - -## Model dependent; applies to each model in sequence -# path to model files in the GraphNeT directory -model_directory=$(realpath "$bash_directory/../../configs/models") -# list of model configurations to train -declare -a model_configs=( - "${model_directory}/example_direction_reconstruction_model.yml" - "${model_directory}/example_energy_reconstruction_model.yml" - "${model_directory}/example_vertex_position_reconstruction_model.yml" -) - -# suffix ending on the created directory -declare -a suffixs=( - "direction" - "energy" - "position" -) - -# prediction name outputs per model -declare -a prediction_names=( - "zenith_pred zenith_kappa_pred azimuth_pred azimuth_kappa_pred" - "energy_pred" - "position_x_pred position_y_pred position_z_pred" -) - -for i in "${!model_configs[@]}"; do - echo "training iteration ${i} on ${model_configs[$i]} with output variables ${prediction_names[i][@]}" - python ${bash_directory}/01_train_model.py \ - --dataset-config ${dataset_config} \ - --model-config ${model_configs[$i]} \ - --gpus ${gpus} \ - --max-epochs ${max_epochs} \ - --early-stopping-patience ${early_stopping_patience} \ - --batch-size ${batch_size} \ - --num-workers ${num_workers} \ - --prediction-names ${prediction_names[i][@]} \ - --suffix ${suffixs[i]} - wait -done -echo "all trainings are done." \ No newline at end of file diff --git a/examples/04_training/03_train_classification_model.py b/examples/04_training/04_train_multiclassifier_from_configs.py similarity index 98% rename from examples/04_training/03_train_classification_model.py rename to examples/04_training/04_train_multiclassifier_from_configs.py index 65b1acf2a..6937b01b1 100644 --- a/examples/04_training/03_train_classification_model.py +++ b/examples/04_training/04_train_multiclassifier_from_configs.py @@ -1,4 +1,4 @@ -"""Simplified example of multi-class classification training Model.""" +"""Multi-class classification using DynEdge from pre-defined config files.""" import os from typing import List, Optional, Dict, Any diff --git a/examples/04_training/README.md b/examples/04_training/README.md index e60aef3ae..934767247 100644 --- a/examples/04_training/README.md +++ b/examples/04_training/README.md @@ -2,44 +2,44 @@ This subfolder contains two main training scripts: -**`01_train_model.py`** Shows how to train a GNN on neutrino telescope data **using configuration files** to construct the dataset that loads the data and the model that is trained. This is the recommended way to configure standard dataset and models, as it is easier to ready and share than doing so in pure code. This example can be run using a few different models targeting different physics use cases. For instance, you can try running: +**`01_train_dynedge.py`** ** Shows how to train a GNN on neutrino telescope data **without configuration files,** i.e., by programatically constructing the dataset and model used. This is good for debugging and experimenting with different dataset settings and model configurations, as it is easier to build the model using the API than by writing configuration files from scratch. **This is our recommended way of getting started with the library**. For instance, try running: ```bash # Show the CLI -(graphnet) $ python examples/04_training/01_train_model.py --help +(graphnet) $ python examples/04_training/01_train_dynedge.py --help # Train energy regression model -(graphnet) $ python examples/04_training/01_train_model.py +(graphnet) $ python examples/04_training/01_train_dynedge.py # Same as above, as this is the default model config. (graphnet) $ python examples/04_training/01_train_model.py \ --model-config configs/models/example_energy_reconstruction_model.yml # Train using a single GPU -(graphnet) $ python examples/04_training/01_train_model.py --gpus 0 +(graphnet) $ python examples/04_training/01_train_dynedge.py --gpus 0 # Train using multiple GPUs -(graphnet) $ python examples/04_training/01_train_model.py --gpus 0 1 +(graphnet) $ python examples/04_training/01_train_dynedge.py --gpus 0 1 # Train a vertex position reconstruction model -(graphnet) $ python examples/04_training/01_train_model.py \ +(graphnet) $ python examples/04_training/01_train_dynedge.py \ --model-config configs/models/example_vertex_position_reconstruction_model.yml # Trains a direction (zenith, azimuth) reconstruction model. Note that the # chosen `Task` in the model config file also returns estimated "kappa" values, # i.e. inverse variance, for each predicted feature, meaning that we need to # manually specify the names of these. -(graphnet) $ python examples/04_training/01_train_model.py --gpus 0 \ +(graphnet) $ python examples/04_training/01_train_model_dynedge.py --gpus 0 \ --model-config configs/models/example_direction_reconstruction_model.yml \ --prediction-names zenith_pred zenith_kappa_pred azimuth_pred azimuth_kappa_pred ``` -**`02_train_model_without_configs.py`** Shows how to train a GNN on neutrino telescope data **without configuration files,** i.e., by programatically constructing the dataset and model used. This is good for debugging and experimenting with different dataset settings and model configurations, as it is easier to build the model using the API than by writing configuration files from scratch. For instance, try running: +**`03_train_model_dynedge_from_config.py** Shows how to train a GNN on neutrino telescope data **using configuration files** to construct the dataset that loads the data and the model that is trained. This is the recommended way to configure standard dataset and models, as it is easier to ready and share than doing so in pure code. This example can be run using a few different models targeting different physics use cases. For instance, you can try running: ```bash # Show the CLI -(graphnet) $ python examples/04_training/02_train_model_without_configs.py --help +(graphnet) $ python examples/04_training/02_train_dynedge_from_config.py --help # Train energy regression model -(graphnet) $ python examples/04_training/02_train_model_without_configs.py +(graphnet) $ python examples/04_training/02_train_dynedge_from_config.py ``` diff --git a/src/graphnet/data/dataset/__init__.py b/src/graphnet/data/dataset/__init__.py index 3ccdd9642..f6eafee94 100644 --- a/src/graphnet/data/dataset/__init__.py +++ b/src/graphnet/data/dataset/__init__.py @@ -7,7 +7,6 @@ from .dataset import EnsembleDataset, Dataset, ColumnMissingException from .parquet.parquet_dataset import ParquetDataset from .sqlite.sqlite_dataset import SQLiteDataset - from .sqlite.sqlite_dataset_perturbed import SQLiteDatasetPerturbed torch.multiprocessing.set_sharing_strategy("file_system") diff --git a/src/graphnet/data/dataset/dataset.py b/src/graphnet/data/dataset/dataset.py index d6fe1d019..4253788a8 100644 --- a/src/graphnet/data/dataset/dataset.py +++ b/src/graphnet/data/dataset/dataset.py @@ -27,7 +27,7 @@ from graphnet.utilities.config import ( Configurable, DatasetConfig, - save_dataset_config, + DatasetConfigSaverABCMeta, ) from graphnet.utilities.config.parsing import traverse_and_apply from graphnet.utilities.logging import Logger @@ -85,7 +85,13 @@ def parse_graph_definition(cfg: dict) -> GraphDefinition: return graph_definition -class Dataset(Logger, Configurable, torch.utils.data.Dataset, ABC): +class Dataset( + Logger, + Configurable, + torch.utils.data.Dataset, + ABC, + metaclass=DatasetConfigSaverABCMeta, +): """Base Dataset class for reading from any intermediate file format.""" # Class method(s) @@ -188,7 +194,6 @@ def _resolve_graphnet_paths( .replace("${GRAPHNET}", GRAPHNET_ROOT_DIR) ) - @save_dataset_config def __init__( self, path: Union[str, List[str]], diff --git a/src/graphnet/data/dataset/sqlite/__init__.py b/src/graphnet/data/dataset/sqlite/__init__.py index 84d67a921..c44d66184 100644 --- a/src/graphnet/data/dataset/sqlite/__init__.py +++ b/src/graphnet/data/dataset/sqlite/__init__.py @@ -3,6 +3,5 @@ if has_torch_package(): from .sqlite_dataset import SQLiteDataset - from .sqlite_dataset_perturbed import SQLiteDatasetPerturbed del has_torch_package diff --git a/src/graphnet/data/dataset/sqlite/sqlite_dataset_perturbed.py b/src/graphnet/data/dataset/sqlite/sqlite_dataset_perturbed.py deleted file mode 100644 index b951e6916..000000000 --- a/src/graphnet/data/dataset/sqlite/sqlite_dataset_perturbed.py +++ /dev/null @@ -1,152 +0,0 @@ -"""`Dataset` class(es) for reading perturbed data from SQLite databases.""" - -from typing import Dict, List, Optional, Tuple, Union - -import numpy as np -from numpy.random import default_rng, Generator -import torch -from torch_geometric.data import Data - -from .sqlite_dataset import SQLiteDataset - - -class SQLiteDatasetPerturbed(SQLiteDataset): - """Pytorch dataset for reading perturbed data from SQLite databases. - - This including a pre-processing step, where the input data is randomly - perturbed according to given per-feature "noise" levels. This is intended - to test the stability of a trained model under small changes to the input - parameters. - """ - - def __init__( - self, - path: Union[str, List[str]], - pulsemaps: Union[str, List[str]], - features: List[str], - truth: List[str], - *, - perturbation_dict: Dict[str, float], - node_truth: Optional[List[str]] = None, - index_column: str = "event_no", - truth_table: str = "truth", - node_truth_table: Optional[str] = None, - string_selection: Optional[List[int]] = None, - selection: Optional[List[int]] = None, - dtype: torch.dtype = torch.float32, - loss_weight_table: Optional[str] = None, - loss_weight_column: Optional[str] = None, - loss_weight_default_value: Optional[float] = None, - seed: Optional[Union[int, Generator]] = None, - ): - """Construct SQLiteDatasetPerturbed. - - Args: - path: Path to the file(s) from which this `Dataset` should read. - pulsemaps: Name(s) of the pulse map series that should be used to - construct the nodes on the individual graph objects, and their - features. Multiple pulse series maps can be used, e.g., when - different DOM types are stored in different maps. - features: List of columns in the input files that should be used as - node features on the graph objects. - truth: List of event-level columns in the input files that should - be used added as attributes on the graph objects. - perturbation_dict (Dict[str, float]): Dictionary mapping a feature - name to a standard deviation according to which the values for - this feature should be randomly perturbed. - node_truth: List of node-level columns in the input files that - should be used added as attributes on the graph objects. - index_column: Name of the column in the input files that contains - unique indicies to identify and map events across tables. - truth_table: Name of the table containing event-level truth - information. - node_truth_table: Name of the table containing node-level truth - information. - string_selection: Subset of strings for which data should be read - and used to construct graph objects. Defaults to None, meaning - all strings for which data exists are used. - selection: List of indicies (in `index_column`) of the events in - the input files that should be read. Defaults to None, meaning - that all events in the input files are read. - dtype: Type of the feature tensor on the graph objects returned. - loss_weight_table: Name of the table containing per-event loss - weights. - loss_weight_column: Name of the column in `loss_weight_table` - containing per-event loss weights. This is also the name of the - corresponding attribute assigned to the graph object. - loss_weight_default_value: Default per-event loss weight. - NOTE: This default value is only applied when - `loss_weight_table` and `loss_weight_column` are specified, and - in this case to events with no value in the corresponding - table/column. That is, if no per-event loss weight table/column - is provided, this value is ignored. Defaults to None. - seed: Optional seed for random number generation. Defaults to None. - """ - # Base class constructor - super().__init__( - path=path, - pulsemaps=pulsemaps, - features=features, - truth=truth, - node_truth=node_truth, - index_column=index_column, - truth_table=truth_table, - node_truth_table=node_truth_table, - string_selection=string_selection, - selection=selection, - dtype=dtype, - loss_weight_table=loss_weight_table, - loss_weight_column=loss_weight_column, - loss_weight_default_value=loss_weight_default_value, - ) - - # Custom member variables - assert isinstance(perturbation_dict, dict) - assert len(set(perturbation_dict.keys())) == len( - perturbation_dict.keys() - ) - self._perturbation_dict = perturbation_dict - - self._perturbation_cols = [ - self._features.index(key) for key in self._perturbation_dict.keys() - ] - - if seed is not None: - if isinstance(seed, int): - self.rng = default_rng(seed) - elif isinstance(seed, Generator): - self.rng = seed - else: - raise ValueError( - "Invalid seed. Must be an int or a numpy Generator." - ) - else: - self.rng = default_rng() - - def __getitem__(self, sequential_index: int) -> Data: - """Return graph `Data` object at `index`.""" - if not (0 <= sequential_index < len(self)): - raise IndexError( - f"Index {sequential_index} not in range [0, {len(self) - 1}]" - ) - features, truth, node_truth, loss_weight = self._query( - sequential_index - ) - perturbed_features = self._perturb_features(features) - graph = self._create_graph( - perturbed_features, truth, node_truth, loss_weight - ) - return graph - - def _perturb_features( - self, features: List[Tuple[float, ...]] - ) -> List[Tuple[float, ...]]: - features_array = np.array(features) - perturbed_features = self.rng.normal( - loc=features_array[:, self._perturbation_cols], - scale=np.array( - list(self._perturbation_dict.values()), dtype=np.float - ), - ) - features_array[:, self._perturbation_cols] = perturbed_features - return features_array.tolist() diff --git a/src/graphnet/models/coarsening.py b/src/graphnet/models/coarsening.py index 68eab50b9..d40f0c009 100644 --- a/src/graphnet/models/coarsening.py +++ b/src/graphnet/models/coarsening.py @@ -22,7 +22,6 @@ std_pool_x, ) from graphnet.models import Model -from graphnet.utilities.config import save_model_config # Utility method(s) from torch_geometric.utils import degree @@ -63,7 +62,6 @@ class Coarsening(Model): "sum": (sum_pool, sum_pool_x), } - @save_model_config def __init__( self, reduce: str = "avg", @@ -198,7 +196,6 @@ def _add_inc_dict(self, original: Data, pooled: Data) -> Data: class AttributeCoarsening(Coarsening): """Coarsen pulses based on specified attributes.""" - @save_model_config def __init__( self, attributes: List[str], diff --git a/src/graphnet/models/detector/detector.py b/src/graphnet/models/detector/detector.py index e1b1cc6ef..a7fb25f1d 100644 --- a/src/graphnet/models/detector/detector.py +++ b/src/graphnet/models/detector/detector.py @@ -8,13 +8,11 @@ from graphnet.models import Model from graphnet.utilities.decorators import final -from graphnet.utilities.config import save_model_config class Detector(Model): """Base class for all detector-specific read-ins in graphnet.""" - @save_model_config def __init__(self) -> None: """Construct `Detector`.""" # Base class constructor diff --git a/src/graphnet/models/gnn/convnet.py b/src/graphnet/models/gnn/convnet.py index 9c03c96f7..dcffd0c50 100644 --- a/src/graphnet/models/gnn/convnet.py +++ b/src/graphnet/models/gnn/convnet.py @@ -10,14 +10,12 @@ from torch_geometric.nn import TAGConv, global_add_pool, global_max_pool from torch_geometric.data import Data -from graphnet.utilities.config import save_model_config from graphnet.models.gnn.gnn import GNN class ConvNet(GNN): """ConvNet (convolutional network) model.""" - @save_model_config def __init__( self, nb_inputs: int, diff --git a/src/graphnet/models/gnn/dynedge.py b/src/graphnet/models/gnn/dynedge.py index 4e9e07b65..9ea93f9ce 100644 --- a/src/graphnet/models/gnn/dynedge.py +++ b/src/graphnet/models/gnn/dynedge.py @@ -7,7 +7,6 @@ from torch_scatter import scatter_max, scatter_mean, scatter_min, scatter_sum from graphnet.models.components.layers import DynEdgeConv -from graphnet.utilities.config import save_model_config from graphnet.models.gnn.gnn import GNN from graphnet.models.utils import calculate_xyzt_homophily @@ -22,7 +21,6 @@ class DynEdge(GNN): """DynEdge (dynamical edge convolutional) model.""" - @save_model_config def __init__( self, nb_inputs: int, diff --git a/src/graphnet/models/gnn/dynedge_jinst.py b/src/graphnet/models/gnn/dynedge_jinst.py index 36a0f1303..23c630fa9 100644 --- a/src/graphnet/models/gnn/dynedge_jinst.py +++ b/src/graphnet/models/gnn/dynedge_jinst.py @@ -10,7 +10,6 @@ from torch_scatter import scatter_max, scatter_mean, scatter_min, scatter_sum from graphnet.models.components.layers import DynEdgeConv -from graphnet.utilities.config import save_model_config from graphnet.models.gnn.gnn import GNN from graphnet.models.utils import calculate_xyzt_homophily @@ -18,7 +17,6 @@ class DynEdgeJINST(GNN): """DynEdge (dynamical edge convolutional) model used in [2209.03042].""" - @save_model_config def __init__( self, nb_inputs: int, diff --git a/src/graphnet/models/gnn/dynedge_kaggle_tito.py b/src/graphnet/models/gnn/dynedge_kaggle_tito.py index 2e07a4e72..d3196dd30 100644 --- a/src/graphnet/models/gnn/dynedge_kaggle_tito.py +++ b/src/graphnet/models/gnn/dynedge_kaggle_tito.py @@ -18,7 +18,6 @@ from torch_scatter import scatter_max, scatter_mean, scatter_min, scatter_sum from graphnet.models.components.layers import DynTrans -from graphnet.utilities.config import save_model_config from graphnet.models.gnn.gnn import GNN from graphnet.models.utils import calculate_xyzt_homophily @@ -33,7 +32,6 @@ class DynEdgeTITO(GNN): """DynEdge (dynamical edge convolutional) model.""" - @save_model_config def __init__( self, nb_inputs: int, diff --git a/src/graphnet/models/gnn/gnn.py b/src/graphnet/models/gnn/gnn.py index de155cb4a..5fd933d84 100644 --- a/src/graphnet/models/gnn/gnn.py +++ b/src/graphnet/models/gnn/gnn.py @@ -6,13 +6,11 @@ from torch_geometric.data import Data from graphnet.models import Model -from graphnet.utilities.config import save_model_config class GNN(Model): """Base class for all core GNN models in graphnet.""" - @save_model_config def __init__(self, nb_inputs: int, nb_outputs: int) -> None: """Construct `GNN`.""" # Base class constructor diff --git a/src/graphnet/models/graphs/edges/edges.py b/src/graphnet/models/graphs/edges/edges.py index 9f7844375..ad3cf3c46 100644 --- a/src/graphnet/models/graphs/edges/edges.py +++ b/src/graphnet/models/graphs/edges/edges.py @@ -7,7 +7,6 @@ from torch_geometric.nn import knn_graph, radius_graph from torch_geometric.data import Data -from graphnet.utilities.config import save_model_config from graphnet.models.utils import calculate_distance_matrix from graphnet.models import Model @@ -48,7 +47,6 @@ def _construct_edges(self, graph: Data) -> Data: class KNNEdges(EdgeDefinition): # pylint: disable=too-few-public-methods """Builds edges from the k-nearest neighbours.""" - @save_model_config def __init__( self, nb_nearest_neighbours: int, @@ -85,7 +83,6 @@ def _construct_edges(self, graph: Data) -> Data: class RadialEdges(EdgeDefinition): """Builds graph from a sphere of chosen radius centred at each node.""" - @save_model_config def __init__( self, radius: float, @@ -126,7 +123,6 @@ class EuclideanEdges(EdgeDefinition): # pylint: disable=too-few-public-methods See https://arxiv.org/pdf/1809.06166.pdf. """ - @save_model_config def __init__( self, sigma: float, diff --git a/src/graphnet/models/graphs/graph_definition.py b/src/graphnet/models/graphs/graph_definition.py index 275c56e57..042e54c58 100644 --- a/src/graphnet/models/graphs/graph_definition.py +++ b/src/graphnet/models/graphs/graph_definition.py @@ -6,30 +6,30 @@ """ -from typing import Any, List, Optional, Dict, Callable +from typing import Any, List, Optional, Dict, Callable, Union import torch from torch_geometric.data import Data import numpy as np - -from graphnet.utilities.config import save_model_config +from numpy.random import default_rng, Generator from graphnet.models.detector import Detector from .edges import EdgeDefinition -from .nodes import NodeDefinition +from .nodes import NodeDefinition, NodesAsPulses from graphnet.models import Model class GraphDefinition(Model): """An Abstract class to create graph definitions from.""" - @save_model_config def __init__( self, detector: Detector, - node_definition: NodeDefinition, + node_definition: NodeDefinition = NodesAsPulses(), edge_definition: Optional[EdgeDefinition] = None, node_feature_names: Optional[List[str]] = None, dtype: Optional[torch.dtype] = torch.float, + perturbation_dict: Optional[Dict[str, float]] = None, + seed: Optional[Union[int, Generator]] = None, ): """Construct ´GraphDefinition´. The ´detector´ holds. @@ -42,10 +42,16 @@ def __init__( Args: detector: The corresponding ´Detector´ representing the data. - node_definition: Definition of nodes. + node_definition: Definition of nodes. Defaults to NodesAsPulses. edge_definition: Definition of edges. Defaults to None. node_feature_names: Names of node feature columns. Defaults to None dtype: data type used for node features. e.g. ´torch.float´ + perturbation_dict: Dictionary mapping a feature name to a standard + deviation according to which the values for this + feature should be randomly perturbed. Defaults + to None. + seed: seed or Generator used to randomly sample perturbations. + Defaults to None. """ # Base class constructor super().__init__(name=__name__, class_name=self.__class__.__name__) @@ -54,6 +60,8 @@ def __init__( self._detector = detector self._edge_definition = edge_definition self._node_definition = node_definition + self._perturbation_dict = perturbation_dict + if node_feature_names is None: # Assume all features in Detector is used. node_feature_names = list(self._detector.feature_map().keys()) # type: ignore @@ -69,6 +77,24 @@ def __init__( self.nb_inputs = len(self._node_feature_names) self.nb_outputs = self._node_definition.nb_outputs + # Set perturbation_cols if needed + if isinstance(self._perturbation_dict, dict): + self._perturbation_cols = [ + self._node_feature_names.index(key) + for key in self._perturbation_dict.keys() + ] + if seed is not None: + if isinstance(seed, int): + self.rng = default_rng(seed) + elif isinstance(seed, Generator): + self.rng = seed + else: + raise ValueError( + "Invalid seed. Must be an int or a numpy Generator." + ) + else: + self.rng = default_rng() + def forward( # type: ignore self, node_features: np.ndarray, @@ -87,9 +113,12 @@ def forward( # type: ignore node_feature_names: name of each column. Shape ´[,d]´. truth_dicts: Dictionary containing truth labels. custom_label_functions: Custom label functions. See https://github.com/graphnet-team/graphnet/blob/main/GETTING_STARTED.md#adding-custom-truth-labels. - loss_weight_column: Name of column that holds loss weight. Defaults to None. + loss_weight_column: Name of column that holds loss weight. + Defaults to None. loss_weight: Loss weight associated with event. Defaults to None. - loss_weight_default_value: default value for loss weight. Used in instances where some events have no pre-defined loss weight. Defaults to None. + loss_weight_default_value: default value for loss weight. + Used in instances where some events have + no pre-defined loss weight. Defaults to None. data_path: Path to dataset data files. Defaults to None. Returns: @@ -100,6 +129,9 @@ def forward( # type: ignore node_features=node_features, node_feature_names=node_feature_names ) + # Gaussian perturbation of each column if perturbation dict is given + node_features = self._perturb_input(node_features) + # Transform to pytorch tensor node_features = torch.tensor(node_features, dtype=self.dtype) @@ -116,6 +148,7 @@ def forward( # type: ignore if self._edge_definition is not None: graph = self._edge_definition(graph) else: + self.warning_once( "No EdgeDefinition provided. Graphs will not have edges defined!" ) @@ -161,11 +194,31 @@ def _validate_input( # was instantiated with. assert len(node_feature_names) == len( self._node_feature_names - ), f"""Input features ({node_feature_names}) is not what {self.__class__.__name__} was instatiated with ({self._node_feature_names})""" + ), f"""Input features ({node_feature_names}) is not what + {self.__class__.__name__} was instatiated + with ({self._node_feature_names})""" # noqa for idx in range(len(node_feature_names)): assert ( node_feature_names[idx] == self._node_feature_names[idx] - ), f""" Order of node features in data are not the same as expected. Got {node_feature_names} vs. {self._node_feature_names}""" + ), f""" Order of node features in data + are not the same as expected. Got {node_feature_names} + vs. {self._node_feature_names}""" # noqa + + def _perturb_input(self, node_features: np.ndarray) -> np.ndarray: + if isinstance(self._perturbation_dict, dict): + self.warning_once( + f"""Will randomly perturb + {list(self._perturbation_dict.keys())} + using stds {self._perturbation_dict.values()}""" # noqa + ) + perturbed_features = self.rng.normal( + loc=node_features[:, self._perturbation_cols], + scale=np.array( + list(self._perturbation_dict.values()), dtype=float + ), + ) + node_features[:, self._perturbation_cols] = perturbed_features + return node_features def _add_loss_weights( self, @@ -254,7 +307,7 @@ def _add_features_individually( else: self.warning_once( """Cannot assign graph['x']. This field is reserved for node features. Please rename your input feature.""" - ) + return graph def _add_custom_labels( diff --git a/src/graphnet/models/graphs/graphs.py b/src/graphnet/models/graphs/graphs.py index dc2ded022..4ae53037a 100644 --- a/src/graphnet/models/graphs/graphs.py +++ b/src/graphnet/models/graphs/graphs.py @@ -1,25 +1,26 @@ """A module containing different graph representations in GraphNeT.""" -from typing import List, Optional +from typing import List, Optional, Dict, Union import torch +from numpy.random import Generator -from graphnet.utilities.config import save_model_config from .graph_definition import GraphDefinition from graphnet.models.detector import Detector from graphnet.models.graphs.edges import EdgeDefinition, KNNEdges -from graphnet.models.graphs.nodes import NodeDefinition +from graphnet.models.graphs.nodes import NodeDefinition, NodesAsPulses class KNNGraph(GraphDefinition): """A Graph representation where Edges are drawn to nearest neighbours.""" - @save_model_config def __init__( self, detector: Detector, - node_definition: NodeDefinition, + node_definition: NodeDefinition = NodesAsPulses(), node_feature_names: Optional[List[str]] = None, dtype: Optional[torch.dtype] = torch.float, + perturbation_dict: Optional[Dict[str, float]] = None, + seed: Optional[Union[int, Generator]] = None, nb_nearest_neighbours: int = 8, columns: List[int] = [0, 1, 2], ) -> None: @@ -30,6 +31,12 @@ def __init__( node_definition: Definition of nodes in the graph. node_feature_names: Name of node features. dtype: data type for node features. + perturbation_dict: Dictionary mapping a feature name to a standard + deviation according to which the values for this + feature should be randomly perturbed. Defaults + to None. + seed: seed or Generator used to randomly sample perturbations. + Defaults to None. nb_nearest_neighbours: Number of edges for each node. Defaults to 8. columns: node feature columns used for distance calculation . Defaults to [0, 1, 2]. @@ -44,4 +51,6 @@ def __init__( ), dtype=dtype, node_feature_names=node_feature_names, + perturbation_dict=perturbation_dict, + seed=seed, ) diff --git a/src/graphnet/models/graphs/nodes/nodes.py b/src/graphnet/models/graphs/nodes/nodes.py index 6b3443e0c..ce539ee80 100644 --- a/src/graphnet/models/graphs/nodes/nodes.py +++ b/src/graphnet/models/graphs/nodes/nodes.py @@ -7,14 +7,12 @@ from torch_geometric.data import Data from graphnet.utilities.decorators import final -from graphnet.utilities.config import save_model_config from graphnet.models import Model class NodeDefinition(Model): # pylint: disable=too-few-public-methods """Base class for graph building.""" - @save_model_config def __init__(self) -> None: """Construct `Detector`.""" # Base class constructor diff --git a/src/graphnet/models/model.py b/src/graphnet/models/model.py index 111a2fa76..5713cbcf4 100644 --- a/src/graphnet/models/model.py +++ b/src/graphnet/models/model.py @@ -18,11 +18,17 @@ from torch_geometric.data import Data from graphnet.utilities.logging import Logger -from graphnet.utilities.config import Configurable, ModelConfig +from graphnet.utilities.config import ( + Configurable, + ModelConfig, + ModelConfigSaverABC, +) from graphnet.training.callbacks import ProgressBar -class Model(Logger, Configurable, LightningModule, ABC): +class Model( + Logger, Configurable, LightningModule, ABC, metaclass=ModelConfigSaverABC +): """Base class for all models in graphnet.""" @abstractmethod diff --git a/src/graphnet/models/standard_model.py b/src/graphnet/models/standard_model.py index 1d439133f..0f4f6895b 100644 --- a/src/graphnet/models/standard_model.py +++ b/src/graphnet/models/standard_model.py @@ -10,7 +10,6 @@ from torch_geometric.data import Data import pandas as pd -from graphnet.utilities.config import save_model_config from graphnet.models.graphs import GraphDefinition from graphnet.models.gnn.gnn import GNN from graphnet.models.model import Model @@ -24,7 +23,6 @@ class StandardModel(Model): model (detector read-in, GNN architecture, and task-specific read-outs). """ - @save_model_config def __init__( self, *, diff --git a/src/graphnet/models/task/task.py b/src/graphnet/models/task/task.py index 094071a7c..0d7379f00 100644 --- a/src/graphnet/models/task/task.py +++ b/src/graphnet/models/task/task.py @@ -15,7 +15,6 @@ from graphnet.training.loss_functions import LossFunction # type: ignore[attr-defined] from graphnet.models import Model -from graphnet.utilities.config import save_model_config from graphnet.utilities.decorators import final @@ -39,7 +38,6 @@ def default_prediction_labels(self) -> List[str]: """Return default prediction labels.""" return self._default_prediction_labels - @save_model_config def __init__( self, *, @@ -264,7 +262,6 @@ def _validate_and_set_transforms( class IdentityTask(Task): """Identity, or trivial, task.""" - @save_model_config def __init__( self, nb_outputs: int, diff --git a/src/graphnet/training/loss_functions.py b/src/graphnet/training/loss_functions.py index 740f0b912..624a5fa53 100644 --- a/src/graphnet/training/loss_functions.py +++ b/src/graphnet/training/loss_functions.py @@ -19,7 +19,6 @@ softplus, ) -from graphnet.utilities.config import save_model_config from graphnet.models.model import Model from graphnet.utilities.decorators import final @@ -27,7 +26,6 @@ class LossFunction(Model): """Base class for loss functions in `graphnet`.""" - @save_model_config def __init__(self, **kwargs: Any) -> None: """Construct `LossFunction`, saving model config.""" super().__init__(**kwargs) @@ -120,7 +118,6 @@ class CrossEntropyLoss(LossFunction): (0, num_classes - 1). """ - @save_model_config def __init__( self, options: Union[int, List[Any], Dict[Any, int]], diff --git a/src/graphnet/training/utils.py b/src/graphnet/training/utils.py index f7d5249f9..ec3b4c461 100644 --- a/src/graphnet/training/utils.py +++ b/src/graphnet/training/utils.py @@ -22,7 +22,7 @@ def collate_fn(graphs: List[Data]) -> Batch: """Remove graphs with less than two DOM hits. - Should not occur in "production. + Should not occur in "production". """ graphs = [g for g in graphs if g.n_pulses > 1] return Batch.from_data_list(graphs) @@ -32,7 +32,7 @@ def collate_fn(graphs: List[Data]) -> Batch: def make_dataloader( db: str, pulsemaps: Union[str, List[str]], - graph_definition: Optional[GraphDefinition], + graph_definition: GraphDefinition, features: List[str], truth: List[str], *, @@ -92,7 +92,7 @@ def make_dataloader( # @TODO: Remove in favour of DataLoader{,.from_dataset_config} def make_train_validation_dataloader( db: str, - graph_definition: Optional[GraphDefinition], + graph_definition: GraphDefinition, selection: Optional[List[int]], pulsemaps: Union[str, List[str]], features: List[str], diff --git a/src/graphnet/utilities/config/__init__.py b/src/graphnet/utilities/config/__init__.py index 5e37c6a00..1520ca68d 100644 --- a/src/graphnet/utilities/config/__init__.py +++ b/src/graphnet/utilities/config/__init__.py @@ -1,6 +1,16 @@ """Modules for configuration files for use across `graphnet`.""" from .configurable import Configurable -from .dataset_config import DatasetConfig, save_dataset_config -from .model_config import ModelConfig, save_model_config +from .dataset_config import ( + DatasetConfig, + DatasetConfigSaverMeta, + DatasetConfigSaverABCMeta, + save_dataset_config, +) +from .model_config import ( + ModelConfig, + ModelConfigSaverMeta, + ModelConfigSaverABC, + save_model_config, +) from .training_config import TrainingConfig diff --git a/src/graphnet/utilities/config/dataset_config.py b/src/graphnet/utilities/config/dataset_config.py index 34d92fc3c..57739b667 100644 --- a/src/graphnet/utilities/config/dataset_config.py +++ b/src/graphnet/utilities/config/dataset_config.py @@ -1,5 +1,6 @@ """Config classes for the `graphnet.data.dataset` module.""" - +import warnings +from abc import ABCMeta from functools import wraps from typing import ( TYPE_CHECKING, @@ -180,6 +181,11 @@ def _parse_torch(self, obj: Any) -> Any: def save_dataset_config(init_fn: Callable) -> Callable: """Save the arguments to `__init__` functions as member `DatasetConfig`.""" + warnings.warn( + "Warning: `save_dataset_config` is deprecated. Config saving " + "is now done automatically, for all classes inheriting from Dataset", + DeprecationWarning, + ) def _replace_model_instance_with_config( obj: Union["Model", Any] @@ -214,3 +220,42 @@ def wrapper(self: Any, *args: Any, **kwargs: Any) -> Any: return ret return wrapper + + +class DatasetConfigSaverMeta(type): + """Metaclass for `DatasetConfig` that saves the config after `__init__`.""" + + def __call__(cls: Any, *args: Any, **kwargs: Any) -> object: + """Catch object after construction and save config.""" + + def _replace_model_instance_with_config( + obj: Union["Model", Any] + ) -> Union[ModelConfig, Any]: + """Replace `Model` instances in `obj` with their `ModelConfig`.""" + from graphnet.models import Model + import torch + + if isinstance(obj, Model): + return obj.config + + if isinstance(obj, torch.dtype): + return obj.__str__() + else: + return obj + + # Create object + created_obj = super().__call__(*args, **kwargs) + + # Get all argument values, including defaults + cfg = get_all_argument_values(created_obj.__init__, *args, **kwargs) + cfg = traverse_and_apply(cfg, _replace_model_instance_with_config) + + # Store config in + created_obj._config = DatasetConfig(**cfg) + return created_obj + + +class DatasetConfigSaverABCMeta(DatasetConfigSaverMeta, ABCMeta): + """Common interface between DatasetConfigSaver and ABC Metaclasses.""" + + pass diff --git a/src/graphnet/utilities/config/model_config.py b/src/graphnet/utilities/config/model_config.py index 9c4d21d26..23b4c9b58 100644 --- a/src/graphnet/utilities/config/model_config.py +++ b/src/graphnet/utilities/config/model_config.py @@ -1,7 +1,9 @@ """Config classes for the `graphnet.models` module.""" +from abc import ABCMeta from functools import wraps import inspect import re +import warnings from typing import ( TYPE_CHECKING, Any, @@ -250,6 +252,11 @@ def as_dict(self) -> Dict[str, Dict[str, Any]]: def save_model_config(init_fn: Callable) -> Callable: """Save the arguments to `__init__` functions as a member `ModelConfig`.""" + warnings.warn( + "Warning: `save_model_config` is deprecated. Config saving is" + "now done automatically for all classes inheriting from Model", + DeprecationWarning, + ) def _replace_model_instance_with_config( obj: Union["Model", Any] @@ -283,3 +290,41 @@ def wrapper(self: Any, *args: Any, **kwargs: Any) -> Any: return ret return wrapper + + +class ModelConfigSaverMeta(type): + """Metaclass for saving `ModelConfig` to `Model` instances.""" + + def __call__(cls: Any, *args: Any, **kwargs: Any) -> object: + """Catch object construction and save config after `__init__`.""" + + def _replace_model_instance_with_config( + obj: Union["Model", Any] + ) -> Union[ModelConfig, Any]: + """Replace `Model` instances in `obj` with their `ModelConfig`.""" + from graphnet.models import Model + + if isinstance(obj, Model): + return obj.config + else: + return obj + + # Create object + created_obj = super().__call__(*args, **kwargs) + + # Get all argument values, including defaults + cfg = get_all_argument_values(created_obj.__init__, *args, **kwargs) + cfg = traverse_and_apply(cfg, _replace_model_instance_with_config) + + # Store config in + created_obj._config = ModelConfig( + class_name=str(cls.__name__), + arguments=dict(**cfg), + ) + return created_obj + + +class ModelConfigSaverABC(ModelConfigSaverMeta, ABCMeta): + """Common interface between ModelConfigSaver and ABC Metaclasses.""" + + pass diff --git a/tests/models/test_graph_definition.py b/tests/models/test_graph_definition.py new file mode 100644 index 000000000..bf16d7853 --- /dev/null +++ b/tests/models/test_graph_definition.py @@ -0,0 +1,53 @@ +"""Unit tests for GraphDefinition.""" + +from graphnet.models.graphs import KNNGraph +from graphnet.models.detector.prometheus import Prometheus +from graphnet.data.constants import FEATURES + +import numpy as np +from copy import deepcopy +import torch + + +def test_graph_definition() -> None: + """Tests the forward pass of GraphDefinition.""" + # Test configuration + features = FEATURES.PROMETHEUS + perturbation_dict = { + "sensor_pos_x": 1.4, + "sensor_pos_y": 2.2, + "sensor_pos_z": 3.7, + "t": 1.2, + } + mock_data = np.array([[1, 5, 2, 3], [2, 9, 6, 2]]) + seed = 42 + n_reps = 5 + + graph_definition = KNNGraph( + detector=Prometheus(), perturbation_dict=perturbation_dict, seed=seed + ) + original_output = graph_definition( + node_features=deepcopy(mock_data), node_feature_names=features + ) + + for _ in range(n_reps): + graph_definition_perturbed = KNNGraph( + detector=Prometheus(), perturbation_dict=perturbation_dict + ) + + graph_definition = KNNGraph( + detector=Prometheus(), + perturbation_dict=perturbation_dict, + seed=seed, + ) + + data = graph_definition( + node_features=deepcopy(mock_data), node_feature_names=features + ) + + perturbed_data = graph_definition_perturbed( + node_features=deepcopy(mock_data), node_feature_names=features + ) + + assert ~torch.equal(data.x, perturbed_data.x) # should not be equal. + assert torch.equal(data.x, original_output.x) # should be equal.