From de27e9a9676dbf3115ed7e2691493c73aa265fc6 Mon Sep 17 00:00:00 2001 From: Kasper Hintz Date: Mon, 7 Oct 2024 16:59:23 +0000 Subject: [PATCH] log model --- neural_lam/train_model.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/neural_lam/train_model.py b/neural_lam/train_model.py index f21ac96e..792a00f6 100644 --- a/neural_lam/train_model.py +++ b/neural_lam/train_model.py @@ -48,7 +48,7 @@ def log_image(self, key, images): def log_model(self, model): # Create model signature - #signature = infer_signature(X.numpy(), model(X).detach().numpy()) + #signature = infer_signature(train_dataset.numpy(), model(train_dataset).detach().numpy()) mlflow.pytorch.log_model(model, "model") @@ -361,7 +361,7 @@ def main(input_args=None): # Log the model training_logger.log_model(model) - + # data_module.train_dataloader().dataset.data if __name__ == "__main__": main()