-
Notifications
You must be signed in to change notification settings - Fork 96
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
NormalizingFlow #649
Closed
RasmusOrsoe
wants to merge
23
commits into
graphnet-team:main
from
RasmusOrsoe:final_normalizing_flow
+914
−43
Closed
NormalizingFlow #649
Changes from 15 commits
Commits
Show all changes
23 commits
Select commit
Hold shift + click to select a range
641e4c6
add spline funcs
RasmusOrsoe 117c799
add NormalizingFlow
RasmusOrsoe 807bbd5
add INGA
RasmusOrsoe dc0d8f3
add example
RasmusOrsoe 7d5e6de
add MultivariateGaussianFlowLoss
RasmusOrsoe 82fa707
example plosh
RasmusOrsoe 55bf1f4
standard model checks
RasmusOrsoe 3563ec9
update forward in standardmodel
RasmusOrsoe 6402d0b
backbone rename
RasmusOrsoe 10c6af6
polish
RasmusOrsoe 9c43255
edit checks in loss_function
RasmusOrsoe 1083f13
Make output of NormalizingFlow a single tensor
RasmusOrsoe 48b4a19
update example
RasmusOrsoe 7b4b16a
remove misplaced abstractmethod decorator
RasmusOrsoe 6b0a55b
remove unused import
RasmusOrsoe 14f71cb
remove commented-out lines of code in spline_blocks.py
RasmusOrsoe 45e8e13
Merge pull request #23 from RasmusOrsoe/main
RasmusOrsoe 966ddec
remove unused imports
RasmusOrsoe 87ed567
remove self._eps
RasmusOrsoe 1386914
Remove assert
RasmusOrsoe 9db41d3
remove comment
RasmusOrsoe e44804a
remove unused import
RasmusOrsoe 4ae5f1b
set default value for `target` in loss function
RasmusOrsoe File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,199 @@ | ||
"""Example of training Model.""" | ||
|
||
import os | ||
from typing import Any, Dict, List, Optional | ||
|
||
from pytorch_lightning.loggers import WandbLogger | ||
import numpy as np | ||
import pandas as pd | ||
|
||
from graphnet.constants import EXAMPLE_DATA_DIR, EXAMPLE_OUTPUT_DIR | ||
from graphnet.data.constants import FEATURES, TRUTH | ||
from graphnet.models import StandardModel | ||
from graphnet.models.detector.prometheus import Prometheus | ||
from graphnet.models.flows import INGA | ||
from graphnet.models.graphs import GraphDefinition | ||
from graphnet.models.graphs.nodes import NodesAsPulses | ||
|
||
from graphnet.models.task import StandardFlowTask | ||
from graphnet.training.loss_functions import ( | ||
MultivariateGaussianFlowLoss, | ||
) | ||
from graphnet.training.utils import make_train_validation_dataloader | ||
from graphnet.utilities.argparse import ArgumentParser | ||
from graphnet.utilities.logging import Logger | ||
|
||
# Constants | ||
features = FEATURES.PROMETHEUS | ||
truth = TRUTH.PROMETHEUS | ||
|
||
|
||
def main( | ||
path: str, | ||
pulsemap: str, | ||
target: str, | ||
truth_table: str, | ||
gpus: Optional[List[int]], | ||
max_epochs: int, | ||
early_stopping_patience: int, | ||
batch_size: int, | ||
num_workers: int, | ||
wandb: bool = False, | ||
) -> None: | ||
"""Run example.""" | ||
# Construct Logger | ||
logger = Logger() | ||
# Configuration | ||
config: Dict[str, Any] = { | ||
"path": path, | ||
"pulsemap": pulsemap, | ||
"batch_size": batch_size, | ||
"num_workers": num_workers, | ||
"target": target, | ||
"early_stopping_patience": early_stopping_patience, | ||
"fit": { | ||
"gpus": gpus, | ||
"max_epochs": max_epochs, | ||
}, | ||
} | ||
|
||
archive = os.path.join(EXAMPLE_OUTPUT_DIR, "train_model_without_configs") | ||
run_name = "INGA_example_1mio" | ||
|
||
# Define graph representation | ||
detector = Prometheus() | ||
|
||
graph_definition = GraphDefinition( | ||
detector=detector, | ||
node_definition=NodesAsPulses(), | ||
input_feature_names=features, | ||
) | ||
( | ||
training_dataloader, | ||
validation_dataloader, | ||
) = make_train_validation_dataloader( | ||
db=config["path"], | ||
graph_definition=graph_definition, | ||
pulsemaps=config["pulsemap"], | ||
features=features, | ||
truth=truth, | ||
batch_size=config["batch_size"], | ||
num_workers=config["num_workers"], | ||
truth_table=truth_table, | ||
selection=None, | ||
) | ||
|
||
# Building model | ||
flow = INGA( | ||
nb_inputs=graph_definition.nb_outputs, | ||
n_knots=120, | ||
num_blocks=4, | ||
b=100, | ||
c=100, | ||
) | ||
task = StandardFlowTask( | ||
target_labels=graph_definition.output_feature_names, | ||
loss_function=MultivariateGaussianFlowLoss(), | ||
coordinate_columns=flow.coordinate_columns, | ||
jacobian_columns=flow.jacobian_columns, | ||
) | ||
model = StandardModel( | ||
graph_definition=graph_definition, | ||
backbone=flow, | ||
tasks=[task], | ||
) | ||
|
||
model.fit( | ||
training_dataloader, | ||
validation_dataloader, | ||
**config["fit"], | ||
) | ||
results = model.predict_as_dataframe( | ||
validation_dataloader, | ||
additional_attributes=["event_no"], | ||
) | ||
|
||
# Save predictions and model to file | ||
db_name = path.split("/")[-1].split(".")[0] | ||
path = os.path.join(archive, db_name, run_name) | ||
logger.info(f"Writing results to {path}") | ||
os.makedirs(path, exist_ok=True) | ||
|
||
# Save results as .csv | ||
results.to_csv(f"{path}/results.csv") | ||
|
||
# Save full model (including weights) to .pth file - not version safe | ||
# Note: Models saved as .pth files in one version of graphnet | ||
# may not be compatible with a different version of graphnet. | ||
model.save(f"{path}/model.pth") | ||
|
||
# Save model config and state dict - Version safe save method. | ||
# This method of saving models is the safest way. | ||
model.save_state_dict(f"{path}/state_dict.pth") | ||
model.save_config(f"{path}/model_config.yml") | ||
|
||
|
||
if __name__ == "__main__": | ||
|
||
# Parse command-line arguments | ||
parser = ArgumentParser( | ||
description=""" | ||
Train GNN model without the use of config files. | ||
""" | ||
) | ||
|
||
parser.add_argument( | ||
"--path", | ||
help="Path to dataset file (default: %(default)s)", | ||
default=f"{EXAMPLE_DATA_DIR}/sqlite/prometheus/prometheus-events.db", | ||
) | ||
|
||
parser.add_argument( | ||
"--pulsemap", | ||
help="Name of pulsemap to use (default: %(default)s)", | ||
default="total", | ||
) | ||
|
||
parser.add_argument( | ||
"--target", | ||
help=( | ||
"Name of feature to use as regression target (default: " | ||
"%(default)s)" | ||
), | ||
default="total_energy", | ||
) | ||
|
||
parser.add_argument( | ||
"--truth-table", | ||
help="Name of truth table to be used (default: %(default)s)", | ||
default="mc_truth", | ||
) | ||
|
||
parser.with_standard_arguments( | ||
"gpus", | ||
("max-epochs", 1), | ||
"early-stopping-patience", | ||
("batch-size", 16), | ||
"num-workers", | ||
) | ||
|
||
parser.add_argument( | ||
"--wandb", | ||
action="store_true", | ||
help="If True, Weights & Biases are used to track the experiment.", | ||
) | ||
|
||
args, unknown = parser.parse_known_args() | ||
|
||
main( | ||
args.path, | ||
args.pulsemap, | ||
args.target, | ||
args.truth_table, | ||
args.gpus, | ||
args.max_epochs, | ||
args.early_stopping_patience, | ||
args.batch_size, | ||
args.num_workers, | ||
args.wandb, | ||
) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,3 @@ | ||
"""Module for Normalizing Flows in GraphNeT.""" | ||
from .normalizing_flow import NormalizingFlow | ||
from .inga import INGA |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,136 @@ | ||
"""Normalizing flow using parameterized splines. | ||
|
||
Implemented by Rasmus Ørsøe, 2023. | ||
""" | ||
from typing import List, Tuple | ||
|
||
import numpy as np | ||
import torch | ||
|
||
from graphnet.models.flows import NormalizingFlow | ||
from graphnet.models.flows.spline_blocks import ( | ||
SplineBlock, | ||
TwoPartitionSplineBlock, | ||
) | ||
from torch_geometric.data import Data | ||
|
||
|
||
class INGA(NormalizingFlow): | ||
"""Implementation of spline-based neural flows. | ||
|
||
Inspied by https://arxiv.org/pdf/1906.04032.pdf | ||
""" | ||
|
||
def __init__( | ||
self, | ||
nb_inputs: int, | ||
b: int = 100, | ||
n_knots: int = 5, | ||
num_blocks: int = 1, | ||
partitions: List[Tuple[List[int], List[int]]] = None, | ||
c: int = 1, | ||
): | ||
"""Construct INGA. | ||
|
||
Args: | ||
nb_inputs: Number of input dimensions to be transformed. | ||
b: The bounding parameter. | ||
All input dimensions are assumed to be in the range [-b,b]. | ||
Defaults to 100. | ||
n_knots: Number of knots per spline. Defaults to 5. | ||
num_blocks: Numbe of spline blocks. Defaults to 1. | ||
partitions: A set of partitions that designate which dimensions of | ||
the input are used to transform each other | ||
E.g. [[0,1,2,3,4], [5,6,7,8,9]] (for 10-dimensional case) | ||
means dimensions 0 through 4 is used to transform | ||
dimensions 5 through 9 and vice-versa. | ||
Defaults to None, which will create an even partition. | ||
c: Scaling parameter for the neural network. | ||
""" | ||
self._coordinate_columns = np.arange(0, nb_inputs).tolist() | ||
self._jacobian_columns = np.arange(nb_inputs, 2 * nb_inputs).tolist() | ||
super().__init__(nb_inputs) | ||
|
||
# Set Member variables | ||
self.n_knots = n_knots | ||
self.num_blocks = num_blocks | ||
|
||
if partitions is None: | ||
partitions = self._create_default_partitions() | ||
|
||
self.partitions = partitions | ||
|
||
# checks | ||
assert len(partitions) == self.num_blocks | ||
|
||
# constants | ||
spline_blocks = [] | ||
for k in range(num_blocks): | ||
nn_0_dim = len(partitions[k][0]) | ||
nn_1_dim = len(partitions[k][1]) | ||
spline_blocks.append( | ||
TwoPartitionSplineBlock( | ||
b=b, | ||
n_knots=n_knots, | ||
input_dim=self.nb_inputs, | ||
nn_0=torch.nn.Sequential( | ||
torch.nn.Linear(nn_0_dim, nn_0_dim * c), | ||
torch.nn.ReLU(), | ||
torch.nn.Linear( | ||
nn_0_dim * c, nn_1_dim * (n_knots * 3) | ||
), | ||
), # ((3*self.n_knots-1)*dim) | ||
nn_1=torch.nn.Sequential( | ||
torch.nn.Linear(nn_1_dim, nn_1_dim * c), | ||
torch.nn.ReLU(), | ||
torch.nn.Linear( | ||
nn_1_dim * c, nn_0_dim * (n_knots * 3) | ||
), | ||
), | ||
partition=partitions[k], | ||
) | ||
) | ||
|
||
self.spline_blocks = torch.nn.ModuleList(spline_blocks) | ||
|
||
def _create_default_partitions(self) -> List[Tuple[List[int], List[int]]]: | ||
default_partition = ( | ||
[i for i in range(0, int(self.nb_inputs / 2))], | ||
[k for k in range(int(self.nb_inputs / 2), self.nb_inputs)], | ||
) | ||
partitions = [] | ||
for _ in range(self.num_blocks): | ||
partitions.append(default_partition) | ||
return partitions | ||
|
||
def forward(self, data: Data) -> Tuple[torch.Tensor, torch.Tensor]: | ||
"""Forward call. | ||
|
||
Will transform sample from input distribution to latent distribution. | ||
""" | ||
is_first = True | ||
c = 0 | ||
x = data.x | ||
for spline_block in self.spline_blocks: | ||
# self.info(f"spline block {c}") | ||
if is_first: | ||
y, partition_jacobian = spline_block(x=x) | ||
global_jacobian = partition_jacobian | ||
is_first = False | ||
else: | ||
y, partition_jacobian = spline_block(x=y) | ||
global_jacobian *= partition_jacobian | ||
c += 1 | ||
return torch.concat([y, global_jacobian], dim=1) | ||
|
||
def inverse(self, y: torch.Tensor) -> torch.Tensor: | ||
"""Inverse call. | ||
|
||
Will transform sample from latent distribution to input distribution. | ||
""" | ||
reversed_index = list(range(0, len(self.spline_blocks)))[ | ||
::-1 | ||
] # 6, 5, 4 .. | ||
for idx in reversed_index: | ||
y = self.spline_blocks[idx].inverse(y=y) | ||
return self.inverse_transform(y) |
Oops, something went wrong.
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
WandbLogger
,numpy
andpandas
imports are not used.There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks!