Skip to content

Commit

Permalink
clean up
Browse files Browse the repository at this point in the history
  • Loading branch information
wenfeiy-db committed Nov 10, 2023
1 parent fff9d07 commit a0626e2
Show file tree
Hide file tree
Showing 3 changed files with 16 additions and 38 deletions.
9 changes: 0 additions & 9 deletions llmfoundry/callbacks/hf_checkpointer.py
Original file line number Diff line number Diff line change
Expand Up @@ -69,9 +69,6 @@ def __init__(

# mlflow config setup
self.mlflow_registered_model_name = mlflow_registered_model_name
print(
f'__init__::mlflow_logging_config={mlflow_logging_config}\n[{type(mlflow_logging_config)}]'
)
if mlflow_logging_config is None:
mlflow_logging_config = {}
if self.mlflow_registered_model_name is not None:
Expand Down Expand Up @@ -262,12 +259,6 @@ def _save_checkpoint(self, state: State, logger: Logger):
# TODO: Remove after mlflow fixes the bug that makes this necessary
import mlflow
mlflow.store._unity_catalog.registry.rest_store.get_feature_dependencies = lambda *args, **kwargs: ''
print(
f'_save_checkpoint::self.mlflow_logging_config={self.mlflow_logging_config}\n[{type(self.mlflow_logging_config)}]'
)
print(
f"_save_checkpoint::self.mlflow_logging_config['metadata']={self.mlflow_logging_config['metadata']}\n{type(self.mlflow_logging_config['metadata'])}]"
)
mlflow_logger.save_model(
flavor='transformers',
transformers_model=components,
Expand Down
17 changes: 0 additions & 17 deletions llmfoundry/utils/builders.py
Original file line number Diff line number Diff line change
Expand Up @@ -120,23 +120,6 @@ def build_callback(name: str, kwargs: Dict[str, Any]) -> Callback:
elif name == 'hf_checkpointer':
if isinstance(kwargs, DictConfig):
kwargs = om.to_object(kwargs)
# print(type(kwargs))
# kwargs_copy = deepcopy(kwargs)
# mlflow_logging_config = kwargs_copy.pop('mlflow_logging_config', None)
# print(f"build_callback::mlflow_logging_config={mlflow_logging_config}")
# print(f"build_callback::isinstance(mlflow_logging_config, DictConfig)={isinstance(mlflow_logging_config, DictConfig)}")
# if isinstance(mlflow_logging_config, DictConfig):
# print("converting mlflow_logging_config")
# mlflow_logging_config = om.to_object(mlflow_logging_config)
# print(f"build_callback::mlflow_logging_config={mlflow_logging_config}")
# print(f"[{type(mlflow_logging_config)}]")
# print(f"after if statement: build_callback::mlflow_logging_config={mlflow_logging_config}")
# print(f"[{type(mlflow_logging_config)}]")
# print(f"before reassign: kwargs_copy.get('mlflow_logging_config', None)={kwargs_copy.get('mlflow_logging_config', None)}")
# print(f"[{type(kwargs_copy.get('mlflow_logging_config', None))}]")
# kwargs_copy['mlflow_logging_config'] = mlflow_logging_config
# print(f"after reassign - build_callback::kwargs_copy['mlflow_logging_config']={kwargs_copy['mlflow_logging_config']}")
# print(f"[{type(kwargs_copy['mlflow_logging_config'])}")
return HuggingFaceCheckpointer(**kwargs)
else:
raise ValueError(f'Not sure how to build callback: {name}')
Expand Down
28 changes: 16 additions & 12 deletions tests/test_hf_conversion_script.py
Original file line number Diff line number Diff line change
Expand Up @@ -421,18 +421,22 @@ def test_huggingface_conversion_callback(model: str, tmp_path: pathlib.Path,
trainer.fit()

if dist.get_global_rank() == 0:
if log_to_mlflow:
mlflow_logger_mock.save_model.assert_called_with(
flavor='transformers',
transformers_model=ANY,
path=ANY,
task='text-generation',
metatdata={'task': 'llm/v1/completions'}
)
assert mlflow_logger_mock.register_model.call_count == 1
else:
assert mlflow_logger_mock.save_model.call_count == 0
assert mlflow_logger_mock.register_model.call_count == 0
assert mlflow_logger_mock.save_model.call_count == (1 if log_to_mlflow
else 0)
assert mlflow_logger_mock.register_model.call_count == (
1 if log_to_mlflow else 0)
# if log_to_mlflow:
# # mlflow_logger_mock.save_model.assert_called_with(
# # flavor='transformers',
# # transformers_model=ANY,
# # path=ANY,
# # task='text-generation',
# # metatdata={'task': 'llm/v1/completions'}
# # )
# assert mlflow_logger_mock.register_model.call_count == 1
# else:
# assert mlflow_logger_mock.save_model.call_count == 0
# assert mlflow_logger_mock.register_model.call_count == 0
else:
assert mlflow_logger_mock.log_model.call_count == 0
assert mlflow_logger_mock.register_model.call_count == 0
Expand Down

0 comments on commit a0626e2

Please sign in to comment.