Skip to content

Commit

Permalink
pre-commit
Browse files Browse the repository at this point in the history
  • Loading branch information
wenfeiy-db committed Nov 9, 2023
1 parent 622f141 commit e416b96
Show file tree
Hide file tree
Showing 2 changed files with 11 additions and 5 deletions.
12 changes: 9 additions & 3 deletions llmfoundry/callbacks/hf_checkpointer.py
Original file line number Diff line number Diff line change
Expand Up @@ -69,7 +69,9 @@ def __init__(

# mlflow config setup
self.mlflow_registered_model_name = mlflow_registered_model_name
print(f"__init__::mlflow_logging_config={mlflow_logging_config}\n[{type(mlflow_logging_config)}]")
print(
f'__init__::mlflow_logging_config={mlflow_logging_config}\n[{type(mlflow_logging_config)}]'
)
if mlflow_logging_config is None:
mlflow_logging_config = {}
if self.mlflow_registered_model_name is not None:
Expand Down Expand Up @@ -260,8 +262,12 @@ def _save_checkpoint(self, state: State, logger: Logger):
# TODO: Remove after mlflow fixes the bug that makes this necessary
import mlflow
mlflow.store._unity_catalog.registry.rest_store.get_feature_dependencies = lambda *args, **kwargs: ''
print(f"_save_checkpoint::self.mlflow_logging_config={self.mlflow_logging_config}\n[{type(self.mlflow_logging_config)}]")
print(f"_save_checkpoint::self.mlflow_logging_config['metadata']={self.mlflow_logging_config['metadata']}\n{type(self.mlflow_logging_config['metadata'])}]")
print(
f'_save_checkpoint::self.mlflow_logging_config={self.mlflow_logging_config}\n[{type(self.mlflow_logging_config)}]'
)
print(
f"_save_checkpoint::self.mlflow_logging_config['metadata']={self.mlflow_logging_config['metadata']}\n{type(self.mlflow_logging_config['metadata'])}]"
)
mlflow_logger.save_model(
flavor='transformers',
transformers_model=components,
Expand Down
4 changes: 2 additions & 2 deletions llmfoundry/utils/builders.py
Original file line number Diff line number Diff line change
Expand Up @@ -117,10 +117,10 @@ def build_callback(name: str, kwargs: Dict[str, Any]) -> Callback:
elif name == 'early_stopper':
return EarlyStopper(**kwargs)
elif name == 'hf_checkpointer':
mlflow_logging_config = kwargs.pop("mlflow_logging_config", None)
mlflow_logging_config = kwargs.pop('mlflow_logging_config', None)
if isinstance(mlflow_logging_config, omegaconf.dictconfig.DictConfig):
mlflow_logging_config = om.to_object(mlflow_logging_config)
kwargs["mlflow_logging_config"] = mlflow_logging_config
kwargs['mlflow_logging_config'] = mlflow_logging_config
return HuggingFaceCheckpointer(**kwargs)
else:
raise ValueError(f'Not sure how to build callback: {name}')
Expand Down

0 comments on commit e416b96

Please sign in to comment.