Skip to content

Commit

Permalink
Merge pull request #329 from MortenHolmRep/multiclass
Browse files Browse the repository at this point in the history
Multi-class classification implementation
  • Loading branch information
Peterandresen12 authored Dec 6, 2022
2 parents 8c2394c + ba5c9e3 commit b77b25c
Show file tree
Hide file tree
Showing 8 changed files with 526 additions and 12 deletions.
43 changes: 43 additions & 0 deletions configs/datasets/PID_classification_last_one_lvl3MC.yml
Original file line number Diff line number Diff line change
@@ -0,0 +1,43 @@
path: /groups/icecube/petersen/GraphNetDatabaseRepository/Leon2022_DataAndMC_CSVandDB_StoppedMuons/last_one_lvl3MC.db
pulsemaps:
- SRTInIcePulses
features:
- dom_x
- dom_y
- dom_z
- dom_time
- charge
- rde
- pmt_area
truth:
- energy
- position_x
- position_y
- position_z
- azimuth
- zenith
- pid
- elasticity
- sim_type
- interaction_type
index_column: event_no
truth_table: truth
seed: 21
selection:
test_nu_e: 10000 random events ~ event_no % 5 == 0 & abs(pid) == 12
test_nu_mu: 10000 random events ~ event_no % 5 == 0 & abs(pid) == 14
test_nu_tau: 10000 random events ~ event_no % 5 == 0 & abs(pid) == 16
test_mu: 10000 random events ~ event_no % 5 == 0 & abs(pid) == 13
test_noise: 10000 random events ~ event_no % 5 == 0 & abs(pid) == 1

validation_nu_e: 10000 random events ~ event_no % 5 == 1 & abs(pid) == 12
validation_nu_mu: 10000 random events ~ event_no % 5 == 1 & abs(pid) == 14
validation_nu_tau: 10000 random events ~ event_no % 5 == 1 & abs(pid) == 16
validation_mu: 10000 random events ~ event_no % 5 == 1 & abs(pid) == 13
validation_noise: 10000 random events ~ event_no % 5 == 1 & abs(pid) == 1

train_nu_e: 50000 random events ~ event_no % 5 > 1 & abs(pid) == 12
train_nu_mu: 50000 random events ~ event_no % 5 > 1 & abs(pid) == 14
train_nu_tau: 50000 random events ~ event_no % 5 > 1 & abs(pid) == 16
train_mu: 50000 random events ~ event_no % 5 > 1 & abs(pid) == 13
train_noise: 50000 random events ~ event_no % 5 > 1 & abs(pid) == 1
45 changes: 45 additions & 0 deletions configs/models/dynedge_PID_classification_example.yml
Original file line number Diff line number Diff line change
@@ -0,0 +1,45 @@
arguments:
class_options: {1: 0, -1: 0, 13: 1, -13: 1, 12: 2, -12: 2, 14: 2, -14: 2, 16: 2, -16: 2}
coarsening: null
detector:
ModelConfig:
arguments:
graph_builder:
ModelConfig:
arguments: {columns: null, device: null, nb_nearest_neighbours: 8}
class_name: KNNGraphBuilder
scalers: null
class_name: IceCubeDeepCore
gnn:
ModelConfig:
arguments:
add_global_variables_after_pooling: false
dynedge_layer_sizes: null
features_subset: null
global_pooling_schemes: [min, max, mean, sum]
nb_inputs: 7
nb_neighbours: 8
post_processing_layer_sizes: null
readout_layer_sizes: null
class_name: DynEdge
optimizer_class: '!class torch.optim.adam Adam'
optimizer_kwargs: {eps: 0.001, lr: 1e-05}
scheduler_class: '!class torch.optim.lr_scheduler ReduceLROnPlateau'
scheduler_config: {frequency: 1, monitor: val_loss}
scheduler_kwargs: {patience: 5}
tasks:
- ModelConfig:
arguments:
hidden_size: 128
loss_function:
ModelConfig:
arguments: {options=general_config["class_options"]}
class_name: CrossEntropyLoss
loss_weight: null
target_labels: pid
transform_inference: '!lambda x: softmax(x, dim=-1)'
transform_prediction: null
transform_support: null
transform_target: null
class_name: MulticlassClassificationTask
class_name: StandardModel
121 changes: 121 additions & 0 deletions examples/train_classification_model.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,121 @@
"""Example of training Model."""

import os
from typing import Dict, Any

from pytorch_lightning.callbacks import EarlyStopping
from pytorch_lightning.loggers import WandbLogger

from graphnet.data.dataloader import DataLoader
from graphnet.models import Model
from graphnet.training.callbacks import ProgressBar
from graphnet.utilities.config import (
DatasetConfig,
ModelConfig,
TrainingConfig,
)

# Make sure W&B output directory exists
WANDB_DIR = "./wandb/"
os.makedirs(WANDB_DIR, exist_ok=True)

# Initialise Weights & Biases (W&B) run
wandb_logger = WandbLogger(
project="example-script",
entity="graphnet-team",
save_dir=WANDB_DIR,
log_model=True,
)


def train(general_config: Dict[str, Any]) -> None:
"""Train model with configuration given by `config`."""
# Configuration
config = TrainingConfig(
target="pid",
early_stopping_patience=5,
fit={"gpus": [0], "max_epochs": 5},
dataloader={"batch_size": 512, "num_workers": 10},
)

run_name = "dynedge_{}_classification_example".format(config.target)

# Log configuration to W&B
wandb_logger.experiment.config.update(config)

# Construct dataloaders
dataset_config = DatasetConfig.load(
"configs/datasets/" + general_config["dataset"] + ".yml"
)
# dataloader_test, dataloader_valid, ..
dataloaders = DataLoader.from_dataset_config(
dataset_config,
**config.dataloader,
)
wandb_logger.experiment.config.update(dataset_config.as_dict())

# Build model
model_config = ModelConfig.load(
"configs/models/" + general_config["model"] + ".yml"
)
model = Model.from_config(model_config, trust=True)
wandb_logger.experiment.config.update(model_config.as_dict())

# Training model
callbacks = [
EarlyStopping(
monitor="val_loss",
patience=config.early_stopping_patience,
),
ProgressBar(),
]

model.fit(
dataloaders["train"],
dataloaders["validation"],
callbacks=callbacks,
logger=wandb_logger,
**config.fit,
)

# Get predictions
if isinstance(config.target, str):
prediction_columns = [
config.target + "_noise_pred",
config.target + "_muon_pred",
config.target + "_neutrino_pred",
]
additional_attributes = [config.target]
else:
prediction_columns = [target + "_pred" for target in config.target]
additional_attributes = config.target

results = model.predict_as_dataframe(
dataloaders["test"],
prediction_columns=prediction_columns,
additional_attributes=additional_attributes + ["event_no"],
)

# Save predictions and model to file
db_name = dataset_config.path.split("/")[-1].split(".")[0]
path = os.path.join(general_config["archive"], db_name, run_name)

results.to_csv(f"{path}/results.csv")
model.save_state_dict(f"{path}/state_dict.pth")
model.save(f"{path}/model.pth")


def main() -> None:
"""Run example."""
# General configuration
general_config = {
"dataset": "PID_classification_last_one_lvl3MC.yml",
"model": "dynedge_PID_Classification_noise_muon_neutrino_example.yml",
"archive": "/groups/icecube/petersen/GraphNetDatabaseRepository/example_results/train_classification_model",
}

train(general_config)


if __name__ == "__main__":
main()
Loading

0 comments on commit b77b25c

Please sign in to comment.