diff --git a/.github/actions/install/action.yml b/.github/actions/install/action.yml index b2d6d2896..19e23be01 100644 --- a/.github/actions/install/action.yml +++ b/.github/actions/install/action.yml @@ -38,4 +38,5 @@ runs: run: | echo requirements/torch_${{ inputs.hardware }}.txt ${{ env.PIP_FLAGS }} .${{ inputs.extras }} pip install -r requirements/torch_${{ inputs.hardware }}.txt ${{ env.PIP_FLAGS }} .${{ inputs.extras }} + pip install git+https://github.com/thoglu/jammy_flows.git shell: bash diff --git a/.github/workflows/build.yml b/.github/workflows/build.yml index 50cf3c879..46b42f216 100644 --- a/.github/workflows/build.yml +++ b/.github/workflows/build.yml @@ -63,6 +63,14 @@ jobs: uses: ./.github/actions/install with: editable: true + - name: Print packages in pip + run: | + pip show torch + pip show torch-geometric + pip show torch-cluster + pip show torch-sparse + pip show torch-scatter + pip show jammy_flows - name: Run unit tests and generate coverage report run: | coverage run --source=graphnet -m pytest tests/ --ignore=tests/examples/04_training --ignore=tests/utilities @@ -110,6 +118,8 @@ jobs: pip show torch-sparse pip show torch-scatter pip show numpy + + - name: Run unit tests and generate coverage report run: | set -o pipefail # To propagate exit code from pytest diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml index 4794b3745..fd6bae19e 100644 --- a/.pre-commit-config.yaml +++ b/.pre-commit-config.yaml @@ -10,16 +10,20 @@ repos: rev: 4.0.1 hooks: - id: flake8 + language_version: python3 - repo: https://github.com/pycqa/docformatter rev: v1.5.0 hooks: - id: docformatter + language_version: python3 - repo: https://github.com/pycqa/pydocstyle rev: 6.1.1 hooks: - id: pydocstyle + language_version: python3 - repo: https://github.com/pre-commit/mirrors-mypy rev: v0.982 hooks: - id: mypy args: [--follow-imports=silent, --disallow-untyped-defs, --disallow-incomplete-defs, --disallow-untyped-calls] + language_version: python3 \ No newline at end of file diff --git a/docs/source/installation/quick-start.html b/docs/source/installation/quick-start.html index aff34659e..e80fd5b8d 100644 --- a/docs/source/installation/quick-start.html +++ b/docs/source/installation/quick-start.html @@ -107,20 +107,20 @@ } if (os == "linux" && cuda != "cpu" && torch != "no_torch"){ - $("#command pre").text(`git clone https://github.com/graphnet-team/graphnet.git\ncd graphnet\n\npip install -r requirements/torch_${$("#command").attr("cuda")}.txt -e .[torch,develop]`); + $("#command pre").text(`git clone https://github.com/graphnet-team/graphnet.git\ncd graphnet\n\npip install -r requirements/torch_${$("#command").attr("cuda")}.txt -e .[torch,develop]\n\n#Optionally, install jammy_flows for normalizing flow support:\npip install git+https://github.com/thoglu/jammy_flows.git`); } else if (os == "linux" && cuda == "cpu" && torch != "no_torch"){ - $("#command pre").text(`git clone https://github.com/graphnet-team/graphnet.git\ncd graphnet\n\npip install -r requirements/torch_${$("#command").attr("cuda")}.txt -e .[torch,develop]`); + $("#command pre").text(`git clone https://github.com/graphnet-team/graphnet.git\ncd graphnet\n\npip install -r requirements/torch_${$("#command").attr("cuda")}.txt -e .[torch,develop]\n\n#Optionally, install jammy_flows for normalizing flow support:\npip install git+https://github.com/thoglu/jammy_flows.git`); } else if (os == "linux" && cuda == "cpu" && torch == "no_torch"){ - $("#command pre").text(`# Installations without PyTorch are intended for file conversion only\ngit clone https://github.com/graphnet-team/graphnet.git\ncd graphnet\n\npip install -r requirements/torch_${$("#command").attr("cuda")}.txt -e .[develop]`); + $("#command pre").text(`# Installations without PyTorch are intended for file conversion only\ngit clone https://github.com/graphnet-team/graphnet.git\ncd graphnet\n\npip install -r requirements/torch_${$("#command").attr("cuda")}.txt -e .[develop]\n\n#Optionally, install jammy_flows for normalizing flow support:\npip install git+https://github.com/thoglu/jammy_flows.git`); } if (os == "macos" && cuda == "cpu" && torch != "no_torch"){ - $("#command pre").text(`git clone https://github.com/graphnet-team/graphnet.git\ncd graphnet\n\npip install -r requirements/torch_macos.txt -e .[torch,develop]`); + $("#command pre").text(`git clone https://github.com/graphnet-team/graphnet.git\ncd graphnet\n\npip install -r requirements/torch_macos.txt -e .[torch,develop]\n\n#Optionally, install jammy_flows for normalizing flow support:\npip install git+https://github.com/thoglu/jammy_flows.git`); } if (os == "macos" && cuda == "cpu" && torch == "no_torch"){ - $("#command pre").text(`# Installations without PyTorch are intended for file conversion only\ngit clone https://github.com/graphnet-team/graphnet.git\ncd graphnet\n\npip install -r requirements/torch_macos.txt -e .[develop]`); + $("#command pre").text(`# Installations without PyTorch are intended for file conversion only\ngit clone https://github.com/graphnet-team/graphnet.git\ncd graphnet\n\npip install -r requirements/torch_macos.txt -e .[develop]\n\n#Optionally, install jammy_flows for normalizing flow support:\npip install git+https://github.com/thoglu/jammy_flows.git`); } } diff --git a/examples/04_training/07_train_normalizing_flow.py b/examples/04_training/07_train_normalizing_flow.py new file mode 100644 index 000000000..baa3eec85 --- /dev/null +++ b/examples/04_training/07_train_normalizing_flow.py @@ -0,0 +1,235 @@ +"""Example of training a conditional NormalizingFlow.""" + +import os +from typing import Any, Dict, List, Optional + +from pytorch_lightning.loggers import WandbLogger +import torch +from torch.optim.adam import Adam + +from graphnet.constants import EXAMPLE_DATA_DIR, EXAMPLE_OUTPUT_DIR +from graphnet.data.constants import FEATURES, TRUTH +from graphnet.models.detector.prometheus import Prometheus +from graphnet.models.gnn import DynEdge +from graphnet.models.graphs import KNNGraph +from graphnet.training.callbacks import PiecewiseLinearLR +from graphnet.training.utils import make_train_validation_dataloader +from graphnet.utilities.argparse import ArgumentParser +from graphnet.utilities.logging import Logger +from graphnet.utilities.imports import has_jammy_flows_package + +# Make sure the jammy flows is installed +try: + assert has_jammy_flows_package() + from graphnet.models import NormalizingFlow +except AssertionError: + raise AssertionError( + "This example requires the package`jammy_flow` " + " to be installed. It appears that the package is " + " not installed. Please install the package." + ) + +# Constants +features = FEATURES.PROMETHEUS +truth = TRUTH.PROMETHEUS + + +def main( + path: str, + pulsemap: str, + target: str, + truth_table: str, + gpus: Optional[List[int]], + max_epochs: int, + early_stopping_patience: int, + batch_size: int, + num_workers: int, + wandb: bool = False, +) -> None: + """Run example.""" + # Construct Logger + logger = Logger() + + # Initialise Weights & Biases (W&B) run + if wandb: + # Make sure W&B output directory exists + wandb_dir = "./wandb/" + os.makedirs(wandb_dir, exist_ok=True) + wandb_logger = WandbLogger( + project="example-script", + entity="graphnet-team", + save_dir=wandb_dir, + log_model=True, + ) + + logger.info(f"features: {features}") + logger.info(f"truth: {truth}") + + # Configuration + config: Dict[str, Any] = { + "path": path, + "pulsemap": pulsemap, + "batch_size": batch_size, + "num_workers": num_workers, + "target": target, + "early_stopping_patience": early_stopping_patience, + "fit": { + "gpus": gpus, + "max_epochs": max_epochs, + }, + } + + archive = os.path.join(EXAMPLE_OUTPUT_DIR, "train_model_without_configs") + run_name = "dynedge_{}_example".format(config["target"]) + if wandb: + # Log configuration to W&B + wandb_logger.experiment.config.update(config) + + # Define graph representation + graph_definition = KNNGraph(detector=Prometheus()) + + ( + training_dataloader, + validation_dataloader, + ) = make_train_validation_dataloader( + db=config["path"], + graph_definition=graph_definition, + pulsemaps=config["pulsemap"], + features=features, + truth=truth, + batch_size=config["batch_size"], + num_workers=config["num_workers"], + truth_table=truth_table, + selection=None, + ) + + # Building model + + backbone = DynEdge( + nb_inputs=graph_definition.nb_outputs, + global_pooling_schemes=["min", "max", "mean", "sum"], + ) + + model = NormalizingFlow( + graph_definition=graph_definition, + backbone=backbone, + optimizer_class=Adam, + target_labels=config["target"], + optimizer_kwargs={"lr": 1e-03, "eps": 1e-03}, + scheduler_class=PiecewiseLinearLR, + scheduler_kwargs={ + "milestones": [ + 0, + len(training_dataloader) / 2, + len(training_dataloader) * config["fit"]["max_epochs"], + ], + "factors": [1e-2, 1, 1e-02], + }, + scheduler_config={ + "interval": "step", + }, + ) + + # Training model + model.fit( + training_dataloader, + validation_dataloader, + early_stopping_patience=config["early_stopping_patience"], + logger=wandb_logger if wandb else None, + **config["fit"], + ) + + # Get predictions + additional_attributes = model.target_labels + assert isinstance(additional_attributes, list) # mypy + + results = model.predict_as_dataframe( + validation_dataloader, + additional_attributes=additional_attributes + ["event_no"], + gpus=config["fit"]["gpus"], + ) + + # Save predictions and model to file + db_name = path.split("/")[-1].split(".")[0] + path = os.path.join(archive, db_name, run_name) + logger.info(f"Writing results to {path}") + os.makedirs(path, exist_ok=True) + + # Save results as .csv + results.to_csv(f"{path}/results.csv") + + # Save full model (including weights) to .pth file - not version safe + # Note: Models saved as .pth files in one version of graphnet + # may not be compatible with a different version of graphnet. + model.save(f"{path}/model.pth") + + # Save model config and state dict - Version safe save method. + # This method of saving models is the safest way. + model.save_state_dict(f"{path}/state_dict.pth") + model.save_config(f"{path}/model_config.yml") + + +if __name__ == "__main__": + + # Parse command-line arguments + parser = ArgumentParser( + description=""" +Train conditional NormalizingFlow without the use of config files. +""" + ) + + parser.add_argument( + "--path", + help="Path to dataset file (default: %(default)s)", + default=f"{EXAMPLE_DATA_DIR}/sqlite/prometheus/prometheus-events.db", + ) + + parser.add_argument( + "--pulsemap", + help="Name of pulsemap to use (default: %(default)s)", + default="total", + ) + + parser.add_argument( + "--target", + help=( + "Name of feature to use as regression target (default: " + "%(default)s)" + ), + default="total_energy", + ) + + parser.add_argument( + "--truth-table", + help="Name of truth table to be used (default: %(default)s)", + default="mc_truth", + ) + + parser.with_standard_arguments( + "gpus", + ("max-epochs", 1), + "early-stopping-patience", + ("batch-size", 50), + "num-workers", + ) + + parser.add_argument( + "--wandb", + action="store_true", + help="If True, Weights & Biases are used to track the experiment.", + ) + + args, unknown = parser.parse_known_args() + + main( + args.path, + args.pulsemap, + args.target, + args.truth_table, + args.gpus, + args.max_epochs, + args.early_stopping_patience, + args.batch_size, + args.num_workers, + args.wandb, + ) diff --git a/src/graphnet/models/__init__.py b/src/graphnet/models/__init__.py index a2e63befb..12d4cbcc5 100644 --- a/src/graphnet/models/__init__.py +++ b/src/graphnet/models/__init__.py @@ -6,8 +6,10 @@ existing, purpose-built components and chain them together to form a complete GNN """ - - +from graphnet.utilities.imports import has_jammy_flows_package from .model import Model from .standard_model import StandardModel from .standard_averaged_model import StandardAveragedModel + +if has_jammy_flows_package(): + from .normalizing_flow import NormalizingFlow diff --git a/src/graphnet/models/easy_model.py b/src/graphnet/models/easy_model.py index d26d88fa0..d3ed4f419 100644 --- a/src/graphnet/models/easy_model.py +++ b/src/graphnet/models/easy_model.py @@ -16,7 +16,6 @@ from pytorch_lightning.loggers import Logger as LightningLogger from graphnet.training.callbacks import ProgressBar -from graphnet.models.graphs import GraphDefinition from graphnet.models.model import Model from graphnet.models.task import StandardLearnedTask @@ -292,6 +291,7 @@ def predict( dataloader: DataLoader, gpus: Optional[Union[List[int], int]] = None, distribution_strategy: Optional[str] = "auto", + **trainer_kwargs: Any, ) -> List[Tensor]: """Return predictions for `dataloader`.""" self.inference() @@ -305,6 +305,7 @@ def predict( gpus=gpus, distribution_strategy=distribution_strategy, callbacks=callbacks, + **trainer_kwargs, ) predictions_list = inference_trainer.predict(self, dataloader) @@ -325,6 +326,7 @@ def predict_as_dataframe( additional_attributes: Optional[List[str]] = None, gpus: Optional[Union[List[int], int]] = None, distribution_strategy: Optional[str] = "auto", + **trainer_kwargs: Any, ) -> pd.DataFrame: """Return predictions for `dataloader` as a DataFrame. @@ -357,6 +359,7 @@ def predict_as_dataframe( dataloader=dataloader, gpus=gpus, distribution_strategy=distribution_strategy, + **trainer_kwargs, ) predictions = ( torch.cat(predictions_torch, dim=1).detach().cpu().numpy() diff --git a/src/graphnet/models/graphs/graph_definition.py b/src/graphnet/models/graphs/graph_definition.py index e384425f9..0338225b8 100644 --- a/src/graphnet/models/graphs/graph_definition.py +++ b/src/graphnet/models/graphs/graph_definition.py @@ -34,6 +34,7 @@ def __init__( sensor_mask: Optional[List[int]] = None, string_mask: Optional[List[int]] = None, sort_by: str = None, + repeat_labels: bool = False, ): """Construct ´GraphDefinition´. The ´detector´ holds. @@ -62,9 +63,14 @@ def __init__( add_inactive_sensors: If True, inactive sensors will be appended to the graph with padded pulse information. Defaults to False. sensor_mask: A list of sensor id's to be masked from the graph. Any - sensor listed here will be removed from the graph. Defaults to None. - string_mask: A list of string id's to be masked from the graph. Defaults to None. + sensor listed here will be removed from the graph. + Defaults to None. + string_mask: A list of string id's to be masked from the graph. + Defaults to None. sort_by: Name of node feature to sort by. Defaults to None. + repeat_labels: If True, labels will be repeated to match the + the number of rows in the output of the GraphDefinition. + Defaults to False. """ # Base class constructor super().__init__(name=__name__, class_name=self.__class__.__name__) @@ -80,6 +86,7 @@ def __init__( self._sensor_mask = sensor_mask self._string_mask = string_mask self._add_inactive_sensors = add_inactive_sensors + self._repeat_labels = repeat_labels self._resolve_masks() @@ -408,10 +415,14 @@ def _add_truth( """ # Write attributes, either target labels, truth info or original # features. + for truth_dict in truth_dicts: for key, value in truth_dict.items(): try: - graph[key] = torch.tensor(value) + label = torch.tensor(value) + if self._repeat_labels: + label = label.repeat(graph.x.shape[0], 1) + graph[key] = label except TypeError: # Cannot convert `value` to Tensor due to its data type, # e.g. `str`. @@ -448,5 +459,8 @@ def _add_custom_labels( ) -> Data: # Add custom labels to the graph for key, fn in custom_label_functions.items(): - graph[key] = fn(graph) + label = fn(graph) + if self._repeat_labels: + label = label.repeat(graph.x.shape[0], 1) + graph[key] = label return graph diff --git a/src/graphnet/models/graphs/graphs.py b/src/graphnet/models/graphs/graphs.py index 0289b943d..525675ca7 100644 --- a/src/graphnet/models/graphs/graphs.py +++ b/src/graphnet/models/graphs/graphs.py @@ -1,6 +1,6 @@ """A module containing different graph representations in GraphNeT.""" -from typing import List, Optional, Dict, Union +from typing import List, Optional, Dict, Union, Any import torch from numpy.random import Generator @@ -23,6 +23,7 @@ def __init__( seed: Optional[Union[int, Generator]] = None, nb_nearest_neighbours: int = 8, columns: List[int] = [0, 1, 2], + **kwargs: Any, ) -> None: """Construct k-nn graph representation. @@ -53,6 +54,7 @@ def __init__( input_feature_names=input_feature_names, perturbation_dict=perturbation_dict, seed=seed, + **kwargs, ) @@ -70,6 +72,7 @@ def __init__( dtype: Optional[torch.dtype] = torch.float, perturbation_dict: Optional[Dict[str, float]] = None, seed: Optional[Union[int, Generator]] = None, + **kwargs: Any, ) -> None: """Construct isolated nodes graph representation. @@ -94,4 +97,5 @@ def __init__( input_feature_names=input_feature_names, perturbation_dict=perturbation_dict, seed=seed, + **kwargs, ) diff --git a/src/graphnet/models/normalizing_flow.py b/src/graphnet/models/normalizing_flow.py new file mode 100644 index 000000000..d62cf7c42 --- /dev/null +++ b/src/graphnet/models/normalizing_flow.py @@ -0,0 +1,158 @@ +"""Standard model class(es).""" + +from typing import Any, Dict, List, Optional, Union, Type +import torch +from torch import Tensor +from torch_geometric.data import Data +from torch.optim import Adam + +from graphnet.models.gnn.gnn import GNN +from .easy_model import EasySyntax +from graphnet.models.task import StandardFlowTask +from graphnet.models.graphs import GraphDefinition +from graphnet.models.utils import get_fields + + +class NormalizingFlow(EasySyntax): + """A model for building (conditional) normalizing flows in GraphNeT. + + This model relies on `jammy_flows` for building and evaluating + normalizing flows. + https://thoglu.github.io/jammy_flows/usage/introduction.html + for details. + """ + + def __init__( + self, + graph_definition: GraphDefinition, + target_labels: str, + backbone: GNN = None, + condition_on: Union[str, List[str], None] = None, + flow_layers: str = "gggt", + optimizer_class: Type[torch.optim.Optimizer] = Adam, + optimizer_kwargs: Optional[Dict] = None, + scheduler_class: Optional[type] = None, + scheduler_kwargs: Optional[Dict] = None, + scheduler_config: Optional[Dict] = None, + ) -> None: + """Build NormalizingFlow to learn (conditional) normalizing flows. + + NormalizingFlow is able to build, train and evaluate a wide suite of + normalizing flows. Instead of optimizing a loss function, flows + minimize a learned pdf of your data, providing you with a posterior + distribution for every example instead of point-like predictions. + + `NormalizingFlow` can be conditioned on existing fields in the + DataRepresentation or latent representations from `Models`. + + NormalizingFlow is built upon https://github.com/thoglu/jammy_flows, + and we refer to their documentation for details on the flows. + + Args: + graph_definition: The `GraphDefinition` to train the model on. + target_labels: Name of target(s) to learn the pdf of. + backbone: Architecture used to produce latent representations of + the input data on which the pdf will be conditioned. + Defaults to None. + condition_on: List of fields in Data objects to condition the + pdf on. Defaults to None. + flow_layers: A string defining the flow layers. + See https://thoglu.github.io/jammy_flows/usage/introduction.html + for details. Defaults to "gggt". + optimizer_class: Optimizer to use. Defaults to Adam. + optimizer_kwargs: Optimzier arguments. Defaults to None. + scheduler_class: Learning rate scheduler to use. Defaults to None. + scheduler_kwargs: Arguments to learning rate scheduler. + Defaults to None. + scheduler_config: Defaults to None. + + Raises: + ValueError: if both `backbone` and `condition_on` is specified. + """ + # Checks + if (backbone is not None) & (condition_on is not None): + # If user wants to condition on both + raise ValueError( + f"{self.__class__.__name__} got values for both " + "`backbone` and `condition_on`, but can only" + "condition on one of those. Please specify just " + "one of these arguments." + ) + + # Handle args + if backbone is not None: + assert isinstance(backbone, GNN) + hidden_size = backbone.nb_outputs + elif condition_on is not None: + if isinstance(condition_on, str): + condition_on = [condition_on] + hidden_size = len(condition_on) + else: + hidden_size = None + + # Build Flow Task + task = StandardFlowTask( + hidden_size=hidden_size, + flow_layers=flow_layers, + target_labels=target_labels, + ) + + # Base class constructor + super().__init__( + tasks=task, + optimizer_class=optimizer_class, + optimizer_kwargs=optimizer_kwargs, + scheduler_class=scheduler_class, + scheduler_kwargs=scheduler_kwargs, + scheduler_config=scheduler_config, + ) + + # Member variable(s) + self._graph_definition = graph_definition + self.backbone = backbone + self._condition_on = condition_on + self._norm = torch.nn.BatchNorm1d(hidden_size) + + def forward(self, data: Union[Data, List[Data]]) -> Tensor: + """Forward pass, chaining model components.""" + if isinstance(data, Data): + data = [data] + x_list = [] + for d in data: + if self.backbone is not None: + x = self._backbone(d) + x = self._norm(x) + elif self._condition_on is not None: + assert isinstance(self._condition_on, list) + x = get_fields(data=d, fields=self._condition_on) + else: + # Unconditional flow + x = None + x = self._tasks[0](x, d) + x_list.append(x) + x = torch.cat(x_list, dim=0) + return [x] + + def _backbone( + self, data: Union[Data, List[Data]] + ) -> List[Union[Tensor, Data]]: + assert self.backbone is not None + return self.backbone(data) + + def shared_step(self, batch: List[Data], batch_idx: int) -> Tensor: + """Perform shared step. + + Applies the forward pass and the following loss calculation, shared + between the training and validation step. + """ + loss = self(batch) + if isinstance(loss, list): + assert len(loss) == 1 + loss = loss[0] + return torch.mean(loss, dim=0) + + def validate_tasks(self) -> None: + """Verify that self._tasks contain compatible elements.""" + accepted_tasks = StandardFlowTask + for task in self._tasks: + assert isinstance(task, accepted_tasks) diff --git a/src/graphnet/models/task/task.py b/src/graphnet/models/task/task.py index cd750f35d..0b9101107 100644 --- a/src/graphnet/models/task/task.py +++ b/src/graphnet/models/task/task.py @@ -4,6 +4,7 @@ from typing import Any, TYPE_CHECKING, List, Tuple, Union from typing import Callable, Optional import numpy as np +from copy import deepcopy import torch from torch import Tensor @@ -16,6 +17,11 @@ from graphnet.models import Model from graphnet.utilities.decorators import final +from graphnet.models.utils import get_fields +from graphnet.utilities.imports import has_jammy_flows_package + +if has_jammy_flows_package(): + import jammy_flows class Task(Model): @@ -39,7 +45,6 @@ def default_prediction_labels(self) -> List[str]: def __init__( self, *, - loss_function: "LossFunction", target_labels: Optional[Union[str, List[str]]] = None, prediction_labels: Optional[Union[str, List[str]]] = None, transform_prediction_and_target: Optional[Callable] = None, @@ -51,7 +56,6 @@ def __init__( """Construct `Task`. Args: - loss_function: Loss function appropriate to the task. target_labels: Name(s) of the quantity/-ies being predicted, used to extract the target tensor(s) from the `Data` object in `.compute_loss(...)`. @@ -101,7 +105,6 @@ def __init__( self._regularisation_loss: Optional[float] = None self._target_labels = target_labels self._prediction_labels = prediction_labels - self._loss_function = loss_function self._inference = False self._loss_weight = loss_weight @@ -229,6 +232,7 @@ class LearnedTask(Task): def __init__( self, hidden_size: int, + loss_function: "LossFunction", **task_kwargs: Any, ): """Construct `LearnedTask`. @@ -237,11 +241,13 @@ def __init__( hidden_size: The number of columns in the output of the last latent layer of `Model` using this Task. Available through `Model.nb_outputs` + loss_function: Loss function appropriate to the task. """ # Base class constructor super().__init__(**task_kwargs) # Mapping from last hidden layer to required size of input + self._loss_function = loss_function self._affine = Linear(hidden_size, self.nb_inputs) @abstractmethod @@ -380,66 +386,85 @@ def _forward(self, x: Union[Tensor, Data]) -> Tensor: # type: ignore class StandardFlowTask(Task): - """A `Task` for `NormalizingFlow`s in GraphNeT.""" + """A `Task` for `NormalizingFlow`s in GraphNeT. + + This Task requires the support package`jammy_flows` for constructing and + evaluating normalizing flows. + """ def __init__( self, - target_labels: List[str], + hidden_size: Union[int, None], + flow_layers: str = "gggt", + target_norm: float = 1000.0, **task_kwargs: Any, ): - """Construct `StandardLearnedTask`. + """Construct `StandardFlowTask`. Args: target_labels: A list of names for the targets of this Task. - hidden_size: The number of columns in the output of - the last latent layer of `Model` using this Task. - Available through `Model.nb_outputs` + flow_layers: A string indicating the flow layer types. See + https://thoglu.github.io/jammy_flows/usage/introduction.html + for details. + target_norm: A normalization constant used to divide the target + values. Value is applied to all targets. Defaults to 1000. + hidden_size: The number of columns on which the normalizing flow + is conditioned on. May be `None`, indicating non-conditional flow. """ # Base class constructor - super().__init__(target_labels=target_labels, **task_kwargs) - def nb_inputs(self) -> int: - """Return number of inputs assumed by task.""" - return len(self._target_labels) + # Member variables + self._default_prediction_labels = ["nllh"] + self._hidden_size = hidden_size + super().__init__(**task_kwargs) + self._flow = jammy_flows.pdf( + f"e{len(self._target_labels)}", + flow_layers, + conditional_input_dim=hidden_size, + ) + self._initialized = False + self._norm = target_norm - def _forward(self, x: Tensor, jacobian: Tensor) -> Tensor: # type: ignore - # Leave it as is. - return x + @property + def default_prediction_labels(self) -> List[str]: + """Return default prediction labels.""" + return self._default_prediction_labels + + def nb_inputs(self) -> Union[int, None]: # type: ignore + """Return number of conditional inputs assumed by task.""" + return self._hidden_size + + def _forward(self, x: Optional[Tensor], y: Tensor) -> Tensor: # type: ignore + y = y / self._norm + if x is not None: + if x.shape[0] != y.shape[0]: + raise AssertionError( + f"Targets {self._target_labels} have " + f"{y.shape[0]} rows while conditional " + f"inputs have {x.shape[0]} rows. " + "The number of rows must match." + ) + log_pdf, _, _ = self._flow(y, conditional_input=x) + else: + log_pdf, _, _ = self._flow(y) + return -log_pdf.reshape(-1, 1) @final def forward( - self, x: Union[Tensor, Data], jacobian: Optional[Tensor] + self, x: Union[Tensor, Data], data: List[Data] ) -> Union[Tensor, Data]: """Forward pass.""" - self._regularisation_loss = 0 # Reset - x = self._forward(x, jacobian) + # Manually cast pdf to correct dtype - is there a better way? + self._flow = self._flow.to(self.dtype) + # Get target values + labels = get_fields(data=data, fields=self._target_labels) + labels = labels.to(self.dtype) + # Set the initial parameters of flow close to truth + # This speeds up training and helps with NaN + if (self._initialized is False) & (self.training): + self._flow.init_params(data=deepcopy(labels).cpu()) + self._flow.to(self.device) + self._initialized = True # This is only done once + # Compute nllh + x = self._forward(x, labels) return self._transform_prediction(x) - - @final - def compute_loss( - self, prediction: Tensor, jacobian: Tensor, data: Data - ) -> Tensor: - """Compute loss for normalizing flow tasks. - - Args: - prediction: transformed sample in latent distribution space. - jacobian: the jacobian associated with the transformation. - data: the graph object. - - Returns: - the loss associated with the transformation. - """ - if self._loss_weight is not None: - weights = data[self._loss_weight] - else: - weights = None - loss = ( - self._loss_function( - prediction=prediction, - jacobian=jacobian, - weights=weights, - target=None, - ) - + self._regularisation_loss - ) - return loss diff --git a/src/graphnet/models/utils.py b/src/graphnet/models/utils.py index d05e8223f..11b73d06f 100644 --- a/src/graphnet/models/utils.py +++ b/src/graphnet/models/utils.py @@ -1,12 +1,13 @@ """Utility functions for `graphnet.models`.""" -from typing import List, Tuple, Any +from typing import List, Tuple, Any, Union from torch_geometric.nn import knn_graph from torch_geometric.data import Batch import torch from torch import Tensor, LongTensor from torch_geometric.utils import homophily +from torch_geometric.data import Data def calculate_xyzt_homophily( @@ -103,3 +104,15 @@ def array_to_sequence( mask = torch.ne(x[:, :, 1], excluding_value) x[~mask] = padding_value return x, mask, seq_length + + +def get_fields(data: Union[Data, List[Data]], fields: List[str]) -> Tensor: + """Extract named fields in Data object.""" + labels = [] + if not isinstance(data, list): + data = [data] + for label in list(fields): + labels.append( + torch.cat([d[label].reshape(-1, 1) for d in data], dim=0) + ) + return torch.cat(labels, dim=1) diff --git a/src/graphnet/utilities/imports.py b/src/graphnet/utilities/imports.py index a490f413c..1c143280a 100644 --- a/src/graphnet/utilities/imports.py +++ b/src/graphnet/utilities/imports.py @@ -33,6 +33,20 @@ def has_torch_package() -> bool: return False +def has_jammy_flows_package() -> bool: + """Check if the `jammy_flows` package is available.""" + try: + import jammy_flows # pyright: reportMissingImports=false + + return True + except ImportError: + Logger(log_folder=None).warning_once( + "`jammy_flows` not available. Normalizing Flow functionality is " + "missing." + ) + return False + + def requires_icecube(test_function: Callable) -> Callable: """Decorate `test_function` for use only if `icecube` module is present."""