Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add MLflow log_model option #1544

Merged
merged 51 commits into from
Nov 1, 2024
Merged
Changes from 4 commits
Commits
Show all changes
51 commits
Select commit Hold shift + click to select a range
06d77db
Register model with MLflow PySDK now that retries are baked in. This …
nancyhung Sep 24, 2024
e40e5dd
Register model with MLflow PySDK now that retries are baked in. This …
nancyhung Sep 24, 2024
454e18b
small changes
nancyhung Sep 24, 2024
c8bd06f
isolated changes
nancyhung Sep 24, 2024
0d3f9ce
pr feedback with a print statement for testing
nancyhung Oct 2, 2024
6ea8de5
some more todos and need to test
nancyhung Oct 4, 2024
81306d8
need to test
nancyhung Oct 11, 2024
b854bb2
Merge branch 'main' into nancy/log-model
nancyhung Oct 11, 2024
bc73f65
use mlflow log model by default
nancyhung Oct 15, 2024
1915042
patch push
nancyhung Oct 15, 2024
bc29278
Merge branch 'main' into nancy/log-model
nancyhung Oct 22, 2024
99589c7
add log statements
nancyhung Oct 22, 2024
04ddfaa
add log outside of process
nancyhung Oct 23, 2024
8e42217
fix
nancyhung Oct 25, 2024
be04e3d
bug
nancyhung Oct 25, 2024
5ab2cc7
print the registered model name
nancyhung Oct 26, 2024
79356d8
update the model registry prefix
nancyhung Oct 26, 2024
4327257
move the download code out of the if statement
nancyhung Oct 26, 2024
6c5fb05
try registering just the model name
nancyhung Oct 26, 2024
bb0dd6a
connect the existing mlflow run id
nancyhung Oct 26, 2024
c5ae4ff
omg it works
nancyhung Oct 26, 2024
4c86e63
pr feedback
nancyhung Oct 29, 2024
b1477bc
add test helper
nancyhung Oct 29, 2024
e376621
fix tests
nancyhung Oct 29, 2024
e2a9d86
mocking mlflow start run
nancyhung Oct 29, 2024
5784c26
fix
nancyhung Oct 29, 2024
625cc29
Merge branch 'main' into nancy/log-model
nancyhung Oct 29, 2024
c939752
pr
nancyhung Oct 30, 2024
9e59a21
json format
nancyhung Oct 30, 2024
5e13b83
patches
nancyhung Oct 30, 2024
19862d2
still not fully working
nancyhung Oct 30, 2024
1eefb84
fixed the final_register_only test case. now need to pass the others
nancyhung Oct 31, 2024
9282fe0
overloading the config mapper still not working
nancyhung Nov 1, 2024
9f9e027
Merge branch 'main' into nancy/log-model
nancyhung Nov 1, 2024
04b520b
using irenes changes
nancyhung Nov 1, 2024
687e48b
default name logic
nancyhung Nov 1, 2024
65a5a1c
typo
nancyhung Nov 1, 2024
ff2f4ac
precommit
nancyhung Nov 1, 2024
67a3acc
precommit again
nancyhung Nov 1, 2024
598d4f3
precommit
nancyhung Nov 1, 2024
6c34a23
fix tests
nancyhung Nov 1, 2024
a5fe322
license
nancyhung Nov 1, 2024
5d37fd5
pr ffeedback and test remove start_run
nancyhung Nov 1, 2024
0132611
precommit
nancyhung Nov 1, 2024
18025f7
start run unnecessary
nancyhung Nov 1, 2024
356e3b2
typing
nancyhung Nov 1, 2024
6daecb8
fix ci
dakinggg Nov 1, 2024
e3b28bf
fix
dakinggg Nov 1, 2024
674506d
clean up tests
dakinggg Nov 1, 2024
7a423b1
fix conflict
dakinggg Nov 1, 2024
30b6927
type ignore
dakinggg Nov 1, 2024
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
65 changes: 64 additions & 1 deletion llmfoundry/callbacks/hf_checkpointer.py
Original file line number Diff line number Diff line change
Expand Up @@ -114,7 +114,7 @@ def _register_model_with_run_id_multiprocess(
name: str,
await_creation_for: int,
):
"""Call MLFlowLogger.register_model_with_run_id.
"""Call MLFlowLogger.register_model.

Used mainly to register from a child process.
"""
Expand All @@ -135,6 +135,43 @@ def _register_model_with_run_id_multiprocess(
)


def _log_model_multiprocess(
mlflow_logger: MLFlowLogger,
composer_logging_level: int,
input_example: dict[str, Any],
task: str,
name: str,
model_name: str,
await_creation_for: int,
):
"""
Call MLFlowLogger.log_model.

Used mainly to log from a child process.
"""
# Setup logging for child process. This ensures that any logs from composer are surfaced.
if composer_logging_level > 0:
# If logging_level is 0, then the composer logger was unset.
logging.basicConfig(
format=
f'%(asctime)s: rank{dist.get_global_rank()}[%(process)d][%(threadName)s]: %(levelname)s: %(name)s: %(message)s',
)
logging.getLogger('composer').setLevel(composer_logging_level)
mlflow_logger.log_model(
nancyhung marked this conversation as resolved.
Show resolved Hide resolved
flavor='transformers',
artifact_path="model",
nancyhung marked this conversation as resolved.
Show resolved Hide resolved
input_example=input_example,
task=task,
metadata={
"task": task,
"databricks_model_source": "genai-fine-tuning",
nancyhung marked this conversation as resolved.
Show resolved Hide resolved
"pretrained_model_name": model_name,
}, # This metadata is currently needed for optimized serving
registered_model_name=name,
await_creation_for=await_creation_for
)


class HuggingFaceCheckpointer(Callback):
"""Save a huggingface formatted checkpoint during training.

Expand Down Expand Up @@ -202,6 +239,7 @@ def __init__(
+
f'Defaulting to final_register_only=False and saving the HuggingFace checkpoint to {save_folder=}.',
)
self.use_mlflow_log_model = False

# mlflow config setup
if mlflow_logging_config is None:
Expand Down Expand Up @@ -232,6 +270,8 @@ def __init__(
'input_example',
default_input_example,
)
if mlflow_logging_config['use_mlflow_log_model']:
self.use_mlflow_log_model = True

self.mlflow_logging_config = mlflow_logging_config
if 'metadata' in self.mlflow_logging_config:
Expand Down Expand Up @@ -729,6 +769,29 @@ def tensor_hook(
monitor_process = None

# Spawn a new process to register the model.
# Slower method to register the model via log_model.
nancyhung marked this conversation as resolved.
Show resolved Hide resolved
if self.use_mlflow_log_model:
nancyhung marked this conversation as resolved.
Show resolved Hide resolved
process = SpawnProcess(
target=_log_model_multiprocess,
kwargs={
'mlflow_logger':
mlflow_logger,
'composer_logging_level':
logging.getLogger('composer').level,
'model_uri':
local_save_path,
'name':
self.mlflow_registered_model_name,
'model_name':
self.pretrained_model_name,
'input_example':
self.mlflow_logging_config['input_example'],
'await_creation_for':
3600,
},
)
process.start()
# Faster method to register model in parallel.
process = SpawnProcess(
target=_register_model_with_run_id_multiprocess,
kwargs={
Expand Down
Loading