From 502eb12ff40c69b8a7d693ace8120057afd34338 Mon Sep 17 00:00:00 2001 From: Charles Tang Date: Thu, 11 Jul 2024 22:55:06 -0700 Subject: [PATCH] Fix MLFlow Save Model for TE (#1353) --- llmfoundry/callbacks/hf_checkpointer.py | 7 ++++++- 1 file changed, 6 insertions(+), 1 deletion(-) diff --git a/llmfoundry/callbacks/hf_checkpointer.py b/llmfoundry/callbacks/hf_checkpointer.py index c2c8cabf4d..4de7f9f2c6 100644 --- a/llmfoundry/callbacks/hf_checkpointer.py +++ b/llmfoundry/callbacks/hf_checkpointer.py @@ -586,7 +586,12 @@ def dtensor_to_tensor_hook( model_saving_kwargs['transformers_model'] = components model_saving_kwargs.update(self.mlflow_logging_config) - mlflow_logger.save_model(**model_saving_kwargs) + context_manager = te.onnx_export( + True, + ) if is_te_imported and state.precision == Precision.AMP_FP8 else contextlib.nullcontext( + ) + with context_manager: + mlflow_logger.save_model(**model_saving_kwargs) # Upload the license file generated by mlflow during the model saving. license_filename = _maybe_get_license_filename(