diff --git a/examples/train_model.py b/examples/train_model.py index 67779e302..f2a0f417e 100644 --- a/examples/train_model.py +++ b/examples/train_model.py @@ -1,4 +1,4 @@ -import os.path +import os from pytorch_lightning import Trainer from pytorch_lightning.callbacks import EarlyStopping @@ -33,11 +33,15 @@ features = FEATURES.DEEPCORE truth = TRUTH.DEEPCORE[:-1] +# Make sure W&B output directory exists +WANDB_DIR = "./wandb/" +os.makedirs(WANDB_DIR, exist_ok=True) + # Initialise Weights & Biases (W&B) run wandb_logger = WandbLogger( project="example-script", entity="graphnet-team", - save_dir="./wandb/", + save_dir=WANDB_DIR, log_model=True, )