diff --git a/requirements/torch_cu118.txt b/requirements/torch_cu118.txt index d01bdf439..31642c08e 100644 --- a/requirements/torch_cu118.txt +++ b/requirements/torch_cu118.txt @@ -1,4 +1,5 @@ # Contains packages requirements for GPU installation --find-links https://download.pytorch.org/whl/torch_stable.html torch==2.2.0+cu118 +torchvision==0.17.0+cu118 --find-links https://data.pyg.org/whl/torch-2.2.0+cu118.html diff --git a/requirements/torch_cu121.txt b/requirements/torch_cu121.txt index 477a67fb8..d7422e645 100644 --- a/requirements/torch_cu121.txt +++ b/requirements/torch_cu121.txt @@ -1,4 +1,5 @@ # Contains packages requirements for GPU installation --find-links https://download.pytorch.org/whl/torch_stable.html torch==2.2.0+cu121 +torchvision==0.17.0+cu121 --find-links https://data.pyg.org/whl/torch-2.2.0+cu121.html diff --git a/src/graphnet/models/components/embedding.py b/src/graphnet/models/components/embedding.py index e97ca90e7..9539fc444 100644 --- a/src/graphnet/models/components/embedding.py +++ b/src/graphnet/models/components/embedding.py @@ -93,6 +93,7 @@ def __init__( f"{n_features} features." ) elif n_features >= 6: + hidden_dim = 6 * seq_length else: hidden_dim = int((n_features + 0.5) * seq_length) diff --git a/src/graphnet/models/easy_model.py b/src/graphnet/models/easy_model.py new file mode 100644 index 000000000..d26d88fa0 --- /dev/null +++ b/src/graphnet/models/easy_model.py @@ -0,0 +1,490 @@ +"""Suggested Model subclass that enables simple user syntax.""" + +from collections import OrderedDict +from typing import Any, Dict, List, Optional, Union, Type + +import numpy as np +import torch +from pytorch_lightning import Callback, Trainer +from pytorch_lightning.callbacks import EarlyStopping, ModelCheckpoint +from torch import Tensor +from torch.nn import ModuleList +from torch.optim import Adam +from torch.utils.data import DataLoader, SequentialSampler +from torch_geometric.data import Data +import pandas as pd +from pytorch_lightning.loggers import Logger as LightningLogger + +from graphnet.training.callbacks import ProgressBar +from graphnet.models.graphs import GraphDefinition +from graphnet.models.model import Model +from graphnet.models.task import StandardLearnedTask + + +class EasySyntax(Model): + """A suggested Model class that comes with simple user syntax. + + This class delivers simple user syntax for training and prediction, while + imposing minimal constraints on structure. + """ + + def __init__( + self, + *, + tasks: Union[StandardLearnedTask, List[StandardLearnedTask]], + optimizer_class: Type[torch.optim.Optimizer] = Adam, + optimizer_kwargs: Optional[Dict] = None, + scheduler_class: Optional[type] = None, + scheduler_kwargs: Optional[Dict] = None, + scheduler_config: Optional[Dict] = None, + ) -> None: + """Construct `StandardModel`.""" + # Base class constructor + super().__init__(name=__name__, class_name=self.__class__.__name__) + + # Check(s) + if not isinstance(tasks, (list, tuple)): + tasks = [tasks] + + # Member variable(s) + self._tasks = ModuleList(tasks) + self._optimizer_class = optimizer_class + self._optimizer_kwargs = optimizer_kwargs or dict() + self._scheduler_class = scheduler_class + self._scheduler_kwargs = scheduler_kwargs or dict() + self._scheduler_config = scheduler_config or dict() + + self.validate_tasks() + + def compute_loss( + self, preds: Tensor, data: List[Data], verbose: bool = False + ) -> Tensor: + """Compute and sum losses across tasks.""" + raise NotImplementedError + + def forward( + self, data: Union[Data, List[Data]] + ) -> List[Union[Tensor, Data]]: + """Forward pass, chaining model components.""" + raise NotImplementedError + + def shared_step(self, batch: List[Data], batch_idx: int) -> Tensor: + """Perform shared step. + + Applies the forward pass and the following loss calculation, shared + between the training and validation step. + """ + raise NotImplementedError + + def validate_tasks(self) -> None: + """Verify that self._tasks contain compatible elements.""" + raise NotImplementedError + + @staticmethod + def _construct_trainer( + max_epochs: int = 10, + gpus: Optional[Union[List[int], int]] = None, + callbacks: Optional[List[Callback]] = None, + logger: Optional[LightningLogger] = None, + log_every_n_steps: int = 1, + gradient_clip_val: Optional[float] = None, + distribution_strategy: Optional[str] = "ddp", + **trainer_kwargs: Any, + ) -> Trainer: + if gpus: + accelerator = "gpu" + devices = gpus + else: + accelerator = "cpu" + devices = 1 + + trainer = Trainer( + accelerator=accelerator, + devices=devices, + max_epochs=max_epochs, + callbacks=callbacks, + log_every_n_steps=log_every_n_steps, + logger=logger, + gradient_clip_val=gradient_clip_val, + strategy=distribution_strategy, + **trainer_kwargs, + ) + + return trainer + + def fit( + self, + train_dataloader: DataLoader, + val_dataloader: Optional[DataLoader] = None, + *, + max_epochs: int = 10, + early_stopping_patience: int = 5, + gpus: Optional[Union[List[int], int]] = None, + callbacks: Optional[List[Callback]] = None, + ckpt_path: Optional[str] = None, + logger: Optional[LightningLogger] = None, + log_every_n_steps: int = 1, + gradient_clip_val: Optional[float] = None, + distribution_strategy: Optional[str] = "ddp", + **trainer_kwargs: Any, + ) -> None: + """Fit `StandardModel` using `pytorch_lightning.Trainer`.""" + # Checks + if callbacks is None: + # We create the bare-minimum callbacks for you. + callbacks = self._create_default_callbacks( + val_dataloader=val_dataloader, + early_stopping_patience=early_stopping_patience, + ) + self.debug("No Callbacks specified. Default callbacks added.") + else: + # You are on your own! + self.debug("Initializing training with user-provided callbacks.") + pass + self._print_callbacks(callbacks) + has_early_stopping = self._contains_callback(callbacks, EarlyStopping) + has_model_checkpoint = self._contains_callback( + callbacks, ModelCheckpoint + ) + + if (has_early_stopping) & (has_model_checkpoint is False): + self.warning( + "No ModelCheckpoint found in callbacks. Best-fit model will" + " not automatically be loaded after training!" + "" + ) + + self.train(mode=True) + trainer = self._construct_trainer( + max_epochs=max_epochs, + gpus=gpus, + callbacks=callbacks, + logger=logger, + log_every_n_steps=log_every_n_steps, + gradient_clip_val=gradient_clip_val, + distribution_strategy=distribution_strategy, + **trainer_kwargs, + ) + + try: + trainer.fit( + self, train_dataloader, val_dataloader, ckpt_path=ckpt_path + ) + except KeyboardInterrupt: + self.warning("[ctrl+c] Exiting gracefully.") + pass + + # Load weights from best-fit model after training if possible + if has_early_stopping & has_model_checkpoint: + for callback in callbacks: + if isinstance(callback, ModelCheckpoint): + checkpoint_callback = callback + self.load_state_dict( + torch.load(checkpoint_callback.best_model_path)["state_dict"] + ) + self.info("Best-fit weights from EarlyStopping loaded.") + + def _print_callbacks(self, callbacks: List[Callback]) -> None: + callback_names = [] + for cbck in callbacks: + callback_names.append(cbck.__class__.__name__) + self.info( + f"Training initiated with callbacks: {', '.join(callback_names)}" + ) + + def _contains_callback( + self, callbacks: List[Callback], callback: Callback + ) -> bool: + """Check if `callback` is in `callbacks`.""" + for cbck in callbacks: + if isinstance(cbck, callback): + return True + return False + + @property + def target_labels(self) -> List[str]: + """Return target label.""" + return [label for task in self._tasks for label in task._target_labels] + + @property + def prediction_labels(self) -> List[str]: + """Return prediction labels.""" + return [ + label for task in self._tasks for label in task._prediction_labels + ] + + def configure_optimizers(self) -> Dict[str, Any]: + """Configure the model's optimizer(s).""" + optimizer = self._optimizer_class( + self.parameters(), **self._optimizer_kwargs + ) + config = { + "optimizer": optimizer, + } + if self._scheduler_class is not None: + scheduler = self._scheduler_class( + optimizer, **self._scheduler_kwargs + ) + config.update( + { + "lr_scheduler": { + "scheduler": scheduler, + **self._scheduler_config, + }, + } + ) + return config + + def training_step( + self, train_batch: Union[Data, List[Data]], batch_idx: int + ) -> Tensor: + """Perform training step.""" + if isinstance(train_batch, Data): + train_batch = [train_batch] + loss = self.shared_step(train_batch, batch_idx) + self.log( + "train_loss", + loss, + batch_size=self._get_batch_size(train_batch), + prog_bar=True, + on_epoch=True, + on_step=False, + sync_dist=True, + ) + + current_lr = self.trainer.optimizers[0].param_groups[0]["lr"] + self.log("lr", current_lr, prog_bar=True, on_step=True) + return loss + + def validation_step( + self, val_batch: Union[Data, List[Data]], batch_idx: int + ) -> Tensor: + """Perform validation step.""" + if isinstance(val_batch, Data): + val_batch = [val_batch] + loss = self.shared_step(val_batch, batch_idx) + self.log( + "val_loss", + loss, + batch_size=self._get_batch_size(val_batch), + prog_bar=True, + on_epoch=True, + on_step=False, + sync_dist=True, + ) + return loss + + def inference(self) -> None: + """Activate inference mode.""" + for task in self._tasks: + task.inference() + + def train(self, mode: bool = True) -> "Model": + """Deactivate inference mode.""" + super().train(mode) + if mode: + for task in self._tasks: + task.train_eval() + return self + + def predict( + self, + dataloader: DataLoader, + gpus: Optional[Union[List[int], int]] = None, + distribution_strategy: Optional[str] = "auto", + ) -> List[Tensor]: + """Return predictions for `dataloader`.""" + self.inference() + self.train(mode=False) + + callbacks = self._create_default_callbacks( + val_dataloader=None, + ) + + inference_trainer = self._construct_trainer( + gpus=gpus, + distribution_strategy=distribution_strategy, + callbacks=callbacks, + ) + + predictions_list = inference_trainer.predict(self, dataloader) + assert len(predictions_list), "Got no predictions" + + nb_outputs = len(predictions_list[0]) + predictions: List[Tensor] = [ + torch.cat([preds[ix] for preds in predictions_list], dim=0) + for ix in range(nb_outputs) + ] + return predictions + + def predict_as_dataframe( + self, + dataloader: DataLoader, + prediction_columns: Optional[List[str]] = None, + *, + additional_attributes: Optional[List[str]] = None, + gpus: Optional[Union[List[int], int]] = None, + distribution_strategy: Optional[str] = "auto", + ) -> pd.DataFrame: + """Return predictions for `dataloader` as a DataFrame. + + Include `additional_attributes` as additional columns in the output + DataFrame. + """ + if prediction_columns is None: + prediction_columns = self.prediction_labels + + if additional_attributes is None: + additional_attributes = [] + assert isinstance(additional_attributes, list) + + if ( + not isinstance(dataloader.sampler, SequentialSampler) + and additional_attributes + ): + print(dataloader.sampler) + raise UserWarning( + "DataLoader has a `sampler` that is not `SequentialSampler`, " + "indicating that shuffling is enabled. Using " + "`predict_as_dataframe` with `additional_attributes` assumes " + "that the sequence of batches in `dataloader` are " + "deterministic. Either call this method a `dataloader` which " + "doesn't resample batches; or do not request " + "`additional_attributes`." + ) + self.info(f"Column names for predictions are: \n {prediction_columns}") + predictions_torch = self.predict( + dataloader=dataloader, + gpus=gpus, + distribution_strategy=distribution_strategy, + ) + predictions = ( + torch.cat(predictions_torch, dim=1).detach().cpu().numpy() + ) + assert len(prediction_columns) == predictions.shape[1], ( + f"Number of provided column names ({len(prediction_columns)}) and " + f"number of output columns ({predictions.shape[1]}) don't match." + ) + + # Check if predictions are on event- or pulse-level + pulse_level_predictions = len(predictions) > len(dataloader.dataset) + + # Get additional attributes + attributes: Dict[str, List[np.ndarray]] = OrderedDict( + [(attr, []) for attr in additional_attributes] + ) + for batch in dataloader: + for attr in attributes: + attribute = batch[attr] + if isinstance(attribute, torch.Tensor): + attribute = attribute.detach().cpu().numpy() + + # Check if node level predictions + # If true, additional attributes are repeated + # to make dimensions fit + if pulse_level_predictions: + if len(attribute) < np.sum( + batch.n_pulses.detach().cpu().numpy() + ): + attribute = np.repeat( + attribute, batch.n_pulses.detach().cpu().numpy() + ) + attributes[attr].extend(attribute) + + # Confirm that attributes match length of predictions + skip_attributes = [] + for attr in attributes.keys(): + try: + assert len(attributes[attr]) == len(predictions) + except AssertionError: + self.warning_once( + "Could not automatically adjust length" + f" of additional attribute '{attr}' to match length of" + f" predictions.This error can be caused by heavy" + " disagreement between number of examples in the" + " dataset vs. actual events in the dataloader, e.g. " + " heavy filtering of events in `collate_fn` passed to" + " `dataloader`. This can also be caused by requesting" + " pulse-level attributes for `Task`s that produce" + " event-level predictions. Attribute skipped." + ) + skip_attributes.append(attr) + + # Remove bad attributes + for attr in skip_attributes: + attributes.pop(attr) + additional_attributes.remove(attr) + + data = np.concatenate( + [predictions] + + [ + np.asarray(values)[:, np.newaxis] + for values in attributes.values() + ], + axis=1, + ) + + results = pd.DataFrame( + data, columns=prediction_columns + additional_attributes + ) + return results + + def _create_default_callbacks( + self, + val_dataloader: DataLoader, + early_stopping_patience: Optional[int] = None, + ) -> List: + """Create default callbacks. + + Used in cases where no callbacks are specified by the user in .fit + """ + callbacks = [ProgressBar()] + if val_dataloader is not None: + assert early_stopping_patience is not None + # Add Early Stopping + callbacks.append( + EarlyStopping( + monitor="val_loss", + patience=early_stopping_patience, + ) + ) + # Add Model Check Point + callbacks.append( + ModelCheckpoint( + save_top_k=1, + monitor="val_loss", + mode="min", + filename=f"{self.backbone.__class__.__name__}" + + "-{epoch}-{val_loss:.2f}-{train_loss:.2f}", + ) + ) + self.info( + "EarlyStopping has been added" + f" with a patience of {early_stopping_patience}." + ) + return callbacks + + def _add_early_stopping( + self, val_dataloader: DataLoader, callbacks: List + ) -> List: + if val_dataloader is None: + return callbacks + has_early_stopping = False + assert isinstance(callbacks, list) + for callback in callbacks: + if isinstance(callback, EarlyStopping): + has_early_stopping = True + + if not has_early_stopping: + callbacks.append( + EarlyStopping( + monitor="val_loss", + patience=5, + ) + ) + self.warning_once( + "Got validation dataloader but no EarlyStopping callback. An " + "EarlyStopping callback has been added automatically with " + "patience=5 and monitor = 'val_loss'." + ) + return callbacks diff --git a/src/graphnet/models/graphs/__init__.py b/src/graphnet/models/graphs/__init__.py index ea5066307..a07d1308d 100644 --- a/src/graphnet/models/graphs/__init__.py +++ b/src/graphnet/models/graphs/__init__.py @@ -7,4 +7,4 @@ from .graph_definition import GraphDefinition -from .graphs import KNNGraph +from .graphs import KNNGraph, EdgelessGraph diff --git a/src/graphnet/models/graphs/graph_definition.py b/src/graphnet/models/graphs/graph_definition.py index 6366fc390..e384425f9 100644 --- a/src/graphnet/models/graphs/graph_definition.py +++ b/src/graphnet/models/graphs/graph_definition.py @@ -24,7 +24,7 @@ class GraphDefinition(Model): def __init__( self, detector: Detector, - node_definition: NodeDefinition = NodesAsPulses(), + node_definition: NodeDefinition = None, edge_definition: Optional[EdgeDefinition] = None, input_feature_names: Optional[List[str]] = None, dtype: Optional[torch.dtype] = torch.float, @@ -69,6 +69,9 @@ def __init__( # Base class constructor super().__init__(name=__name__, class_name=self.__class__.__name__) + if node_definition is None: + node_definition = NodesAsPulses() + # Member Variables self._detector = detector self._edge_definition = edge_definition diff --git a/src/graphnet/models/graphs/graphs.py b/src/graphnet/models/graphs/graphs.py index d486bba0a..0289b943d 100644 --- a/src/graphnet/models/graphs/graphs.py +++ b/src/graphnet/models/graphs/graphs.py @@ -54,3 +54,44 @@ def __init__( perturbation_dict=perturbation_dict, seed=seed, ) + + +class EdgelessGraph(GraphDefinition): + """A Data representation without edge assignment. + + I.e the resulting representation is created without an EdgeDefinition. + """ + + def __init__( + self, + detector: Detector, + node_definition: NodeDefinition = None, + input_feature_names: Optional[List[str]] = None, + dtype: Optional[torch.dtype] = torch.float, + perturbation_dict: Optional[Dict[str, float]] = None, + seed: Optional[Union[int, Generator]] = None, + ) -> None: + """Construct isolated nodes graph representation. + + Args: + detector: Detector that represents your data. + node_definition: Definition of nodes in the graph. + input_feature_names: Name of input feature columns. + dtype: data type for node features. + perturbation_dict: Dictionary mapping a feature name to a standard + deviation according to which the values for this + feature should be randomly perturbed. Defaults + to None. + seed: seed or Generator used to randomly sample perturbations. + Defaults to None. + """ + # Base class constructor + super().__init__( + detector=detector, + node_definition=node_definition or NodesAsPulses(), + edge_definition=None, + dtype=dtype, + input_feature_names=input_feature_names, + perturbation_dict=perturbation_dict, + seed=seed, + ) diff --git a/src/graphnet/models/standard_model.py b/src/graphnet/models/standard_model.py index e38308ba1..cfb814f94 100644 --- a/src/graphnet/models/standard_model.py +++ b/src/graphnet/models/standard_model.py @@ -1,40 +1,32 @@ """Standard model class(es).""" -from collections import OrderedDict -from typing import Any, Dict, List, Optional, Union, Type -import numpy as np +from typing import Any, Dict, List, Optional, Union, Type import torch -from pytorch_lightning import Callback, Trainer -from pytorch_lightning.callbacks import EarlyStopping, ModelCheckpoint from torch import Tensor -from torch.nn import ModuleList -from torch.optim import Adam -from torch.utils.data import DataLoader, SequentialSampler from torch_geometric.data import Data -import pandas as pd -from pytorch_lightning.loggers import Logger as LightningLogger +from torch.optim import Adam -from graphnet.training.callbacks import ProgressBar -from graphnet.models.graphs import GraphDefinition from graphnet.models.gnn.gnn import GNN -from graphnet.models.model import Model +from .easy_model import EasySyntax from graphnet.models.task import StandardLearnedTask +from graphnet.models.graphs import GraphDefinition + +class StandardModel(EasySyntax): + """A Standard way of combining model components in GraphNeT. -class StandardModel(Model): - """Main class for standard models in graphnet. + This model is compatible with the vast majority of supervised learning + tasks such as regression, binary and multi-label classification. - This class chains together the different elements of a complete GNN- based - model (detector read-in, GNN backbone, and task-specific read-outs). + Capable of producing both event-level and pulse-level predictions. """ def __init__( self, - *, graph_definition: GraphDefinition, + tasks: Union[StandardLearnedTask, List[StandardLearnedTask]], backbone: GNN = None, gnn: Optional[GNN] = None, - tasks: Union[StandardLearnedTask, List[StandardLearnedTask]], optimizer_class: Type[torch.optim.Optimizer] = Adam, optimizer_kwargs: Optional[Dict] = None, scheduler_class: Optional[type] = None, @@ -43,258 +35,37 @@ def __init__( ) -> None: """Construct `StandardModel`.""" # Base class constructor - super().__init__(name=__name__, class_name=self.__class__.__name__) - - # Check(s) - if isinstance(tasks, StandardLearnedTask): - tasks = [tasks] - assert isinstance(tasks, (list, tuple)) - assert all(isinstance(task, StandardLearnedTask) for task in tasks) - assert isinstance(graph_definition, GraphDefinition) + super().__init__( + tasks=tasks, + optimizer_class=optimizer_class, + optimizer_kwargs=optimizer_kwargs, + scheduler_class=scheduler_class, + scheduler_kwargs=scheduler_kwargs, + scheduler_config=scheduler_config, + ) # deprecation warnings if (backbone is None) & (gnn is not None): backbone = gnn # Code continues after warning self.warning( - """DeprecationWarning: Argument `gnn` will be deprecated in GraphNeT 2.0. Please use `backbone` instead.""" + "DeprecationWarning: Argument `gnn` will be deprecated in" + " GraphNeT 2.0. Please use `backbone` instead." + "" ) elif (backbone is None) & (gnn is None): # Code stops raise TypeError( - "__init__() missing 1 required keyword-only argument: 'backbone'" + "__init__() missing 1 required keyword argument:'backbone'" ) + + # Checks assert isinstance(backbone, GNN) + assert isinstance(graph_definition, GraphDefinition) # Member variable(s) self._graph_definition = graph_definition self.backbone = backbone - self._tasks = ModuleList(tasks) - self._optimizer_class = optimizer_class - self._optimizer_kwargs = optimizer_kwargs or dict() - self._scheduler_class = scheduler_class - self._scheduler_kwargs = scheduler_kwargs or dict() - self._scheduler_config = scheduler_config or dict() - - # set dtype of GNN from graph_definition - self.backbone.type(self._graph_definition._dtype) - - @staticmethod - def _construct_trainer( - max_epochs: int = 10, - gpus: Optional[Union[List[int], int]] = None, - callbacks: Optional[List[Callback]] = None, - logger: Optional[LightningLogger] = None, - log_every_n_steps: int = 1, - gradient_clip_val: Optional[float] = None, - distribution_strategy: Optional[str] = "ddp", - **trainer_kwargs: Any, - ) -> Trainer: - if gpus: - accelerator = "gpu" - devices = gpus - else: - accelerator = "cpu" - devices = 1 - - trainer = Trainer( - accelerator=accelerator, - devices=devices, - max_epochs=max_epochs, - callbacks=callbacks, - log_every_n_steps=log_every_n_steps, - logger=logger, - gradient_clip_val=gradient_clip_val, - strategy=distribution_strategy, - **trainer_kwargs, - ) - - return trainer - - def fit( - self, - train_dataloader: DataLoader, - val_dataloader: Optional[DataLoader] = None, - *, - max_epochs: int = 10, - early_stopping_patience: int = 5, - gpus: Optional[Union[List[int], int]] = None, - callbacks: Optional[List[Callback]] = None, - ckpt_path: Optional[str] = None, - logger: Optional[LightningLogger] = None, - log_every_n_steps: int = 1, - gradient_clip_val: Optional[float] = None, - distribution_strategy: Optional[str] = "ddp", - **trainer_kwargs: Any, - ) -> None: - """Fit `StandardModel` using `pytorch_lightning.Trainer`.""" - # Checks - if callbacks is None: - # We create the bare-minimum callbacks for you. - callbacks = self._create_default_callbacks( - val_dataloader=val_dataloader, - early_stopping_patience=early_stopping_patience, - ) - self.debug("No Callbacks specified. Default callbacks added.") - else: - # You are on your own! - self.debug("Initializing training with user-provided callbacks.") - pass - self._print_callbacks(callbacks) - has_early_stopping = self._contains_callback(callbacks, EarlyStopping) - has_model_checkpoint = self._contains_callback( - callbacks, ModelCheckpoint - ) - - if (has_early_stopping) & (has_model_checkpoint is False): - self.warning( - """No ModelCheckpoint found in callbacks. Best-fit model will not automatically be loaded after training!""" - ) - - self.train(mode=True) - trainer = self._construct_trainer( - max_epochs=max_epochs, - gpus=gpus, - callbacks=callbacks, - logger=logger, - log_every_n_steps=log_every_n_steps, - gradient_clip_val=gradient_clip_val, - distribution_strategy=distribution_strategy, - **trainer_kwargs, - ) - - try: - trainer.fit( - self, train_dataloader, val_dataloader, ckpt_path=ckpt_path - ) - except KeyboardInterrupt: - self.warning("[ctrl+c] Exiting gracefully.") - pass - - # Load weights from best-fit model after training if possible - if has_early_stopping & has_model_checkpoint: - for callback in callbacks: - if isinstance(callback, ModelCheckpoint): - checkpoint_callback = callback - self.load_state_dict( - torch.load(checkpoint_callback.best_model_path)["state_dict"] - ) - self.info("Best-fit weights from EarlyStopping loaded.") - - def _print_callbacks(self, callbacks: List[Callback]) -> None: - callback_names = [] - for cbck in callbacks: - callback_names.append(cbck.__class__.__name__) - self.info( - f"Training initiated with callbacks: {', '.join(callback_names)}" - ) - - def _contains_callback( - self, callbacks: List[Callback], callback: Callback - ) -> bool: - """Check if `callback` is in `callbacks`.""" - for cbck in callbacks: - if isinstance(cbck, callback): - return True - return False - - @property - def target_labels(self) -> List[str]: - """Return target label.""" - return [label for task in self._tasks for label in task._target_labels] - - @property - def prediction_labels(self) -> List[str]: - """Return prediction labels.""" - return [ - label for task in self._tasks for label in task._prediction_labels - ] - - def configure_optimizers(self) -> Dict[str, Any]: - """Configure the model's optimizer(s).""" - optimizer = self._optimizer_class( - self.parameters(), **self._optimizer_kwargs - ) - config = { - "optimizer": optimizer, - } - if self._scheduler_class is not None: - scheduler = self._scheduler_class( - optimizer, **self._scheduler_kwargs - ) - config.update( - { - "lr_scheduler": { - "scheduler": scheduler, - **self._scheduler_config, - }, - } - ) - return config - - def forward( - self, data: Union[Data, List[Data]] - ) -> List[Union[Tensor, Data]]: - """Forward pass, chaining model components.""" - if isinstance(data, Data): - data = [data] - x_list = [] - for d in data: - x = self.backbone(d) - x_list.append(x) - x = torch.cat(x_list, dim=0) - - preds = [task(x) for task in self._tasks] - return preds - - def shared_step(self, batch: List[Data], batch_idx: int) -> Tensor: - """Perform shared step. - - Applies the forward pass and the following loss calculation, shared - between the training and validation step. - """ - preds = self(batch) - loss = self.compute_loss(preds, batch) - return loss - - def training_step( - self, train_batch: Union[Data, List[Data]], batch_idx: int - ) -> Tensor: - """Perform training step.""" - if isinstance(train_batch, Data): - train_batch = [train_batch] - loss = self.shared_step(train_batch, batch_idx) - self.log( - "train_loss", - loss, - batch_size=self._get_batch_size(train_batch), - prog_bar=True, - on_epoch=True, - on_step=False, - sync_dist=True, - ) - - current_lr = self.trainer.optimizers[0].param_groups[0]["lr"] - self.log("lr", current_lr, prog_bar=True, on_step=True) - return loss - - def validation_step( - self, val_batch: Union[Data, List[Data]], batch_idx: int - ) -> Tensor: - """Perform validation step.""" - if isinstance(val_batch, Data): - val_batch = [val_batch] - loss = self.shared_step(val_batch, batch_idx) - self.log( - "val_loss", - loss, - batch_size=self._get_batch_size(val_batch), - prog_bar=True, - on_epoch=True, - on_step=False, - sync_dist=True, - ) - return loss def compute_loss( self, preds: Tensor, data: List[Data], verbose: bool = False @@ -321,216 +92,33 @@ def compute_loss( ), "Please reduce loss for each task separately" return torch.sum(torch.stack(losses)) - def inference(self) -> None: - """Activate inference mode.""" - for task in self._tasks: - task.inference() - - def train(self, mode: bool = True) -> "Model": - """Deactivate inference mode.""" - super().train(mode) - if mode: - for task in self._tasks: - task.train_eval() - return self - - def predict( - self, - dataloader: DataLoader, - gpus: Optional[Union[List[int], int]] = None, - distribution_strategy: Optional[str] = "auto", - ) -> List[Tensor]: - """Return predictions for `dataloader`.""" - self.inference() - self.train(mode=False) - - callbacks = self._create_default_callbacks( - val_dataloader=None, - ) - - inference_trainer = self._construct_trainer( - gpus=gpus, - distribution_strategy=distribution_strategy, - callbacks=callbacks, - ) - - predictions_list = inference_trainer.predict(self, dataloader) - assert len(predictions_list), "Got no predictions" - - nb_outputs = len(predictions_list[0]) - predictions: List[Tensor] = [ - torch.cat([preds[ix] for preds in predictions_list], dim=0) - for ix in range(nb_outputs) - ] - return predictions - - def predict_as_dataframe( - self, - dataloader: DataLoader, - prediction_columns: Optional[List[str]] = None, - *, - additional_attributes: Optional[List[str]] = None, - gpus: Optional[Union[List[int], int]] = None, - distribution_strategy: Optional[str] = "auto", - ) -> pd.DataFrame: - """Return predictions for `dataloader` as a DataFrame. - - Include `additional_attributes` as additional columns in the output - DataFrame. - """ - if prediction_columns is None: - prediction_columns = self.prediction_labels - - if additional_attributes is None: - additional_attributes = [] - assert isinstance(additional_attributes, list) - - if ( - not isinstance(dataloader.sampler, SequentialSampler) - and additional_attributes - ): - print(dataloader.sampler) - raise UserWarning( - "DataLoader has a `sampler` that is not `SequentialSampler`, " - "indicating that shuffling is enabled. Using " - "`predict_as_dataframe` with `additional_attributes` assumes " - "that the sequence of batches in `dataloader` are " - "deterministic. Either call this method a `dataloader` which " - "doesn't resample batches; or do not request " - "`additional_attributes`." - ) - self.info(f"Column names for predictions are: \n {prediction_columns}") - predictions_torch = self.predict( - dataloader=dataloader, - gpus=gpus, - distribution_strategy=distribution_strategy, - ) - predictions = ( - torch.cat(predictions_torch, dim=1).detach().cpu().numpy() - ) - assert len(prediction_columns) == predictions.shape[1], ( - f"Number of provided column names ({len(prediction_columns)}) and " - f"number of output columns ({predictions.shape[1]}) don't match." - ) - - # Check if predictions are on event- or pulse-level - pulse_level_predictions = len(predictions) > len(dataloader.dataset) - - # Get additional attributes - attributes: Dict[str, List[np.ndarray]] = OrderedDict( - [(attr, []) for attr in additional_attributes] - ) - for batch in dataloader: - for attr in attributes: - attribute = batch[attr] - if isinstance(attribute, torch.Tensor): - attribute = attribute.detach().cpu().numpy() - - # Check if node level predictions - # If true, additional attributes are repeated - # to make dimensions fit - if pulse_level_predictions: - if len(attribute) < np.sum( - batch.n_pulses.detach().cpu().numpy() - ): - attribute = np.repeat( - attribute, batch.n_pulses.detach().cpu().numpy() - ) - attributes[attr].extend(attribute) - - # Confirm that attributes match length of predictions - skip_attributes = [] - for attr in attributes.keys(): - try: - assert len(attributes[attr]) == len(predictions) - except AssertionError: - self.warning_once( - "Could not automatically adjust length" - f" of additional attribute '{attr}' to match length of" - f" predictions.This error can be caused by heavy" - " disagreement between number of examples in the" - " dataset vs. actual events in the dataloader, e.g. " - " heavy filtering of events in `collate_fn` passed to" - " `dataloader`. This can also be caused by requesting" - " pulse-level attributes for `Task`s that produce" - " event-level predictions. Attribute skipped." - ) - skip_attributes.append(attr) - - # Remove bad attributes - for attr in skip_attributes: - attributes.pop(attr) - additional_attributes.remove(attr) - - data = np.concatenate( - [predictions] - + [ - np.asarray(values)[:, np.newaxis] - for values in attributes.values() - ], - axis=1, - ) + def forward( + self, data: Union[Data, List[Data]] + ) -> List[Union[Tensor, Data]]: + """Forward pass, chaining model components.""" + if isinstance(data, Data): + data = [data] + x_list = [] + for d in data: + x = self.backbone(d) + x_list.append(x) + x = torch.cat(x_list, dim=0) - results = pd.DataFrame( - data, columns=prediction_columns + additional_attributes - ) - return results + preds = [task(x) for task in self._tasks] + return preds - def _create_default_callbacks( - self, - val_dataloader: DataLoader, - early_stopping_patience: Optional[int] = None, - ) -> List: - """Create default callbacks. + def shared_step(self, batch: List[Data], batch_idx: int) -> Tensor: + """Perform shared step. - Used in cases where no callbacks are specified by the user in .fit + Applies the forward pass and the following loss calculation, shared + between the training and validation step. """ - callbacks = [ProgressBar()] - if val_dataloader is not None: - assert early_stopping_patience is not None - # Add Early Stopping - callbacks.append( - EarlyStopping( - monitor="val_loss", - patience=early_stopping_patience, - ) - ) - # Add Model Check Point - callbacks.append( - ModelCheckpoint( - save_top_k=1, - monitor="val_loss", - mode="min", - filename=f"{self.backbone.__class__.__name__}" - + "-{epoch}-{val_loss:.2f}-{train_loss:.2f}", - ) - ) - self.info( - f"EarlyStopping has been added with a patience of {early_stopping_patience}." - ) - return callbacks - - def _add_early_stopping( - self, val_dataloader: DataLoader, callbacks: List - ) -> List: - if val_dataloader is None: - return callbacks - has_early_stopping = False - assert isinstance(callbacks, list) - for callback in callbacks: - if isinstance(callback, EarlyStopping): - has_early_stopping = True + preds = self(batch) + loss = self.compute_loss(preds, batch) + return loss - if not has_early_stopping: - callbacks.append( - EarlyStopping( - monitor="val_loss", - patience=5, - ) - ) - self.warning_once( - "Got validation dataloader but no EarlyStopping callback. An " - "EarlyStopping callback has been added automatically with " - "patience=5 and monitor = 'val_loss'." - ) - return callbacks + def validate_tasks(self) -> None: + """Verify that self._tasks contain compatible elements.""" + accepted_tasks = StandardLearnedTask + for task in self._tasks: + assert isinstance(task, accepted_tasks) diff --git a/src/graphnet/training/loss_functions.py b/src/graphnet/training/loss_functions.py index 624a5fa53..d3fc43f7e 100644 --- a/src/graphnet/training/loss_functions.py +++ b/src/graphnet/training/loss_functions.py @@ -71,6 +71,8 @@ def _forward(self, prediction: Tensor, target: Tensor) -> Tensor: """Implement loss calculation.""" # Check(s) assert prediction.dim() == 2 + if target.dim() != prediction.dim(): + target = target.squeeze(1) assert prediction.size() == target.size() elements = torch.mean((prediction - target) ** 2, dim=-1) @@ -443,3 +445,93 @@ def _forward(self, prediction: Tensor, target: Tensor) -> Tensor: kappa = prediction[:, 3] p = kappa.unsqueeze(1) * prediction[:, [0, 1, 2]] return self._evaluate(p, target) + + +class EnsembleLoss(LossFunction): + """Chain multiple loss functions together.""" + + def __init__( + self, + loss_functions: List[LossFunction], + loss_factors: List[float] = None, + prediction_keys: Optional[List[List[int]]] = None, + *args: Any, + **kwargs: Any, + ) -> None: + """Chain multiple loss functions together. + + Optionally apply a weight to each loss function contribution. + + E.g. Loss = RMSE*0.5 + LogCoshLoss*1.5 + + Args: + loss_functions: A list of loss functions to use. + Each loss function contributes a term to the overall loss. + loss_factors: An optional list of factors that will be mulitplied + to each loss function contribution. Must be ordered according + to `loss_functions`. If not given, the weights default to 1. + prediction_keys: An optional list of lists of indices for which + prediction columns to use for each loss function. If not + given, all columns are used for all loss functions. + """ + if loss_factors is None: + # add weight of 1 - i.e no discrimination + loss_factors = np.repeat(1, len(loss_functions)).tolist() + + assert len(loss_functions) == len(loss_factors) + self._factors = loss_factors + self._loss_functions = loss_functions + + if prediction_keys is not None: + self._prediction_keys: Optional[List[List[int]]] = prediction_keys + else: + self._prediction_keys = None + super().__init__(*args, **kwargs) + + def _forward(self, prediction: Tensor, target: Tensor) -> Tensor: + """Calculate loss using multiple loss functions. + + Args: + prediction: Output of the model. + target: Target tensor, extracted from graph object. + + Returns: + Elementwise loss terms. Shape [N,] + """ + if self._prediction_keys is None: + prediction_keys = [list(range(prediction.size(1)))] * len( + self._loss_functions + ) + else: + prediction_keys = self._prediction_keys + for k, (loss_function, prediction_key) in enumerate( + zip(self._loss_functions, prediction_keys) + ): + if k == 0: + elements = self._factors[k] * loss_function._forward( + prediction=prediction[:, prediction_key], target=target + ) + else: + elements += self._factors[k] * loss_function._forward( + prediction=prediction[:, prediction_key], target=target + ) + return elements + + +class RMSEVonMisesFisher3DLoss(EnsembleLoss): + """Combine the VonMisesFisher3DLoss with RMSELoss.""" + + def __init__(self, vmfs_factor: float = 0.05) -> None: + """VonMisesFisher3DLoss with a RMSE penality term. + + The VonMisesFisher3DLoss will be weighted with `vmfs_factor`. + + Args: + vmfs_factor: A factor applied to the VonMisesFisher3DLoss term. + Defaults ot 0.05. + """ + super().__init__( + loss_functions=[RMSELoss(), VonMisesFisher3DLoss()], + loss_factors=[1, vmfs_factor], + prediction_keys=[[0, 1, 2], [0, 1, 2, 3]], + )