diff --git a/examples/train_model.py b/examples/train_model.py index 49439f8ba..f2a0f417e 100644 --- a/examples/train_model.py +++ b/examples/train_model.py @@ -33,17 +33,18 @@ features = FEATURES.DEEPCORE truth = TRUTH.DEEPCORE[:-1] +# Make sure W&B output directory exists +WANDB_DIR = "./wandb/" +os.makedirs(WANDB_DIR, exist_ok=True) + # Initialise Weights & Biases (W&B) run wandb_logger = WandbLogger( project="example-script", entity="graphnet-team", - save_dir="./wandb/", + save_dir=WANDB_DIR, log_model=True, ) -# -- Ensure that custom output directory exists -os.makedirs(wandb_logger.save_dir, exists_ok=True) - # Main function definition def main():