Skip to content

Commit

Permalink
Use create_model_version instead of register_model (#953)
Browse files Browse the repository at this point in the history
  • Loading branch information
dakinggg authored Feb 9, 2024
1 parent 2e59620 commit 2f64a14
Show file tree
Hide file tree
Showing 2 changed files with 8 additions and 8 deletions.
4 changes: 2 additions & 2 deletions llmfoundry/callbacks/hf_checkpointer.py
Original file line number Diff line number Diff line change
Expand Up @@ -394,8 +394,8 @@ def _save_checkpoint(self, state: State, logger: Logger):
os.path.join(local_save_path, license_filename),
)

mlflow_logger.register_model(
mlflow_logger.register_model_with_run_id(
model_uri=local_save_path,
name=self.mlflow_registered_model_name,
await_registration_for=3600,
await_creation_for=3600,
)
12 changes: 6 additions & 6 deletions tests/a_scripts/inference/test_convert_composer_to_hf.py
Original file line number Diff line number Diff line change
Expand Up @@ -306,7 +306,7 @@ def test_huggingface_conversion_callback_interval(
mlflow_logger_mock = MagicMock(spec=MLFlowLogger)
mlflow_logger_mock.state_dict = lambda *args, **kwargs: {}
mlflow_logger_mock.save_model = MagicMock(wraps=_save_model_mock)
mlflow_logger_mock.register_model = MagicMock()
mlflow_logger_mock.register_model_with_run_id = MagicMock()
mlflow_logger_mock.model_registry_prefix = ''
mlflow_logger_mock._experiment_id = 'mlflow-experiment-id'
mlflow_logger_mock._run_id = 'mlflow-run-id'
Expand Down Expand Up @@ -334,10 +334,10 @@ def test_huggingface_conversion_callback_interval(
input_example=ANY,
signature=ANY,
metadata={'task': 'llm/v1/completions'})
assert mlflow_logger_mock.register_model.call_count == 1
assert mlflow_logger_mock.register_model_with_run_id.call_count == 1
else:
assert mlflow_logger_mock.save_model.call_count == 0
assert mlflow_logger_mock.register_model.call_count == 0
assert mlflow_logger_mock.register_model_with_run_id.call_count == 0

normal_checkpoints = [
name for name in os.listdir(os.path.join(tmp_path, 'checkpoints'))
Expand Down Expand Up @@ -564,7 +564,7 @@ def test_huggingface_conversion_callback(
mlflow_logger_mock = MagicMock(spec=MLFlowLogger)
mlflow_logger_mock.state_dict = lambda *args, **kwargs: {}
mlflow_logger_mock.save_model = MagicMock(wraps=_save_model_mock)
mlflow_logger_mock.register_model = MagicMock()
mlflow_logger_mock.register_model_with_run_id = MagicMock()
mlflow_logger_mock.model_registry_prefix = ''
mlflow_logger_mock._experiment_id = 'mlflow-experiment-id'
mlflow_logger_mock._run_id = 'mlflow-run-id'
Expand Down Expand Up @@ -628,10 +628,10 @@ def test_huggingface_conversion_callback(
}
}
mlflow_logger_mock.save_model.assert_called_with(**expectation)
assert mlflow_logger_mock.register_model.call_count == 1
assert mlflow_logger_mock.register_model_with_run_id.call_count == 1
else:
assert mlflow_logger_mock.log_model.call_count == 0
assert mlflow_logger_mock.register_model.call_count == 0
assert mlflow_logger_mock.register_model_with_run_id.call_count == 0

# summon full params to check equivalence
from torch.distributed.fsdp import FullyShardedDataParallel as FSDP
Expand Down

0 comments on commit 2f64a14

Please sign in to comment.