From 2620bd1b7677c2ab9841316481c1fa34603761de Mon Sep 17 00:00:00 2001 From: Kasper Hintz Date: Mon, 9 Dec 2024 15:02:24 +0000 Subject: [PATCH] set logger url --- neural_lam/train_model.py | 25 ++++++++++++++++++------- tests/test_config.py | 2 ++ 2 files changed, 20 insertions(+), 7 deletions(-) diff --git a/neural_lam/train_model.py b/neural_lam/train_model.py index 7b2b055..32ec921 100644 --- a/neural_lam/train_model.py +++ b/neural_lam/train_model.py @@ -29,6 +29,10 @@ class CustomMLFlowLogger(pl.loggers.MLFlowLogger): + """ + Custom MLFlow logger that adds functionality not present in the default + """ + def __init__(self, experiment_name, tracking_uri): super().__init__( experiment_name=experiment_name, tracking_uri=tracking_uri @@ -38,9 +42,22 @@ def __init__(self, experiment_name, tracking_uri): @property def save_dir(self): + """ + Returns the directory where the MLFlow artifacts are saved + """ return "mlruns" def log_image(self, key, images, step=None): + """ + Log a matplotlib figure as an image to MLFlow + + key: str + Key to log the image under + images: list + List of matplotlib figures to log + step: Union[int, None] + Step to log the image under. If None, logs under the key directly + """ # Third-party from PIL import Image @@ -61,7 +78,7 @@ def log_model(self, data_module, model): with torch.no_grad(): model_output = model.common_step(input_example)[ 0 - ] # expects batch, returns tuple (prediction, target, pred_std, _) + ] # common_step returns tuple (prediction, target, pred_std, _) log_model_input_example = { name: tensor.cpu().numpy() @@ -81,8 +98,6 @@ def log_model(self, data_module, model): signature=signature, ) - # validate_serving_input(model_uri, validate_example) - def create_input_example(self, data_module): if data_module.val_dataset is None: @@ -405,9 +420,5 @@ def main(input_args=None): else: trainer.fit(model=model, datamodule=data_module, ckpt_path=args.load) - # Log model. TODO: only log for mlflow - training_logger.log_model(data_module, model) - - if __name__ == "__main__": main() diff --git a/tests/test_config.py b/tests/test_config.py index 1ff40bc..ede6b24 100644 --- a/tests/test_config.py +++ b/tests/test_config.py @@ -44,6 +44,8 @@ def test_config_serialization(state_weighting_config): kind: mdp config_path: "" training: + logger: wandb + logger_url: https://wandb.ai state_feature_weighting: __config_class__: ManualStateFeatureWeighting weights: