Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Config files refactor, Examples polish, Perturbation in GraphDefinition #603

Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
27 commits
Select commit Hold shift + click to select a range
a3b722e
Create MetaClasses to save Model/Dataset configs
AMHermansen Sep 13, 2023
682213c
Remove usage of save_model_config and save_dataset_config
AMHermansen Sep 13, 2023
e4c06fd
Remove redundant config saving baseclasses.
AMHermansen Sep 13, 2023
63a8715
Remove redundant config saving baseclasses.
AMHermansen Sep 13, 2023
9321927
Reintroduce save_model_config and save_dataset_config but added depre…
AMHermansen Sep 14, 2023
4af2872
Fixed typehints for make_(train_validation)_dataloader
AMHermansen Sep 19, 2023
368b9a8
Fixed typehints for make_(train_validation)_dataloader
AMHermansen Sep 19, 2023
396eb77
Update utils.py
AMHermansen Sep 19, 2023
0b73f6f
Merge branch 'graphnet-team:main' into add-ConfSaverMeta
AMHermansen Sep 20, 2023
b846b40
fix example 02-02
RasmusOrsoe Sep 22, 2023
ae226e0
default arguments, fix 02-01
RasmusOrsoe Sep 22, 2023
8e04af2
tito_example update
RasmusOrsoe Sep 22, 2023
15a14f7
Polish examples
RasmusOrsoe Sep 22, 2023
a11014c
delete shell script example
RasmusOrsoe Sep 22, 2023
c840b44
rename examples, update readme.md
RasmusOrsoe Sep 22, 2023
1fbf534
Move perturbations to graph_definition
RasmusOrsoe Sep 22, 2023
d4e166a
minor adjustments, unit test
RasmusOrsoe Sep 22, 2023
e93f514
Unit tests
RasmusOrsoe Sep 22, 2023
b567581
delete perturbedsqlitedataset
RasmusOrsoe Sep 22, 2023
8c54c77
remove old import statements
RasmusOrsoe Sep 22, 2023
6f014fa
replace np.float with float
RasmusOrsoe Sep 22, 2023
726e653
shorten warning
RasmusOrsoe Sep 22, 2023
822c0d7
shorten doc string
RasmusOrsoe Sep 22, 2023
1758b74
shorten error strings
RasmusOrsoe Sep 22, 2023
afeec3e
Replace GenericExtractor in 01-03 for FeatureExtractor
RasmusOrsoe Sep 22, 2023
999e26d
fix typo in readme.md
RasmusOrsoe Sep 22, 2023
03dc5d2
Update code comment in 04-01
RasmusOrsoe Sep 23, 2023
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 2 additions & 6 deletions examples/01_icetray/01_convert_i3_files.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
from graphnet.constants import EXAMPLE_OUTPUT_DIR, TEST_DATA_DIR
from graphnet.data.extractors import (
I3FeatureExtractorIceCubeUpgrade,
I3FeatureExtractorIceCube86,
I3RetroExtractor,
I3TruthExtractor,
I3GenericExtractor,
Expand Down Expand Up @@ -34,12 +35,7 @@ def main_icecube86(backend: str) -> None:

converter: DataConverter = CONVERTER_CLASS[backend](
[
I3GenericExtractor(
keys=[
"SRTInIcePulses",
"I3MCTree",
]
),
I3FeatureExtractorIceCube86("SRTInIcePulses"),
I3TruthExtractor(),
],
outdir,
Expand Down
21 changes: 14 additions & 7 deletions examples/02_data/01_read_dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,10 @@
from graphnet.data.dataset import ParquetDataset
from graphnet.utilities.argparse import ArgumentParser
from graphnet.utilities.logging import Logger

from graphnet.models.graphs import KNNGraph
from graphnet.models.detector.icecube import (
IceCubeDeepCore,
)

DATASET_CLASS = {
"sqlite": SQLiteDataset,
Expand All @@ -44,6 +47,9 @@ def main(backend: str) -> None:
num_workers = 30
wait_time = 0.00 # sec.

# Define graph representation
graph_definition = KNNGraph(detector=IceCubeDeepCore())

for table in [pulsemap, truth_table]:
# Get column names from backend
if backend == "sqlite":
Expand All @@ -62,15 +68,16 @@ def main(backend: str) -> None:

# Common variables
dataset = DATASET_CLASS[backend](
path,
pulsemap,
features,
truth,
path=path,
pulsemaps=pulsemap,
features=features,
truth=truth,
truth_table=truth_table,
graph_definition=graph_definition,
)
assert isinstance(dataset, Dataset)

logger.info(dataset[1])
logger.info(str(dataset[1]))
logger.info(dataset[1].x)
if backend == "sqlite":
assert isinstance(dataset, SQLiteDataset)
Expand All @@ -92,7 +99,7 @@ def main(backend: str) -> None:
for batch in tqdm(dataloader, unit=" batches", colour="green"):
time.sleep(wait_time)

logger.info(batch)
logger.info(str(batch))
logger.info(batch.size())
logger.info(batch.num_graphs)

Expand Down
38 changes: 5 additions & 33 deletions examples/02_data/02_plot_feature_distributions.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
"""Example of plotting feature distributions from SQLite database."""
"""Example of visualization of input data from a configured Dataset."""

import os.path

Expand All @@ -8,8 +8,6 @@

from graphnet.constants import CONFIG_DIR
from graphnet.data.dataset import Dataset
from graphnet.models.detector.icecube import IceCubeDeepCore
from graphnet.models.graph_builders import KNNGraphBuilder
from graphnet.utilities.logging import Logger
from graphnet.utilities.argparse import ArgumentParser

Expand All @@ -27,46 +25,20 @@ def main() -> None:
assert isinstance(dataset, Dataset)
features = dataset._features[1:]

# Building model
detector = IceCubeDeepCore(
graph_builder=KNNGraphBuilder(nb_nearest_neighbours=8),
)

# Get feature matrix
x_original_list = []
x_preprocessed_list = []
for batch in tqdm(dataset, colour="green"):
x_original_list.append(batch.x.numpy())
x_preprocessed_list.append(detector(batch).x.numpy())
x_preprocessed_list.append(batch.x.numpy())

x_original = np.concatenate(x_original_list, axis=0)
x_preprocessed = np.concatenate(x_preprocessed_list, axis=0)

logger.info(f"Number of NaNs: {np.sum(np.isnan(x_original))}")
logger.info(f"Number of infs: {np.sum(np.isinf(x_original))}")
logger.info(f"Number of NaNs: {np.sum(np.isnan(x_preprocessed))}")
logger.info(f"Number of infs: {np.sum(np.isinf(x_preprocessed))}")

# Plot feature distributions
nb_features_original = x_original.shape[1]
nb_features_preprocessed = x_preprocessed.shape[1]
dim = int(np.ceil(np.sqrt(nb_features_preprocessed)))
axis_size = 4
bins = 100

# -- Original
fig, axes = plt.subplots(
dim, dim, figsize=(dim * axis_size, dim * axis_size)
)
for ix, ax in enumerate(axes.ravel()[:nb_features_original]):
ax.hist(x_original[:, ix], bins=bins)
ax.set_xlabel(
f"x{ix}: {features[ix] if ix < len(features) else 'N/A'}"
)
ax.set_yscale("log")

fig.tight_layout
figure_name_original = "feature_distribution_original.png"
fig.savefig(figure_name_original)
logger.info(f"Figure written to {figure_name_original}")
bins = 50

# -- Preprocessed
fig, axes = plt.subplots(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -79,12 +79,7 @@ def main(
wandb_logger.experiment.config.update(config)

# Define graph representation
graph_definition = KNNGraph(
detector=Prometheus(),
node_definition=NodesAsPulses(),
nb_nearest_neighbours=8,
node_feature_names=features,
)
graph_definition = KNNGraph(detector=Prometheus())

(
training_dataloader,
Expand Down Expand Up @@ -166,10 +161,19 @@ def main(
logger.info(f"Writing results to {path}")
os.makedirs(path, exist_ok=True)

# Save results as .csv
results.to_csv(f"{path}/results.csv")
model.save_state_dict(f"{path}/state_dict.pth")

# Save full model (including weights) to .pth file - not version safe
# Note: Models saved as .pth files in one version of graphnet
# may not be compatible with a different version of graphnet.
model.save(f"{path}/model.pth")

# Save model config and state dict - Version safe save method.
# This method of saving models is the safest way.
model.save_state_dict(f"{path}/state_dict.pth")
model.save_config(f"{path}/model_config.yml")


if __name__ == "__main__":

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,6 @@
from graphnet.models.detector.prometheus import Prometheus
from graphnet.models.gnn import DynEdgeTITO
from graphnet.models.graphs import KNNGraph
from graphnet.models.graphs.nodes import NodesAsPulses
from graphnet.models.task.reconstruction import (
DirectionReconstructionWithKappa,
)
Expand All @@ -28,7 +27,6 @@
# Constants
features = FEATURES.PROMETHEUS
truth = TRUTH.PROMETHEUS
DYNTRANS_LAYER_SIZES = [(256, 256), (256, 256), (256, 256)]


def main(
Expand Down Expand Up @@ -76,12 +74,7 @@ def main(
},
}

graph_definition = KNNGraph(
detector=Prometheus(),
node_definition=NodesAsPulses(),
nb_nearest_neighbours=8,
node_feature_names=features,
)
graph_definition = KNNGraph(detector=Prometheus())
archive = os.path.join(EXAMPLE_OUTPUT_DIR, "train_tito_model")
run_name = "dynedgeTITO_{}_example".format(config["target"])
if wandb:
Expand Down Expand Up @@ -115,7 +108,6 @@ def main(
gnn = DynEdgeTITO(
nb_inputs=graph_definition.nb_outputs,
global_pooling_schemes=["max"],
dyntrans_layer_sizes=DYNTRANS_LAYER_SIZES,
)
task = DirectionReconstructionWithKappa(
hidden_size=gnn.nb_outputs,
Expand Down Expand Up @@ -182,10 +174,16 @@ def main(
logger.info(f"Writing results to {path}")
os.makedirs(path, exist_ok=True)

# Save results as .csv
results.to_csv(f"{path}/results.csv")
model.save_state_dict(f"{path}/state_dict.pth")

# Save full model (including weights) to .pth file - Not version proof
model.save(f"{path}/model.pth")

# Save model config and state dict - Version safe save method.
model.save_state_dict(f"{path}/state_dict.pth")
model.save_config(f"{path}/model_config.yml")


if __name__ == "__main__":

Expand Down
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
"""Simplified example of training Model."""
"""Simplified example of training DynEdge from pre-defined config files."""

from typing import List, Optional
import os
Expand Down Expand Up @@ -46,7 +46,7 @@ def main(
log_model=True,
)

# Build model
# Build model from pre-defined config file made from Model.save_config
model_config = ModelConfig.load(model_config_path)
model: StandardModel = StandardModel.from_config(model_config, trust=True)

Expand All @@ -69,7 +69,8 @@ def main(
archive = os.path.join(EXAMPLE_OUTPUT_DIR, "train_model")
run_name = "dynedge_{}_example".format("_".join(config.target))

# Construct dataloaders
# Construct dataloaders from pre-defined dataset config files.
# i.e. from Dataset.save_config
dataset_config = DatasetConfig.load(dataset_config_path)
dataloaders = DataLoader.from_dataset_config(
dataset_config,
Expand Down
62 changes: 0 additions & 62 deletions examples/04_training/03_train_multiple_models.sh

This file was deleted.

Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
"""Simplified example of multi-class classification training Model."""
"""Multi-class classification using DynEdge from pre-defined config files."""

import os
from typing import List, Optional, Dict, Any
Expand Down
20 changes: 10 additions & 10 deletions examples/04_training/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -2,44 +2,44 @@

This subfolder contains two main training scripts:

**`01_train_model.py`** Shows how to train a GNN on neutrino telescope data **using configuration files** to construct the dataset that loads the data and the model that is trained. This is the recommended way to configure standard dataset and models, as it is easier to ready and share than doing so in pure code. This example can be run using a few different models targeting different physics use cases. For instance, you can try running:
**`01_train_dynedge.py`** ** Shows how to train a GNN on neutrino telescope data **without configuration files,** i.e., by programatically constructing the dataset and model used. This is good for debugging and experimenting with different dataset settings and model configurations, as it is easier to build the model using the API than by writing configuration files from scratch. **This is our recommended way of getting started with the library**. For instance, try running:

```bash
# Show the CLI
(graphnet) $ python examples/04_training/01_train_model.py --help
(graphnet) $ python examples/04_training/01_train_dynedge.py --help

# Train energy regression model
(graphnet) $ python examples/04_training/01_train_model.py
(graphnet) $ python examples/04_training/01_train_dynedge.py

# Same as above, as this is the default model config.
(graphnet) $ python examples/04_training/01_train_model.py \
--model-config configs/models/example_energy_reconstruction_model.yml

# Train using a single GPU
(graphnet) $ python examples/04_training/01_train_model.py --gpus 0
(graphnet) $ python examples/04_training/01_train_dynedge.py --gpus 0

# Train using multiple GPUs
(graphnet) $ python examples/04_training/01_train_model.py --gpus 0 1
(graphnet) $ python examples/04_training/01_train_dynedge.py --gpus 0 1

# Train a vertex position reconstruction model
(graphnet) $ python examples/04_training/01_train_model.py \
(graphnet) $ python examples/04_training/01_train_dynedge.py \
--model-config configs/models/example_vertex_position_reconstruction_model.yml

# Trains a direction (zenith, azimuth) reconstruction model. Note that the
# chosen `Task` in the model config file also returns estimated "kappa" values,
# i.e. inverse variance, for each predicted feature, meaning that we need to
# manually specify the names of these.
(graphnet) $ python examples/04_training/01_train_model.py --gpus 0 \
(graphnet) $ python examples/04_training/01_train_model_dynedge.py --gpus 0 \
--model-config configs/models/example_direction_reconstruction_model.yml \
--prediction-names zenith_pred zenith_kappa_pred azimuth_pred azimuth_kappa_pred
```

**`02_train_model_without_configs.py`** Shows how to train a GNN on neutrino telescope data **without configuration files,** i.e., by programatically constructing the dataset and model used. This is good for debugging and experimenting with different dataset settings and model configurations, as it is easier to build the model using the API than by writing configuration files from scratch. For instance, try running:
**`03_train_model_dynedge_from_config.py** Shows how to train a GNN on neutrino telescope data **using configuration files** to construct the dataset that loads the data and the model that is trained. This is the recommended way to configure standard dataset and models, as it is easier to ready and share than doing so in pure code. This example can be run using a few different models targeting different physics use cases. For instance, you can try running:

```bash
# Show the CLI
(graphnet) $ python examples/04_training/02_train_model_without_configs.py --help
(graphnet) $ python examples/04_training/02_train_dynedge_from_config.py --help

# Train energy regression model
(graphnet) $ python examples/04_training/02_train_model_without_configs.py
(graphnet) $ python examples/04_training/02_train_dynedge_from_config.py
```
1 change: 0 additions & 1 deletion src/graphnet/data/dataset/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,6 @@
from .dataset import EnsembleDataset, Dataset, ColumnMissingException
from .parquet.parquet_dataset import ParquetDataset
from .sqlite.sqlite_dataset import SQLiteDataset
from .sqlite.sqlite_dataset_perturbed import SQLiteDatasetPerturbed

torch.multiprocessing.set_sharing_strategy("file_system")

Expand Down
Loading