From 5b83b16933523164a9c21bbd6b762149c348f6b4 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Rasmus=20=C3=98rs=C3=B8e?= Date: Fri, 3 May 2024 12:55:51 +0200 Subject: [PATCH 01/16] Add EasyModel --- src/graphnet/models/easy_model.py | 493 ++++++++++++++++++++++++++++++ 1 file changed, 493 insertions(+) create mode 100644 src/graphnet/models/easy_model.py diff --git a/src/graphnet/models/easy_model.py b/src/graphnet/models/easy_model.py new file mode 100644 index 000000000..dd40465ea --- /dev/null +++ b/src/graphnet/models/easy_model.py @@ -0,0 +1,493 @@ +"""Suggested Model subclass that enables simple user syntax.""" + +from collections import OrderedDict +from typing import Any, Dict, List, Optional, Union, Type + +import numpy as np +import torch +from pytorch_lightning import Callback, Trainer +from pytorch_lightning.callbacks import EarlyStopping, ModelCheckpoint +from torch import Tensor +from torch.nn import ModuleList +from torch.optim import Adam +from torch.utils.data import DataLoader, SequentialSampler +from torch_geometric.data import Data +import pandas as pd +from pytorch_lightning.loggers import Logger as LightningLogger + +from graphnet.training.callbacks import ProgressBar +from graphnet.models.graphs import GraphDefinition +from graphnet.models.model import Model +from graphnet.models.task import StandardLearnedTask + + +class EasyModel(Model): + """A suggested Model class that comes with simple user syntax. + + This class delivers simple user syntax for training and prediction, while + imposing minimal constraints on structure. + """ + + def __init__( + self, + *, + graph_definition: GraphDefinition, + tasks: Union[StandardLearnedTask, List[StandardLearnedTask]], + optimizer_class: Type[torch.optim.Optimizer] = Adam, + optimizer_kwargs: Optional[Dict] = None, + scheduler_class: Optional[type] = None, + scheduler_kwargs: Optional[Dict] = None, + scheduler_config: Optional[Dict] = None, + ) -> None: + """Construct `StandardModel`.""" + # Base class constructor + super().__init__(name=__name__, class_name=self.__class__.__name__) + + # Check(s) + if not isinstance(tasks, (list, tuple)): + tasks = [tasks] + self.validate_tasks() + + assert isinstance(graph_definition, GraphDefinition) + + # Member variable(s) + self._graph_definition = graph_definition + self._tasks = ModuleList(tasks) + self._optimizer_class = optimizer_class + self._optimizer_kwargs = optimizer_kwargs or dict() + self._scheduler_class = scheduler_class + self._scheduler_kwargs = scheduler_kwargs or dict() + self._scheduler_config = scheduler_config or dict() + + def compute_loss( + self, preds: Tensor, data: List[Data], verbose: bool = False + ) -> Tensor: + """Compute and sum losses across tasks.""" + raise NotImplementedError + + def forward( + self, data: Union[Data, List[Data]] + ) -> List[Union[Tensor, Data]]: + """Forward pass, chaining model components.""" + raise NotImplementedError + + def shared_step(self, batch: List[Data], batch_idx: int) -> Tensor: + """Perform shared step. + + Applies the forward pass and the following loss calculation, shared + between the training and validation step. + """ + raise NotImplementedError + + def validate_tasks(self) -> None: + """Verify that self._tasks contain compatible elements.""" + raise NotImplementedError + + @staticmethod + def _construct_trainer( + max_epochs: int = 10, + gpus: Optional[Union[List[int], int]] = None, + callbacks: Optional[List[Callback]] = None, + logger: Optional[LightningLogger] = None, + log_every_n_steps: int = 1, + gradient_clip_val: Optional[float] = None, + distribution_strategy: Optional[str] = "ddp", + **trainer_kwargs: Any, + ) -> Trainer: + if gpus: + accelerator = "gpu" + devices = gpus + else: + accelerator = "cpu" + devices = 1 + + trainer = Trainer( + accelerator=accelerator, + devices=devices, + max_epochs=max_epochs, + callbacks=callbacks, + log_every_n_steps=log_every_n_steps, + logger=logger, + gradient_clip_val=gradient_clip_val, + strategy=distribution_strategy, + **trainer_kwargs, + ) + + return trainer + + def fit( + self, + train_dataloader: DataLoader, + val_dataloader: Optional[DataLoader] = None, + *, + max_epochs: int = 10, + early_stopping_patience: int = 5, + gpus: Optional[Union[List[int], int]] = None, + callbacks: Optional[List[Callback]] = None, + ckpt_path: Optional[str] = None, + logger: Optional[LightningLogger] = None, + log_every_n_steps: int = 1, + gradient_clip_val: Optional[float] = None, + distribution_strategy: Optional[str] = "ddp", + **trainer_kwargs: Any, + ) -> None: + """Fit `StandardModel` using `pytorch_lightning.Trainer`.""" + # Checks + if callbacks is None: + # We create the bare-minimum callbacks for you. + callbacks = self._create_default_callbacks( + val_dataloader=val_dataloader, + early_stopping_patience=early_stopping_patience, + ) + self.debug("No Callbacks specified. Default callbacks added.") + else: + # You are on your own! + self.debug("Initializing training with user-provided callbacks.") + pass + self._print_callbacks(callbacks) + has_early_stopping = self._contains_callback(callbacks, EarlyStopping) + has_model_checkpoint = self._contains_callback( + callbacks, ModelCheckpoint + ) + + if (has_early_stopping) & (has_model_checkpoint is False): + self.warning( + "No ModelCheckpoint found in callbacks. Best-fit model will" + " not automatically be loaded after training!" + "" + ) + + self.train(mode=True) + trainer = self._construct_trainer( + max_epochs=max_epochs, + gpus=gpus, + callbacks=callbacks, + logger=logger, + log_every_n_steps=log_every_n_steps, + gradient_clip_val=gradient_clip_val, + distribution_strategy=distribution_strategy, + **trainer_kwargs, + ) + + try: + trainer.fit( + self, train_dataloader, val_dataloader, ckpt_path=ckpt_path + ) + except KeyboardInterrupt: + self.warning("[ctrl+c] Exiting gracefully.") + pass + + # Load weights from best-fit model after training if possible + if has_early_stopping & has_model_checkpoint: + for callback in callbacks: + if isinstance(callback, ModelCheckpoint): + checkpoint_callback = callback + self.load_state_dict( + torch.load(checkpoint_callback.best_model_path)["state_dict"] + ) + self.info("Best-fit weights from EarlyStopping loaded.") + + def _print_callbacks(self, callbacks: List[Callback]) -> None: + callback_names = [] + for cbck in callbacks: + callback_names.append(cbck.__class__.__name__) + self.info( + f"Training initiated with callbacks: {', '.join(callback_names)}" + ) + + def _contains_callback( + self, callbacks: List[Callback], callback: Callback + ) -> bool: + """Check if `callback` is in `callbacks`.""" + for cbck in callbacks: + if isinstance(cbck, callback): + return True + return False + + @property + def target_labels(self) -> List[str]: + """Return target label.""" + return [label for task in self._tasks for label in task._target_labels] + + @property + def prediction_labels(self) -> List[str]: + """Return prediction labels.""" + return [ + label for task in self._tasks for label in task._prediction_labels + ] + + def configure_optimizers(self) -> Dict[str, Any]: + """Configure the model's optimizer(s).""" + optimizer = self._optimizer_class( + self.parameters(), **self._optimizer_kwargs + ) + config = { + "optimizer": optimizer, + } + if self._scheduler_class is not None: + scheduler = self._scheduler_class( + optimizer, **self._scheduler_kwargs + ) + config.update( + { + "lr_scheduler": { + "scheduler": scheduler, + **self._scheduler_config, + }, + } + ) + return config + + def training_step( + self, train_batch: Union[Data, List[Data]], batch_idx: int + ) -> Tensor: + """Perform training step.""" + if isinstance(train_batch, Data): + train_batch = [train_batch] + loss = self.shared_step(train_batch, batch_idx) + self.log( + "train_loss", + loss, + batch_size=self._get_batch_size(train_batch), + prog_bar=True, + on_epoch=True, + on_step=False, + sync_dist=True, + ) + + current_lr = self.trainer.optimizers[0].param_groups[0]["lr"] + self.log("lr", current_lr, prog_bar=True, on_step=True) + return loss + + def validation_step( + self, val_batch: Union[Data, List[Data]], batch_idx: int + ) -> Tensor: + """Perform validation step.""" + if isinstance(val_batch, Data): + val_batch = [val_batch] + loss = self.shared_step(val_batch, batch_idx) + self.log( + "val_loss", + loss, + batch_size=self._get_batch_size(val_batch), + prog_bar=True, + on_epoch=True, + on_step=False, + sync_dist=True, + ) + return loss + + def inference(self) -> None: + """Activate inference mode.""" + for task in self._tasks: + task.inference() + + def train(self, mode: bool = True) -> "Model": + """Deactivate inference mode.""" + super().train(mode) + if mode: + for task in self._tasks: + task.train_eval() + return self + + def predict( + self, + dataloader: DataLoader, + gpus: Optional[Union[List[int], int]] = None, + distribution_strategy: Optional[str] = "auto", + ) -> List[Tensor]: + """Return predictions for `dataloader`.""" + self.inference() + self.train(mode=False) + + callbacks = self._create_default_callbacks( + val_dataloader=None, + ) + + inference_trainer = self._construct_trainer( + gpus=gpus, + distribution_strategy=distribution_strategy, + callbacks=callbacks, + ) + + predictions_list = inference_trainer.predict(self, dataloader) + assert len(predictions_list), "Got no predictions" + + nb_outputs = len(predictions_list[0]) + predictions: List[Tensor] = [ + torch.cat([preds[ix] for preds in predictions_list], dim=0) + for ix in range(nb_outputs) + ] + return predictions + + def predict_as_dataframe( + self, + dataloader: DataLoader, + prediction_columns: Optional[List[str]] = None, + *, + additional_attributes: Optional[List[str]] = None, + gpus: Optional[Union[List[int], int]] = None, + distribution_strategy: Optional[str] = "auto", + ) -> pd.DataFrame: + """Return predictions for `dataloader` as a DataFrame. + + Include `additional_attributes` as additional columns in the output + DataFrame. + """ + if prediction_columns is None: + prediction_columns = self.prediction_labels + + if additional_attributes is None: + additional_attributes = [] + assert isinstance(additional_attributes, list) + + if ( + not isinstance(dataloader.sampler, SequentialSampler) + and additional_attributes + ): + print(dataloader.sampler) + raise UserWarning( + "DataLoader has a `sampler` that is not `SequentialSampler`, " + "indicating that shuffling is enabled. Using " + "`predict_as_dataframe` with `additional_attributes` assumes " + "that the sequence of batches in `dataloader` are " + "deterministic. Either call this method a `dataloader` which " + "doesn't resample batches; or do not request " + "`additional_attributes`." + ) + self.info(f"Column names for predictions are: \n {prediction_columns}") + predictions_torch = self.predict( + dataloader=dataloader, + gpus=gpus, + distribution_strategy=distribution_strategy, + ) + predictions = ( + torch.cat(predictions_torch, dim=1).detach().cpu().numpy() + ) + assert len(prediction_columns) == predictions.shape[1], ( + f"Number of provided column names ({len(prediction_columns)}) and " + f"number of output columns ({predictions.shape[1]}) don't match." + ) + + # Check if predictions are on event- or pulse-level + pulse_level_predictions = len(predictions) > len(dataloader.dataset) + + # Get additional attributes + attributes: Dict[str, List[np.ndarray]] = OrderedDict( + [(attr, []) for attr in additional_attributes] + ) + for batch in dataloader: + for attr in attributes: + attribute = batch[attr] + if isinstance(attribute, torch.Tensor): + attribute = attribute.detach().cpu().numpy() + + # Check if node level predictions + # If true, additional attributes are repeated + # to make dimensions fit + if pulse_level_predictions: + if len(attribute) < np.sum( + batch.n_pulses.detach().cpu().numpy() + ): + attribute = np.repeat( + attribute, batch.n_pulses.detach().cpu().numpy() + ) + attributes[attr].extend(attribute) + + # Confirm that attributes match length of predictions + skip_attributes = [] + for attr in attributes.keys(): + try: + assert len(attributes[attr]) == len(predictions) + except AssertionError: + self.warning_once( + "Could not automatically adjust length" + f" of additional attribute '{attr}' to match length of" + f" predictions.This error can be caused by heavy" + " disagreement between number of examples in the" + " dataset vs. actual events in the dataloader, e.g. " + " heavy filtering of events in `collate_fn` passed to" + " `dataloader`. This can also be caused by requesting" + " pulse-level attributes for `Task`s that produce" + " event-level predictions. Attribute skipped." + ) + skip_attributes.append(attr) + + # Remove bad attributes + for attr in skip_attributes: + attributes.pop(attr) + additional_attributes.remove(attr) + + data = np.concatenate( + [predictions] + + [ + np.asarray(values)[:, np.newaxis] + for values in attributes.values() + ], + axis=1, + ) + + results = pd.DataFrame( + data, columns=prediction_columns + additional_attributes + ) + return results + + def _create_default_callbacks( + self, + val_dataloader: DataLoader, + early_stopping_patience: Optional[int] = None, + ) -> List: + """Create default callbacks. + + Used in cases where no callbacks are specified by the user in .fit + """ + callbacks = [ProgressBar()] + if val_dataloader is not None: + assert early_stopping_patience is not None + # Add Early Stopping + callbacks.append( + EarlyStopping( + monitor="val_loss", + patience=early_stopping_patience, + ) + ) + # Add Model Check Point + callbacks.append( + ModelCheckpoint( + save_top_k=1, + monitor="val_loss", + mode="min", + filename=f"{self.backbone.__class__.__name__}" + + "-{epoch}-{val_loss:.2f}-{train_loss:.2f}", + ) + ) + self.info( + "EarlyStopping has been added" + f" with a patience of {early_stopping_patience}." + ) + return callbacks + + def _add_early_stopping( + self, val_dataloader: DataLoader, callbacks: List + ) -> List: + if val_dataloader is None: + return callbacks + has_early_stopping = False + assert isinstance(callbacks, list) + for callback in callbacks: + if isinstance(callback, EarlyStopping): + has_early_stopping = True + + if not has_early_stopping: + callbacks.append( + EarlyStopping( + monitor="val_loss", + patience=5, + ) + ) + self.warning_once( + "Got validation dataloader but no EarlyStopping callback. An " + "EarlyStopping callback has been added automatically with " + "patience=5 and monitor = 'val_loss'." + ) + return callbacks From f431a25638aa38a9b838b7d82e61a525197a2a9c Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Rasmus=20=C3=98rs=C3=B8e?= Date: Fri, 3 May 2024 13:05:35 +0200 Subject: [PATCH 02/16] refactor StandardModel --- src/graphnet/models/standard_model.py | 509 ++------------------------ 1 file changed, 39 insertions(+), 470 deletions(-) diff --git a/src/graphnet/models/standard_model.py b/src/graphnet/models/standard_model.py index e38308ba1..a3a539a5e 100644 --- a/src/graphnet/models/standard_model.py +++ b/src/graphnet/models/standard_model.py @@ -1,300 +1,52 @@ """Standard model class(es).""" -from collections import OrderedDict -from typing import Any, Dict, List, Optional, Union, Type -import numpy as np +from typing import Any, Dict, List, Optional, Union import torch -from pytorch_lightning import Callback, Trainer -from pytorch_lightning.callbacks import EarlyStopping, ModelCheckpoint from torch import Tensor -from torch.nn import ModuleList -from torch.optim import Adam -from torch.utils.data import DataLoader, SequentialSampler from torch_geometric.data import Data -import pandas as pd -from pytorch_lightning.loggers import Logger as LightningLogger -from graphnet.training.callbacks import ProgressBar -from graphnet.models.graphs import GraphDefinition from graphnet.models.gnn.gnn import GNN -from graphnet.models.model import Model +from .easy_model import EasyModel from graphnet.models.task import StandardLearnedTask -class StandardModel(Model): - """Main class for standard models in graphnet. +class StandardModel(EasyModel): + """A Standard way of combining model components in GraphNeT. - This class chains together the different elements of a complete GNN- based - model (detector read-in, GNN backbone, and task-specific read-outs). + This model is compatible with the vast majority of supervised learning + tasks such as regression, binary and multi-label classification. + + Capable of producing both event-level and pulse-level predictions. """ def __init__( self, - *, - graph_definition: GraphDefinition, backbone: GNN = None, gnn: Optional[GNN] = None, - tasks: Union[StandardLearnedTask, List[StandardLearnedTask]], - optimizer_class: Type[torch.optim.Optimizer] = Adam, - optimizer_kwargs: Optional[Dict] = None, - scheduler_class: Optional[type] = None, - scheduler_kwargs: Optional[Dict] = None, - scheduler_config: Optional[Dict] = None, + **easy_model_kwargs: Any, ) -> None: """Construct `StandardModel`.""" # Base class constructor - super().__init__(name=__name__, class_name=self.__class__.__name__) - - # Check(s) - if isinstance(tasks, StandardLearnedTask): - tasks = [tasks] - assert isinstance(tasks, (list, tuple)) - assert all(isinstance(task, StandardLearnedTask) for task in tasks) - assert isinstance(graph_definition, GraphDefinition) + super().__init__(**easy_model_kwargs) # deprecation warnings if (backbone is None) & (gnn is not None): backbone = gnn # Code continues after warning self.warning( - """DeprecationWarning: Argument `gnn` will be deprecated in GraphNeT 2.0. Please use `backbone` instead.""" + "DeprecationWarning: Argument `gnn` will be deprecated in" + " GraphNeT 2.0. Please use `backbone` instead." + "" ) elif (backbone is None) & (gnn is None): # Code stops raise TypeError( - "__init__() missing 1 required keyword-only argument: 'backbone'" + "__init__() missing 1 required keyword argument:'backbone'" ) assert isinstance(backbone, GNN) # Member variable(s) - self._graph_definition = graph_definition self.backbone = backbone - self._tasks = ModuleList(tasks) - self._optimizer_class = optimizer_class - self._optimizer_kwargs = optimizer_kwargs or dict() - self._scheduler_class = scheduler_class - self._scheduler_kwargs = scheduler_kwargs or dict() - self._scheduler_config = scheduler_config or dict() - - # set dtype of GNN from graph_definition - self.backbone.type(self._graph_definition._dtype) - - @staticmethod - def _construct_trainer( - max_epochs: int = 10, - gpus: Optional[Union[List[int], int]] = None, - callbacks: Optional[List[Callback]] = None, - logger: Optional[LightningLogger] = None, - log_every_n_steps: int = 1, - gradient_clip_val: Optional[float] = None, - distribution_strategy: Optional[str] = "ddp", - **trainer_kwargs: Any, - ) -> Trainer: - if gpus: - accelerator = "gpu" - devices = gpus - else: - accelerator = "cpu" - devices = 1 - - trainer = Trainer( - accelerator=accelerator, - devices=devices, - max_epochs=max_epochs, - callbacks=callbacks, - log_every_n_steps=log_every_n_steps, - logger=logger, - gradient_clip_val=gradient_clip_val, - strategy=distribution_strategy, - **trainer_kwargs, - ) - - return trainer - - def fit( - self, - train_dataloader: DataLoader, - val_dataloader: Optional[DataLoader] = None, - *, - max_epochs: int = 10, - early_stopping_patience: int = 5, - gpus: Optional[Union[List[int], int]] = None, - callbacks: Optional[List[Callback]] = None, - ckpt_path: Optional[str] = None, - logger: Optional[LightningLogger] = None, - log_every_n_steps: int = 1, - gradient_clip_val: Optional[float] = None, - distribution_strategy: Optional[str] = "ddp", - **trainer_kwargs: Any, - ) -> None: - """Fit `StandardModel` using `pytorch_lightning.Trainer`.""" - # Checks - if callbacks is None: - # We create the bare-minimum callbacks for you. - callbacks = self._create_default_callbacks( - val_dataloader=val_dataloader, - early_stopping_patience=early_stopping_patience, - ) - self.debug("No Callbacks specified. Default callbacks added.") - else: - # You are on your own! - self.debug("Initializing training with user-provided callbacks.") - pass - self._print_callbacks(callbacks) - has_early_stopping = self._contains_callback(callbacks, EarlyStopping) - has_model_checkpoint = self._contains_callback( - callbacks, ModelCheckpoint - ) - - if (has_early_stopping) & (has_model_checkpoint is False): - self.warning( - """No ModelCheckpoint found in callbacks. Best-fit model will not automatically be loaded after training!""" - ) - - self.train(mode=True) - trainer = self._construct_trainer( - max_epochs=max_epochs, - gpus=gpus, - callbacks=callbacks, - logger=logger, - log_every_n_steps=log_every_n_steps, - gradient_clip_val=gradient_clip_val, - distribution_strategy=distribution_strategy, - **trainer_kwargs, - ) - - try: - trainer.fit( - self, train_dataloader, val_dataloader, ckpt_path=ckpt_path - ) - except KeyboardInterrupt: - self.warning("[ctrl+c] Exiting gracefully.") - pass - - # Load weights from best-fit model after training if possible - if has_early_stopping & has_model_checkpoint: - for callback in callbacks: - if isinstance(callback, ModelCheckpoint): - checkpoint_callback = callback - self.load_state_dict( - torch.load(checkpoint_callback.best_model_path)["state_dict"] - ) - self.info("Best-fit weights from EarlyStopping loaded.") - - def _print_callbacks(self, callbacks: List[Callback]) -> None: - callback_names = [] - for cbck in callbacks: - callback_names.append(cbck.__class__.__name__) - self.info( - f"Training initiated with callbacks: {', '.join(callback_names)}" - ) - - def _contains_callback( - self, callbacks: List[Callback], callback: Callback - ) -> bool: - """Check if `callback` is in `callbacks`.""" - for cbck in callbacks: - if isinstance(cbck, callback): - return True - return False - - @property - def target_labels(self) -> List[str]: - """Return target label.""" - return [label for task in self._tasks for label in task._target_labels] - - @property - def prediction_labels(self) -> List[str]: - """Return prediction labels.""" - return [ - label for task in self._tasks for label in task._prediction_labels - ] - - def configure_optimizers(self) -> Dict[str, Any]: - """Configure the model's optimizer(s).""" - optimizer = self._optimizer_class( - self.parameters(), **self._optimizer_kwargs - ) - config = { - "optimizer": optimizer, - } - if self._scheduler_class is not None: - scheduler = self._scheduler_class( - optimizer, **self._scheduler_kwargs - ) - config.update( - { - "lr_scheduler": { - "scheduler": scheduler, - **self._scheduler_config, - }, - } - ) - return config - - def forward( - self, data: Union[Data, List[Data]] - ) -> List[Union[Tensor, Data]]: - """Forward pass, chaining model components.""" - if isinstance(data, Data): - data = [data] - x_list = [] - for d in data: - x = self.backbone(d) - x_list.append(x) - x = torch.cat(x_list, dim=0) - - preds = [task(x) for task in self._tasks] - return preds - - def shared_step(self, batch: List[Data], batch_idx: int) -> Tensor: - """Perform shared step. - - Applies the forward pass and the following loss calculation, shared - between the training and validation step. - """ - preds = self(batch) - loss = self.compute_loss(preds, batch) - return loss - - def training_step( - self, train_batch: Union[Data, List[Data]], batch_idx: int - ) -> Tensor: - """Perform training step.""" - if isinstance(train_batch, Data): - train_batch = [train_batch] - loss = self.shared_step(train_batch, batch_idx) - self.log( - "train_loss", - loss, - batch_size=self._get_batch_size(train_batch), - prog_bar=True, - on_epoch=True, - on_step=False, - sync_dist=True, - ) - - current_lr = self.trainer.optimizers[0].param_groups[0]["lr"] - self.log("lr", current_lr, prog_bar=True, on_step=True) - return loss - - def validation_step( - self, val_batch: Union[Data, List[Data]], batch_idx: int - ) -> Tensor: - """Perform validation step.""" - if isinstance(val_batch, Data): - val_batch = [val_batch] - loss = self.shared_step(val_batch, batch_idx) - self.log( - "val_loss", - loss, - batch_size=self._get_batch_size(val_batch), - prog_bar=True, - on_epoch=True, - on_step=False, - sync_dist=True, - ) - return loss def compute_loss( self, preds: Tensor, data: List[Data], verbose: bool = False @@ -321,216 +73,33 @@ def compute_loss( ), "Please reduce loss for each task separately" return torch.sum(torch.stack(losses)) - def inference(self) -> None: - """Activate inference mode.""" - for task in self._tasks: - task.inference() - - def train(self, mode: bool = True) -> "Model": - """Deactivate inference mode.""" - super().train(mode) - if mode: - for task in self._tasks: - task.train_eval() - return self - - def predict( - self, - dataloader: DataLoader, - gpus: Optional[Union[List[int], int]] = None, - distribution_strategy: Optional[str] = "auto", - ) -> List[Tensor]: - """Return predictions for `dataloader`.""" - self.inference() - self.train(mode=False) - - callbacks = self._create_default_callbacks( - val_dataloader=None, - ) - - inference_trainer = self._construct_trainer( - gpus=gpus, - distribution_strategy=distribution_strategy, - callbacks=callbacks, - ) - - predictions_list = inference_trainer.predict(self, dataloader) - assert len(predictions_list), "Got no predictions" - - nb_outputs = len(predictions_list[0]) - predictions: List[Tensor] = [ - torch.cat([preds[ix] for preds in predictions_list], dim=0) - for ix in range(nb_outputs) - ] - return predictions - - def predict_as_dataframe( - self, - dataloader: DataLoader, - prediction_columns: Optional[List[str]] = None, - *, - additional_attributes: Optional[List[str]] = None, - gpus: Optional[Union[List[int], int]] = None, - distribution_strategy: Optional[str] = "auto", - ) -> pd.DataFrame: - """Return predictions for `dataloader` as a DataFrame. - - Include `additional_attributes` as additional columns in the output - DataFrame. - """ - if prediction_columns is None: - prediction_columns = self.prediction_labels - - if additional_attributes is None: - additional_attributes = [] - assert isinstance(additional_attributes, list) - - if ( - not isinstance(dataloader.sampler, SequentialSampler) - and additional_attributes - ): - print(dataloader.sampler) - raise UserWarning( - "DataLoader has a `sampler` that is not `SequentialSampler`, " - "indicating that shuffling is enabled. Using " - "`predict_as_dataframe` with `additional_attributes` assumes " - "that the sequence of batches in `dataloader` are " - "deterministic. Either call this method a `dataloader` which " - "doesn't resample batches; or do not request " - "`additional_attributes`." - ) - self.info(f"Column names for predictions are: \n {prediction_columns}") - predictions_torch = self.predict( - dataloader=dataloader, - gpus=gpus, - distribution_strategy=distribution_strategy, - ) - predictions = ( - torch.cat(predictions_torch, dim=1).detach().cpu().numpy() - ) - assert len(prediction_columns) == predictions.shape[1], ( - f"Number of provided column names ({len(prediction_columns)}) and " - f"number of output columns ({predictions.shape[1]}) don't match." - ) - - # Check if predictions are on event- or pulse-level - pulse_level_predictions = len(predictions) > len(dataloader.dataset) - - # Get additional attributes - attributes: Dict[str, List[np.ndarray]] = OrderedDict( - [(attr, []) for attr in additional_attributes] - ) - for batch in dataloader: - for attr in attributes: - attribute = batch[attr] - if isinstance(attribute, torch.Tensor): - attribute = attribute.detach().cpu().numpy() - - # Check if node level predictions - # If true, additional attributes are repeated - # to make dimensions fit - if pulse_level_predictions: - if len(attribute) < np.sum( - batch.n_pulses.detach().cpu().numpy() - ): - attribute = np.repeat( - attribute, batch.n_pulses.detach().cpu().numpy() - ) - attributes[attr].extend(attribute) - - # Confirm that attributes match length of predictions - skip_attributes = [] - for attr in attributes.keys(): - try: - assert len(attributes[attr]) == len(predictions) - except AssertionError: - self.warning_once( - "Could not automatically adjust length" - f" of additional attribute '{attr}' to match length of" - f" predictions.This error can be caused by heavy" - " disagreement between number of examples in the" - " dataset vs. actual events in the dataloader, e.g. " - " heavy filtering of events in `collate_fn` passed to" - " `dataloader`. This can also be caused by requesting" - " pulse-level attributes for `Task`s that produce" - " event-level predictions. Attribute skipped." - ) - skip_attributes.append(attr) - - # Remove bad attributes - for attr in skip_attributes: - attributes.pop(attr) - additional_attributes.remove(attr) - - data = np.concatenate( - [predictions] - + [ - np.asarray(values)[:, np.newaxis] - for values in attributes.values() - ], - axis=1, - ) + def forward( + self, data: Union[Data, List[Data]] + ) -> List[Union[Tensor, Data]]: + """Forward pass, chaining model components.""" + if isinstance(data, Data): + data = [data] + x_list = [] + for d in data: + x = self.backbone(d) + x_list.append(x) + x = torch.cat(x_list, dim=0) - results = pd.DataFrame( - data, columns=prediction_columns + additional_attributes - ) - return results + preds = [task(x) for task in self._tasks] + return preds - def _create_default_callbacks( - self, - val_dataloader: DataLoader, - early_stopping_patience: Optional[int] = None, - ) -> List: - """Create default callbacks. + def shared_step(self, batch: List[Data], batch_idx: int) -> Tensor: + """Perform shared step. - Used in cases where no callbacks are specified by the user in .fit + Applies the forward pass and the following loss calculation, shared + between the training and validation step. """ - callbacks = [ProgressBar()] - if val_dataloader is not None: - assert early_stopping_patience is not None - # Add Early Stopping - callbacks.append( - EarlyStopping( - monitor="val_loss", - patience=early_stopping_patience, - ) - ) - # Add Model Check Point - callbacks.append( - ModelCheckpoint( - save_top_k=1, - monitor="val_loss", - mode="min", - filename=f"{self.backbone.__class__.__name__}" - + "-{epoch}-{val_loss:.2f}-{train_loss:.2f}", - ) - ) - self.info( - f"EarlyStopping has been added with a patience of {early_stopping_patience}." - ) - return callbacks - - def _add_early_stopping( - self, val_dataloader: DataLoader, callbacks: List - ) -> List: - if val_dataloader is None: - return callbacks - has_early_stopping = False - assert isinstance(callbacks, list) - for callback in callbacks: - if isinstance(callback, EarlyStopping): - has_early_stopping = True + preds = self(batch) + loss = self.compute_loss(preds, batch) + return loss - if not has_early_stopping: - callbacks.append( - EarlyStopping( - monitor="val_loss", - patience=5, - ) - ) - self.warning_once( - "Got validation dataloader but no EarlyStopping callback. An " - "EarlyStopping callback has been added automatically with " - "patience=5 and monitor = 'val_loss'." - ) - return callbacks + def validate_tasks(self) -> None: + """Verify that self._tasks contain compatible elements.""" + accepted_tasks = StandardLearnedTask + for task in self._tasks: + assert isinstance(task, accepted_tasks) From b762a0de9bcd7457f1906a6c301013e9487ae711 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Rasmus=20=C3=98rs=C3=B8e?= Date: Fri, 3 May 2024 13:07:12 +0200 Subject: [PATCH 03/16] rename EasyModel -> EasySyntax --- src/graphnet/models/easy_model.py | 2 +- src/graphnet/models/standard_model.py | 4 ++-- 2 files changed, 3 insertions(+), 3 deletions(-) diff --git a/src/graphnet/models/easy_model.py b/src/graphnet/models/easy_model.py index dd40465ea..f8cd873be 100644 --- a/src/graphnet/models/easy_model.py +++ b/src/graphnet/models/easy_model.py @@ -21,7 +21,7 @@ from graphnet.models.task import StandardLearnedTask -class EasyModel(Model): +class EasySyntax(Model): """A suggested Model class that comes with simple user syntax. This class delivers simple user syntax for training and prediction, while diff --git a/src/graphnet/models/standard_model.py b/src/graphnet/models/standard_model.py index a3a539a5e..08d9c83e6 100644 --- a/src/graphnet/models/standard_model.py +++ b/src/graphnet/models/standard_model.py @@ -6,11 +6,11 @@ from torch_geometric.data import Data from graphnet.models.gnn.gnn import GNN -from .easy_model import EasyModel +from .easy_model import EasySyntax from graphnet.models.task import StandardLearnedTask -class StandardModel(EasyModel): +class StandardModel(EasySyntax): """A Standard way of combining model components in GraphNeT. This model is compatible with the vast majority of supervised learning From 5556bb84d8ba60df10cdee2d74d6e528dd3a9da2 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Rasmus=20=C3=98rs=C3=B8e?= Date: Fri, 3 May 2024 14:13:49 +0200 Subject: [PATCH 04/16] move `.validate_tasks` to after `self._tasks` have been set. --- src/graphnet/models/easy_model.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/src/graphnet/models/easy_model.py b/src/graphnet/models/easy_model.py index f8cd873be..dfed4b205 100644 --- a/src/graphnet/models/easy_model.py +++ b/src/graphnet/models/easy_model.py @@ -46,7 +46,6 @@ def __init__( # Check(s) if not isinstance(tasks, (list, tuple)): tasks = [tasks] - self.validate_tasks() assert isinstance(graph_definition, GraphDefinition) @@ -59,6 +58,8 @@ def __init__( self._scheduler_kwargs = scheduler_kwargs or dict() self._scheduler_config = scheduler_config or dict() + self.validate_tasks() + def compute_loss( self, preds: Tensor, data: List[Data], verbose: bool = False ) -> Tensor: From e75f4a4418bee544af2e0ee38a3b952142fb6b95 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Rasmus=20=C3=98rs=C3=B8e?= Date: Fri, 3 May 2024 14:37:35 +0200 Subject: [PATCH 05/16] make arguments explicit in StandardModel --- src/graphnet/models/standard_model.py | 22 +++++++++++++++++++--- 1 file changed, 19 insertions(+), 3 deletions(-) diff --git a/src/graphnet/models/standard_model.py b/src/graphnet/models/standard_model.py index 08d9c83e6..6069c6deb 100644 --- a/src/graphnet/models/standard_model.py +++ b/src/graphnet/models/standard_model.py @@ -1,13 +1,15 @@ """Standard model class(es).""" -from typing import Any, Dict, List, Optional, Union +from typing import Any, Dict, List, Optional, Union, Type import torch from torch import Tensor from torch_geometric.data import Data +from torch.optim import Adam from graphnet.models.gnn.gnn import GNN from .easy_model import EasySyntax from graphnet.models.task import StandardLearnedTask +from graphnet.models.graphs import GraphDefinition class StandardModel(EasySyntax): @@ -21,13 +23,27 @@ class StandardModel(EasySyntax): def __init__( self, + graph_definition: GraphDefinition, + tasks: Union[StandardLearnedTask, List[StandardLearnedTask]], backbone: GNN = None, gnn: Optional[GNN] = None, - **easy_model_kwargs: Any, + optimizer_class: Type[torch.optim.Optimizer] = Adam, + optimizer_kwargs: Optional[Dict] = None, + scheduler_class: Optional[type] = None, + scheduler_kwargs: Optional[Dict] = None, + scheduler_config: Optional[Dict] = None, ) -> None: """Construct `StandardModel`.""" # Base class constructor - super().__init__(**easy_model_kwargs) + super().__init__( + graph_definition=graph_definition, + tasks=tasks, + optimizer_class=optimizer_class, + optimizer_kwargs=optimizer_kwargs, + scheduler_class=scheduler_class, + scheduler_kwargs=scheduler_kwargs, + scheduler_config=scheduler_config, + ) # deprecation warnings if (backbone is None) & (gnn is not None): From 4f1dbaf1b5d389b9c6773586ca5bb5933e4ee743 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Rasmus=20=C3=98rs=C3=B8e?= Date: Fri, 3 May 2024 14:39:26 +0200 Subject: [PATCH 06/16] remove graph_definition arg to EasySyntax --- src/graphnet/models/easy_model.py | 4 ---- src/graphnet/models/standard_model.py | 5 ++++- 2 files changed, 4 insertions(+), 5 deletions(-) diff --git a/src/graphnet/models/easy_model.py b/src/graphnet/models/easy_model.py index dfed4b205..d26d88fa0 100644 --- a/src/graphnet/models/easy_model.py +++ b/src/graphnet/models/easy_model.py @@ -31,7 +31,6 @@ class EasySyntax(Model): def __init__( self, *, - graph_definition: GraphDefinition, tasks: Union[StandardLearnedTask, List[StandardLearnedTask]], optimizer_class: Type[torch.optim.Optimizer] = Adam, optimizer_kwargs: Optional[Dict] = None, @@ -47,10 +46,7 @@ def __init__( if not isinstance(tasks, (list, tuple)): tasks = [tasks] - assert isinstance(graph_definition, GraphDefinition) - # Member variable(s) - self._graph_definition = graph_definition self._tasks = ModuleList(tasks) self._optimizer_class = optimizer_class self._optimizer_kwargs = optimizer_kwargs or dict() diff --git a/src/graphnet/models/standard_model.py b/src/graphnet/models/standard_model.py index 6069c6deb..cfb814f94 100644 --- a/src/graphnet/models/standard_model.py +++ b/src/graphnet/models/standard_model.py @@ -36,7 +36,6 @@ def __init__( """Construct `StandardModel`.""" # Base class constructor super().__init__( - graph_definition=graph_definition, tasks=tasks, optimizer_class=optimizer_class, optimizer_kwargs=optimizer_kwargs, @@ -59,9 +58,13 @@ def __init__( raise TypeError( "__init__() missing 1 required keyword argument:'backbone'" ) + + # Checks assert isinstance(backbone, GNN) + assert isinstance(graph_definition, GraphDefinition) # Member variable(s) + self._graph_definition = graph_definition self.backbone = backbone def compute_loss( From fb4aa2703b00a97039cb19dc063e3c854c1d5af1 Mon Sep 17 00:00:00 2001 From: "askerosted@gmail.com" Date: Tue, 14 May 2024 11:24:34 +0900 Subject: [PATCH 07/16] add RMSEVonMisesFisher3DLoss --- src/graphnet/training/loss_functions.py | 54 +++++++++++++++++++++++++ 1 file changed, 54 insertions(+) diff --git a/src/graphnet/training/loss_functions.py b/src/graphnet/training/loss_functions.py index 624a5fa53..30bde11a0 100644 --- a/src/graphnet/training/loss_functions.py +++ b/src/graphnet/training/loss_functions.py @@ -21,6 +21,7 @@ from graphnet.models.model import Model from graphnet.utilities.decorators import final +import importlib class LossFunction(Model): @@ -443,3 +444,56 @@ def _forward(self, prediction: Tensor, target: Tensor) -> Tensor: kappa = prediction[:, 3] p = kappa.unsqueeze(1) * prediction[:, [0, 1, 2]] return self._evaluate(p, target) + + +# class LossCombiner(LossFunction): +# """Combine multiple loss functions into a single loss function.""" + +# def __init__(self, loss_functions: List[str], **kwargs: Any) -> None: +# """Construct `LossCombiner`.""" + +# super().__init__(**kwargs) +# self._loss_functions = [] +# for loss_function in loss_functions: +# loss = importlib.import_module(f"graphnet.training.loss_functions.{loss_function}") +# self._loss_functions.append(loss) + + +# def _forward(self, prediction: Tensor, target: Tensor) -> Tensor: +# """Calculate combined loss.""" +# for count, loss in enumerate(self._loss_functions): +# if count == 0: +# elements = loss.forward(prediction, target) +# else: +# elements += loss.forward(prediction, target) +# return elements + + +class RMSEVonMisesFisher3DLoss(VonMisesFisherLoss): + """von Mises-Fisher loss function vectors in the 3D plane.""" + + def _forward(self, prediction: Tensor, target: Tensor) -> Tensor: + """Calculate von Mises-Fisher loss for a direction in the 3D. + + Args: + prediction: Output of the model. Must have shape [N, 4] where + columns 0, 1, 2 are predictions of `direction` and last column + is an estimate of `kappa`. + target: Target tensor, extracted from graph object. + + Returns: + Elementwise von Mises-Fisher loss terms. Shape [N,] + """ + target = target.reshape(-1, 3) + # Check(s) + assert prediction.dim() == 2 and prediction.size()[1] == 4 + assert target.dim() == 2 + assert prediction.size()[0] == target.size()[0] + + kappa = prediction[:, 3] + p = kappa.unsqueeze(1) * prediction[:, [0, 1, 2]] + elements = 0.05 * self._evaluate(p, target) + elements += torch.sqrt( + torch.mean((prediction[:, :-1] - target) ** 2, dim=-1) + ) + return elements From 169f8fd43bd72da9483a0df543373f6ac00beb9b Mon Sep 17 00:00:00 2001 From: "askerosted@gmail.com" Date: Tue, 14 May 2024 11:56:59 +0900 Subject: [PATCH 08/16] add isolated nodes --- src/graphnet/models/graphs/__init__.py | 2 +- src/graphnet/models/graphs/graphs.py | 38 ++++++++++++++++++++++++++ 2 files changed, 39 insertions(+), 1 deletion(-) diff --git a/src/graphnet/models/graphs/__init__.py b/src/graphnet/models/graphs/__init__.py index ea5066307..6974cdddc 100644 --- a/src/graphnet/models/graphs/__init__.py +++ b/src/graphnet/models/graphs/__init__.py @@ -7,4 +7,4 @@ from .graph_definition import GraphDefinition -from .graphs import KNNGraph +from .graphs import KNNGraph, IsolatedNodes diff --git a/src/graphnet/models/graphs/graphs.py b/src/graphnet/models/graphs/graphs.py index d486bba0a..bdb442f09 100644 --- a/src/graphnet/models/graphs/graphs.py +++ b/src/graphnet/models/graphs/graphs.py @@ -54,3 +54,41 @@ def __init__( perturbation_dict=perturbation_dict, seed=seed, ) + + +class IsolatedNodes(GraphDefinition): + """A Graph representation where each node is isolated.""" + + def __init__( + self, + detector: Detector, + node_definition: NodeDefinition = None, + input_feature_names: Optional[List[str]] = None, + dtype: Optional[torch.dtype] = torch.float, + perturbation_dict: Optional[Dict[str, float]] = None, + seed: Optional[Union[int, Generator]] = None, + ) -> None: + """Construct isolated nodes graph representation. + + Args: + detector: Detector that represents your data. + node_definition: Definition of nodes in the graph. + input_feature_names: Name of input feature columns. + dtype: data type for node features. + perturbation_dict: Dictionary mapping a feature name to a standard + deviation according to which the values for this + feature should be randomly perturbed. Defaults + to None. + seed: seed or Generator used to randomly sample perturbations. + Defaults to None. + """ + # Base class constructor + super().__init__( + detector=detector, + node_definition=node_definition or NodesAsPulses(), + edge_definition=None, + dtype=dtype, + input_feature_names=input_feature_names, + perturbation_dict=perturbation_dict, + seed=seed, + ) From 661111b651d224e5244fbcb31efd5400445496dc Mon Sep 17 00:00:00 2001 From: "askerosted@gmail.com" Date: Tue, 14 May 2024 13:34:45 +0900 Subject: [PATCH 09/16] cleaning --- src/graphnet/training/loss_functions.py | 23 ----------------------- 1 file changed, 23 deletions(-) diff --git a/src/graphnet/training/loss_functions.py b/src/graphnet/training/loss_functions.py index 30bde11a0..aa646d744 100644 --- a/src/graphnet/training/loss_functions.py +++ b/src/graphnet/training/loss_functions.py @@ -446,29 +446,6 @@ def _forward(self, prediction: Tensor, target: Tensor) -> Tensor: return self._evaluate(p, target) -# class LossCombiner(LossFunction): -# """Combine multiple loss functions into a single loss function.""" - -# def __init__(self, loss_functions: List[str], **kwargs: Any) -> None: -# """Construct `LossCombiner`.""" - -# super().__init__(**kwargs) -# self._loss_functions = [] -# for loss_function in loss_functions: -# loss = importlib.import_module(f"graphnet.training.loss_functions.{loss_function}") -# self._loss_functions.append(loss) - - -# def _forward(self, prediction: Tensor, target: Tensor) -> Tensor: -# """Calculate combined loss.""" -# for count, loss in enumerate(self._loss_functions): -# if count == 0: -# elements = loss.forward(prediction, target) -# else: -# elements += loss.forward(prediction, target) -# return elements - - class RMSEVonMisesFisher3DLoss(VonMisesFisherLoss): """von Mises-Fisher loss function vectors in the 3D plane.""" From 3e3390472b13b3cfacfc84f465dd9cb2bebbea94 Mon Sep 17 00:00:00 2001 From: "askerosted@gmail.com" Date: Tue, 14 May 2024 13:43:50 +0900 Subject: [PATCH 10/16] cleanup --- src/graphnet/training/loss_functions.py | 1 - 1 file changed, 1 deletion(-) diff --git a/src/graphnet/training/loss_functions.py b/src/graphnet/training/loss_functions.py index aa646d744..b41aaa6f5 100644 --- a/src/graphnet/training/loss_functions.py +++ b/src/graphnet/training/loss_functions.py @@ -21,7 +21,6 @@ from graphnet.models.model import Model from graphnet.utilities.decorators import final -import importlib class LossFunction(Model): From 0750dc67d74262982bffa3e70111e7fcdb678926 Mon Sep 17 00:00:00 2001 From: "askerosted@gmail.com" Date: Thu, 16 May 2024 22:30:51 +0900 Subject: [PATCH 11/16] requirements update --- requirements/torch_cu118.txt | 1 + requirements/torch_cu121.txt | 1 + 2 files changed, 2 insertions(+) diff --git a/requirements/torch_cu118.txt b/requirements/torch_cu118.txt index d01bdf439..31642c08e 100644 --- a/requirements/torch_cu118.txt +++ b/requirements/torch_cu118.txt @@ -1,4 +1,5 @@ # Contains packages requirements for GPU installation --find-links https://download.pytorch.org/whl/torch_stable.html torch==2.2.0+cu118 +torchvision==0.17.0+cu118 --find-links https://data.pyg.org/whl/torch-2.2.0+cu118.html diff --git a/requirements/torch_cu121.txt b/requirements/torch_cu121.txt index 477a67fb8..d7422e645 100644 --- a/requirements/torch_cu121.txt +++ b/requirements/torch_cu121.txt @@ -1,4 +1,5 @@ # Contains packages requirements for GPU installation --find-links https://download.pytorch.org/whl/torch_stable.html torch==2.2.0+cu121 +torchvision==0.17.0+cu121 --find-links https://data.pyg.org/whl/torch-2.2.0+cu121.html From c9eadcc5ec9fddae85bff9962f15195719aa75e2 Mon Sep 17 00:00:00 2001 From: "askerosted@gmail.com" Date: Mon, 20 May 2024 14:09:42 +0900 Subject: [PATCH 12/16] docstring update --- src/graphnet/models/components/embedding.py | 2 +- src/graphnet/models/graphs/__init__.py | 2 +- src/graphnet/models/graphs/graphs.py | 7 +++++-- 3 files changed, 7 insertions(+), 4 deletions(-) diff --git a/src/graphnet/models/components/embedding.py b/src/graphnet/models/components/embedding.py index e97ca90e7..1b49cd901 100644 --- a/src/graphnet/models/components/embedding.py +++ b/src/graphnet/models/components/embedding.py @@ -84,7 +84,6 @@ def __init__( super().__init__() self.sin_emb = SinusoidalPosEmb(dim=seq_length, scaled=scaled) - self.aux_emb = nn.Embedding(2, seq_length // 2) self.sin_emb2 = SinusoidalPosEmb(dim=seq_length // 2, scaled=scaled) if n_features < 4: @@ -93,6 +92,7 @@ def __init__( f"{n_features} features." ) elif n_features >= 6: + self.aux_emb = nn.Embedding(2, seq_length // 2) hidden_dim = 6 * seq_length else: hidden_dim = int((n_features + 0.5) * seq_length) diff --git a/src/graphnet/models/graphs/__init__.py b/src/graphnet/models/graphs/__init__.py index 6974cdddc..a07d1308d 100644 --- a/src/graphnet/models/graphs/__init__.py +++ b/src/graphnet/models/graphs/__init__.py @@ -7,4 +7,4 @@ from .graph_definition import GraphDefinition -from .graphs import KNNGraph, IsolatedNodes +from .graphs import KNNGraph, EdgelessGraph diff --git a/src/graphnet/models/graphs/graphs.py b/src/graphnet/models/graphs/graphs.py index bdb442f09..0289b943d 100644 --- a/src/graphnet/models/graphs/graphs.py +++ b/src/graphnet/models/graphs/graphs.py @@ -56,8 +56,11 @@ def __init__( ) -class IsolatedNodes(GraphDefinition): - """A Graph representation where each node is isolated.""" +class EdgelessGraph(GraphDefinition): + """A Data representation without edge assignment. + + I.e the resulting representation is created without an EdgeDefinition. + """ def __init__( self, From b9c3195b48e5586967248660d21d841b7cc624b0 Mon Sep 17 00:00:00 2001 From: "askerosted@gmail.com" Date: Mon, 20 May 2024 14:55:25 +0900 Subject: [PATCH 13/16] large refactor --- src/graphnet/training/loss_functions.py | 93 ++++++++++++++++++++----- 1 file changed, 75 insertions(+), 18 deletions(-) diff --git a/src/graphnet/training/loss_functions.py b/src/graphnet/training/loss_functions.py index b41aaa6f5..6468e5296 100644 --- a/src/graphnet/training/loss_functions.py +++ b/src/graphnet/training/loss_functions.py @@ -445,31 +445,88 @@ def _forward(self, prediction: Tensor, target: Tensor) -> Tensor: return self._evaluate(p, target) -class RMSEVonMisesFisher3DLoss(VonMisesFisherLoss): - """von Mises-Fisher loss function vectors in the 3D plane.""" +class EnsembleLoss(LossFunction): + """Chain multiple loss functions together.""" + + def __init__( + self, + loss_functions: List[LossFunction], + loss_factors: List[float] = None, + prediction_keys: Optional[List[List[int]]] = None, + ) -> None: + """Chain multiple loss functions together. + + Optionally apply a weight to each loss function contribution. + + E.g. Loss = RMSE*0.5 + LogCoshLoss*1.5 + + Args: + loss_functions: A list of loss functions to use. + Each loss function contributes a term to the overall loss. + loss_factors: An optional list of factors that will be mulitplied + to each loss function contribution. Must be ordered according + to `loss_functions`. If not given, the weights default to 1. + prediction_keys: An optional list of lists of indices for which + prediction columns to use for each loss function. If not + given, all columns are used for all loss functions. + """ + if loss_factors is None: + # add weight of 1 - i.e no discrimination + loss_factors = np.repeat(1, len(loss_functions)).tolist() + + assert len(loss_functions) == len(loss_factors) + self._factors = loss_factors + self._loss_functions = loss_functions + + if prediction_keys is not None: + self._prediction_keys: Optional[List[List[int]]] = prediction_keys + else: + self._prediction_keys = None def _forward(self, prediction: Tensor, target: Tensor) -> Tensor: - """Calculate von Mises-Fisher loss for a direction in the 3D. + """Calculate loss using multiple loss functions. Args: - prediction: Output of the model. Must have shape [N, 4] where - columns 0, 1, 2 are predictions of `direction` and last column - is an estimate of `kappa`. + prediction: Output of the model. target: Target tensor, extracted from graph object. Returns: - Elementwise von Mises-Fisher loss terms. Shape [N,] + Elementwise loss terms. Shape [N,] """ - target = target.reshape(-1, 3) - # Check(s) - assert prediction.dim() == 2 and prediction.size()[1] == 4 - assert target.dim() == 2 - assert prediction.size()[0] == target.size()[0] + if self._prediction_keys is None: + prediction_keys = [list(range(prediction.size(1)))] * len( + self._loss_functions + ) + else: + prediction_keys = self._prediction_keys + for k, (loss_function, prediction_key) in enumerate( + zip(self._loss_functions, prediction_keys) + ): + if k == 0: + elements = self._factors[k] * loss_function._forward( + prediction=prediction[prediction_key], target=target + ) + else: + elements += self._factors[k] * loss_function._forward( + prediction=prediction[prediction_key], target=target + ) + return elements - kappa = prediction[:, 3] - p = kappa.unsqueeze(1) * prediction[:, [0, 1, 2]] - elements = 0.05 * self._evaluate(p, target) - elements += torch.sqrt( - torch.mean((prediction[:, :-1] - target) ** 2, dim=-1) + +class RMSEVonMisesFisher3DLoss(EnsembleLoss): + """Combine the VonMisesFisher3DLoss with RMSELoss.""" + + def __init__(self, vmfs_factor: float = 0.05) -> None: + """VonMisesFisher3DLoss with a RMSE penality term. + + The VonMisesFisher3DLoss will be weighted with `vmfs_factor`. + + Args: + vmfs_factor: A factor applied to the VonMisesFisher3DLoss term. + Defaults ot 0.05. + """ + super().__init__( + loss_functions=[RMSELoss(), VonMisesFisher3DLoss()], + loss_factors=[1, vmfs_factor], + prediction_keys=[[0, 1, 2], [0, 1, 2, 3]], ) - return elements From 338814c2fc5bad8081d5f8ace77acb7c0c9246fa Mon Sep 17 00:00:00 2001 From: "askerosted@gmail.com" Date: Mon, 20 May 2024 15:31:01 +0900 Subject: [PATCH 14/16] fixing --- src/graphnet/training/loss_functions.py | 9 +++++++-- 1 file changed, 7 insertions(+), 2 deletions(-) diff --git a/src/graphnet/training/loss_functions.py b/src/graphnet/training/loss_functions.py index 6468e5296..d3fc43f7e 100644 --- a/src/graphnet/training/loss_functions.py +++ b/src/graphnet/training/loss_functions.py @@ -71,6 +71,8 @@ def _forward(self, prediction: Tensor, target: Tensor) -> Tensor: """Implement loss calculation.""" # Check(s) assert prediction.dim() == 2 + if target.dim() != prediction.dim(): + target = target.squeeze(1) assert prediction.size() == target.size() elements = torch.mean((prediction - target) ** 2, dim=-1) @@ -453,6 +455,8 @@ def __init__( loss_functions: List[LossFunction], loss_factors: List[float] = None, prediction_keys: Optional[List[List[int]]] = None, + *args: Any, + **kwargs: Any, ) -> None: """Chain multiple loss functions together. @@ -482,6 +486,7 @@ def __init__( self._prediction_keys: Optional[List[List[int]]] = prediction_keys else: self._prediction_keys = None + super().__init__(*args, **kwargs) def _forward(self, prediction: Tensor, target: Tensor) -> Tensor: """Calculate loss using multiple loss functions. @@ -504,11 +509,11 @@ def _forward(self, prediction: Tensor, target: Tensor) -> Tensor: ): if k == 0: elements = self._factors[k] * loss_function._forward( - prediction=prediction[prediction_key], target=target + prediction=prediction[:, prediction_key], target=target ) else: elements += self._factors[k] * loss_function._forward( - prediction=prediction[prediction_key], target=target + prediction=prediction[:, prediction_key], target=target ) return elements From ba214615cc04a78085ed959dc29ed7e5cb21c5a6 Mon Sep 17 00:00:00 2001 From: Rasmus Oersoe Date: Mon, 20 May 2024 15:47:27 +0200 Subject: [PATCH 15/16] Change defaulting behavior of GraphDefinition --- src/graphnet/models/graphs/graph_definition.py | 5 ++++- 1 file changed, 4 insertions(+), 1 deletion(-) diff --git a/src/graphnet/models/graphs/graph_definition.py b/src/graphnet/models/graphs/graph_definition.py index 6366fc390..e384425f9 100644 --- a/src/graphnet/models/graphs/graph_definition.py +++ b/src/graphnet/models/graphs/graph_definition.py @@ -24,7 +24,7 @@ class GraphDefinition(Model): def __init__( self, detector: Detector, - node_definition: NodeDefinition = NodesAsPulses(), + node_definition: NodeDefinition = None, edge_definition: Optional[EdgeDefinition] = None, input_feature_names: Optional[List[str]] = None, dtype: Optional[torch.dtype] = torch.float, @@ -69,6 +69,9 @@ def __init__( # Base class constructor super().__init__(name=__name__, class_name=self.__class__.__name__) + if node_definition is None: + node_definition = NodesAsPulses() + # Member Variables self._detector = detector self._edge_definition = edge_definition From 96c7c695ed7cb93dd69b258c49cb4d70712c3e1c Mon Sep 17 00:00:00 2001 From: "askerosted@gmail.com" Date: Tue, 21 May 2024 10:49:10 +0900 Subject: [PATCH 16/16] revert embedding changes --- src/graphnet/models/components/embedding.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/src/graphnet/models/components/embedding.py b/src/graphnet/models/components/embedding.py index 1b49cd901..9539fc444 100644 --- a/src/graphnet/models/components/embedding.py +++ b/src/graphnet/models/components/embedding.py @@ -84,6 +84,7 @@ def __init__( super().__init__() self.sin_emb = SinusoidalPosEmb(dim=seq_length, scaled=scaled) + self.aux_emb = nn.Embedding(2, seq_length // 2) self.sin_emb2 = SinusoidalPosEmb(dim=seq_length // 2, scaled=scaled) if n_features < 4: @@ -92,7 +93,7 @@ def __init__( f"{n_features} features." ) elif n_features >= 6: - self.aux_emb = nn.Embedding(2, seq_length // 2) + hidden_dim = 6 * seq_length else: hidden_dim = int((n_features + 0.5) * seq_length)