diff --git a/composer/loggers/mlflow_logger.py b/composer/loggers/mlflow_logger.py index 0e13791d64..c8a0be47d0 100644 --- a/composer/loggers/mlflow_logger.py +++ b/composer/loggers/mlflow_logger.py @@ -144,7 +144,7 @@ def init(self, state: State, logger: Logger) -> None: # Store the Composer run name in the MLFlow run tags so it can be retrieved for autoresume. self.tags = self.tags or {} - self.tags['composer_run_name'] = state.run_name + self.tags['run_name'] = state.run_name # Adjust name and group based on `rank_zero_only`. if not self._rank_zero_only: @@ -162,8 +162,17 @@ def init(self, state: State, logger: Logger) -> None: # Search for an existing run tagged with this Composer run. assert self._experiment_id is not None existing_runs = mlflow.search_runs(experiment_ids=[self._experiment_id], - filter_string=f'tags.composer_run_name = "{state.run_name}"', + filter_string=f'tags.run_name = "{state.run_name}"', output_format='list') + + # Check for the old tag (`composer_run_name`) For backwards compatibility in case a run using the old + # tag fails and the run is resumed with a newer version of Composer that uses `run_name` instead of + # `composer_run_name`. + if len(existing_runs) == 0: + existing_runs = mlflow.search_runs(experiment_ids=[self._experiment_id], + filter_string=f'tags.composer_run_name = "{state.run_name}"', + output_format='list') + if len(existing_runs) > 0: self._run_id = existing_runs[0].info.run_id else: diff --git a/tests/loggers/test_mlflow_logger.py b/tests/loggers/test_mlflow_logger.py index d5de5b8171..6dd02ab30e 100644 --- a/tests/loggers/test_mlflow_logger.py +++ b/tests/loggers/test_mlflow_logger.py @@ -167,7 +167,7 @@ def test_mlflow_experiment_init_experiment_name(monkeypatch): def test_mlflow_experiment_init_existing_composer_run(monkeypatch): - """ Test that an existing MLFlow run is used if one already exists in the experiment for the Composer run. + """ Test that an existing MLFlow run is used if one tagged with `run_name` exists in the experiment for the Composer run. """ mlflow = pytest.importorskip('mlflow') @@ -186,6 +186,26 @@ def test_mlflow_experiment_init_existing_composer_run(monkeypatch): assert test_logger._run_id == existing_id +def test_mlflow_experiment_init_existing_composer_run_with_old_tag(monkeypatch): + """ Test that an existing MLFlow run is used if one exists with the old `composer_run_name` tag. + """ + mlflow = pytest.importorskip('mlflow') + + monkeypatch.setattr(mlflow, 'set_tracking_uri', MagicMock()) + monkeypatch.setattr(mlflow, 'start_run', MagicMock()) + + mock_state = MagicMock() + mock_state.composer_run_name = 'dummy-run-name' + + existing_id = 'dummy-id' + mock_search_runs = MagicMock(return_value=[MagicMock(info=MagicMock(run_id=existing_id))]) + monkeypatch.setattr(mlflow, 'search_runs', mock_search_runs) + + test_logger = MLFlowLogger() + test_logger.init(state=mock_state, logger=MagicMock()) + assert test_logger._run_id == existing_id + + def test_mlflow_experiment_set_up(tmp_path): """ Test that MLFlow experiment is set up correctly within mlflow """ @@ -231,7 +251,7 @@ def test_mlflow_experiment_set_up(tmp_path): assert actual_run_name == expected_run_name # Check run tagged with Composer run name. - assert tags['composer_run_name'] == mock_state.run_name + assert tags['run_name'] == mock_state.run_name # Check run ended. test_mlflow_logger.post_close()