diff --git a/neural_lam/train_model.py b/neural_lam/train_model.py index bcf920da..fd010f59 100644 --- a/neural_lam/train_model.py +++ b/neural_lam/train_model.py @@ -28,6 +28,12 @@ def log_image(self, key, images): import mlflow import io from PIL import Image + + # Retrieve the active run ID from the logger + run_id = self.run_id + # Ensure mlflow uses the same run + mlflow.start_run(run_id=run_id) + # Need to save the image to a temporary file, then log that file # mlflow.log_image, should do this automatically, but it doesn't work temporary_image = f"{key}.png" @@ -39,6 +45,7 @@ def log_image(self, key, images): mlflow.log_image(img, f"{key}.png") #mlflow.log_figure(images[0], key) + mlflow.end_run() def _setup_training_logger(config, datastore, args, run_name): @@ -346,7 +353,7 @@ def main(input_args=None): deterministic=True, #strategy="ddp", #devices=2, - devices=[1, 3], + devices=[0, 1], strategy="auto", accelerator=device_name, logger=training_logger,