Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add MLflow log_model option #1544

Merged
merged 51 commits into from
Nov 1, 2024
Merged
Show file tree
Hide file tree
Changes from 49 commits
Commits
Show all changes
51 commits
Select commit Hold shift + click to select a range
06d77db
Register model with MLflow PySDK now that retries are baked in. This …
nancyhung Sep 24, 2024
e40e5dd
Register model with MLflow PySDK now that retries are baked in. This …
nancyhung Sep 24, 2024
454e18b
small changes
nancyhung Sep 24, 2024
c8bd06f
isolated changes
nancyhung Sep 24, 2024
0d3f9ce
pr feedback with a print statement for testing
nancyhung Oct 2, 2024
6ea8de5
some more todos and need to test
nancyhung Oct 4, 2024
81306d8
need to test
nancyhung Oct 11, 2024
b854bb2
Merge branch 'main' into nancy/log-model
nancyhung Oct 11, 2024
bc73f65
use mlflow log model by default
nancyhung Oct 15, 2024
1915042
patch push
nancyhung Oct 15, 2024
bc29278
Merge branch 'main' into nancy/log-model
nancyhung Oct 22, 2024
99589c7
add log statements
nancyhung Oct 22, 2024
04ddfaa
add log outside of process
nancyhung Oct 23, 2024
8e42217
fix
nancyhung Oct 25, 2024
be04e3d
bug
nancyhung Oct 25, 2024
5ab2cc7
print the registered model name
nancyhung Oct 26, 2024
79356d8
update the model registry prefix
nancyhung Oct 26, 2024
4327257
move the download code out of the if statement
nancyhung Oct 26, 2024
6c5fb05
try registering just the model name
nancyhung Oct 26, 2024
bb0dd6a
connect the existing mlflow run id
nancyhung Oct 26, 2024
c5ae4ff
omg it works
nancyhung Oct 26, 2024
4c86e63
pr feedback
nancyhung Oct 29, 2024
b1477bc
add test helper
nancyhung Oct 29, 2024
e376621
fix tests
nancyhung Oct 29, 2024
e2a9d86
mocking mlflow start run
nancyhung Oct 29, 2024
5784c26
fix
nancyhung Oct 29, 2024
625cc29
Merge branch 'main' into nancy/log-model
nancyhung Oct 29, 2024
c939752
pr
nancyhung Oct 30, 2024
9e59a21
json format
nancyhung Oct 30, 2024
5e13b83
patches
nancyhung Oct 30, 2024
19862d2
still not fully working
nancyhung Oct 30, 2024
1eefb84
fixed the final_register_only test case. now need to pass the others
nancyhung Oct 31, 2024
9282fe0
overloading the config mapper still not working
nancyhung Nov 1, 2024
9f9e027
Merge branch 'main' into nancy/log-model
nancyhung Nov 1, 2024
04b520b
using irenes changes
nancyhung Nov 1, 2024
687e48b
default name logic
nancyhung Nov 1, 2024
65a5a1c
typo
nancyhung Nov 1, 2024
ff2f4ac
precommit
nancyhung Nov 1, 2024
67a3acc
precommit again
nancyhung Nov 1, 2024
598d4f3
precommit
nancyhung Nov 1, 2024
6c34a23
fix tests
nancyhung Nov 1, 2024
a5fe322
license
nancyhung Nov 1, 2024
5d37fd5
pr ffeedback and test remove start_run
nancyhung Nov 1, 2024
0132611
precommit
nancyhung Nov 1, 2024
18025f7
start run unnecessary
nancyhung Nov 1, 2024
356e3b2
typing
nancyhung Nov 1, 2024
6daecb8
fix ci
dakinggg Nov 1, 2024
e3b28bf
fix
dakinggg Nov 1, 2024
674506d
clean up tests
dakinggg Nov 1, 2024
7a423b1
fix conflict
dakinggg Nov 1, 2024
30b6927
type ignore
dakinggg Nov 1, 2024
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
330 changes: 241 additions & 89 deletions llmfoundry/callbacks/hf_checkpointer.py
Original file line number Diff line number Diff line change
Expand Up @@ -60,6 +60,26 @@

_LICENSE_FILE_PATTERN = re.compile(r'license(\.[a-z]+|$)', re.IGNORECASE)

from contextlib import contextmanager


@contextmanager
def _monitor_process_saver(mlflow_logger: MLFlowLogger):
# Save the current monitor process
if hasattr(mlflow_logger, 'monitor_process'):
original_monitor_process = mlflow_logger.monitor_process # type: ignore
mlflow_logger.monitor_process = None # type: ignore
else:
original_monitor_process = None

try:
# Yield control back to the calling code
yield
finally:
# Restore the monitor process
if original_monitor_process is not None:
mlflow_logger.monitor_process = original_monitor_process # type: ignore


def _maybe_get_license_filename(
local_dir: str,
Expand Down Expand Up @@ -108,6 +128,91 @@ def _maybe_get_license_filename(
return None


def _log_model_with_multi_process(
mlflow_logger: MLFlowLogger,
python_logging_level: int,
transformers_model: str,
artifact_path: str,
pretrained_model_name: str,
registered_model_name: Optional[str],
await_registration_for: int,
mlflow_logging_config: dict[str, Any],
):
"""Call MLFlowLogger.log_model.

First, patch the mlflow save_model function by removing duplicate tokenizer
files in the model directory. Then, register the model to mlflow from a
child process.
"""
# Setup logging for child process. This ensures that any logs from composer are surfaced.
if python_logging_level > 0:
# If logging_level is 0, then the composer logger was unset.
logging.basicConfig(
format=
f'%(asctime)s: rank{dist.get_global_rank()}[%(process)d][%(threadName)s]: %(levelname)s: %(name)s: %(message)s',
force=True,
)
logging.getLogger('composer').setLevel(python_logging_level)
logging.getLogger('llmfoundry').setLevel(python_logging_level)
nancyhung marked this conversation as resolved.
Show resolved Hide resolved

import mlflow
original_save_model = mlflow.transformers.save_model

def save_model_patch(*args: Any, **kwargs: Any):
original_save_model(*args, **kwargs)
tokenizer_files = []
save_path = kwargs['path']
tokenizer_path = os.path.join(save_path, 'components', 'tokenizer')
if os.path.exists(tokenizer_path):
tokenizer_files = os.listdir(
os.path.join(save_path, 'components', 'tokenizer'),
)
try:
# Check if there are duplicate tokenizer files in the model directory and remove them.
for tokenizer_file_name in tokenizer_files:
dupe_file = os.path.isfile(
os.path.join(save_path, 'model', tokenizer_file_name),
)
if dupe_file:
log.debug(
f'Removing duplicate tokenizer file: {tokenizer_file_name}',
)
os.remove(
os.path.join(save_path, 'model', tokenizer_file_name),
)
license_filename = _maybe_get_license_filename(
save_path,
pretrained_model_name,
)
if license_filename is not None:
mlflow_logger._mlflow_client.log_artifact(
mlflow_logger._run_id,
os.path.join(save_path, license_filename),
)
except Exception as e:
log.error(
f'Exception when removing duplicate tokenizer files in the model directory',
e,
)

mlflow.transformers.save_model = save_model_patch

mlflow.set_tracking_uri(mlflow_logger.tracking_uri)
if mlflow_logger.model_registry_uri is not None:
mlflow.set_registry_uri(mlflow_logger.model_registry_uri)

register_model_path = f'{mlflow_logger.model_registry_prefix}.{registered_model_name}' if mlflow_logger.model_registry_prefix and registered_model_name else registered_model_name
mlflow_logger.log_model(
transformers_model=transformers_model,
flavor='transformers',
artifact_path=artifact_path,
registered_model_name=register_model_path,
run_id=mlflow_logger._run_id,
await_registration_for=await_registration_for,
**mlflow_logging_config,
)


def _register_model_with_run_id_multiprocess(
mlflow_logger: MLFlowLogger,
composer_logging_level: int,
Expand Down Expand Up @@ -675,101 +780,148 @@ def tensor_hook(

if dist.get_global_rank() == 0:
if register_to_mlflow:
new_model_instance = self.transform_model_pre_registration(
new_model_instance,
)

components = {'model': new_model_instance}
if original_tokenizer is not None:
components['tokenizer'] = original_tokenizer

log.debug('Logging Hugging Face model to MLFlow')
for i, mlflow_logger in enumerate(self.mlflow_loggers):
log.debug(
f'Registering model to UC at {mlflow_logger.model_registry_prefix}.{self.mlflow_registered_model_name}',
)
local_save_path = str(
Path(temp_save_dir) / f'mlflow_save_{i}',
)

# TODO: Remove after mlflow fixes the bug that makes this necessary
import mlflow
mlflow.store._unity_catalog.registry.rest_store.get_feature_dependencies = lambda *args, **kwargs: ''
model_saving_kwargs: dict[str, Any] = {
'path': local_save_path,
}
if self.using_peft:
model_saving_kwargs['flavor'] = 'peft'
model_saving_kwargs['save_pretrained_dir'
] = temp_save_dir
model_saving_kwargs[
'metadata'] = self.mlflow_logging_config['metadata']
else:
model_saving_kwargs['flavor'] = 'transformers'
model_saving_kwargs['transformers_model'] = components
model_saving_kwargs.update(self.mlflow_logging_config)
if self.using_peft:

context_manager = te.onnx_export(
True,
) if is_te_imported and state.precision == Precision.AMP_FP8 else contextlib.nullcontext(
# Save and register peft model to mlflow, this code path uses our older two step logic
self._save_and_register_peft_model(
state,
new_model_instance,
original_tokenizer,
temp_save_dir,
)
with context_manager:
nancyhung marked this conversation as resolved.
Show resolved Hide resolved
# Add the pip requirements directly to avoid mlflow
# attempting to run inference on the model
model_saving_kwargs['pip_requirements'] = [
'transformers',
'torch',
]
mlflow_logger.save_model(**model_saving_kwargs)

# Upload the license file generated by mlflow during the model saving.
license_filename = _maybe_get_license_filename(
local_save_path,
self.pretrained_model_name,
else:
register_save_dir = os.path.join(
temp_save_dir,
'register_save',
)
if license_filename is not None:
mlflow_logger._mlflow_client.log_artifact(
mlflow_logger._run_id,
os.path.join(local_save_path, license_filename),
)

self.pre_register_edit(local_save_path,)

# Save the monitor process to be restored after registering the model.
if hasattr(mlflow_logger, 'monitor_process'):
monitor_process = mlflow_logger.monitor_process # type: ignore
mlflow_logger.monitor_process = None # type: ignore
else:
monitor_process = None

# Spawn a new process to register the model.
process = SpawnProcess(
target=_register_model_with_run_id_multiprocess,
kwargs={
'mlflow_logger':
mlflow_logger,
'composer_logging_level':
logging.getLogger('composer').level,
'model_uri':
local_save_path,
'name':
self.mlflow_registered_model_name,
'await_creation_for':
3600,
},
assert new_model_instance is not None
new_model_instance = self.transform_model_pre_registration(
new_model_instance,
)
process.start()

# Restore the monitor process.
if monitor_process is not None:
mlflow_logger.monitor_process = monitor_process # type: ignore
self.register_processes.append(process)

# Save the temporary directory to be cleaned up later.
if use_temp_dir:
self.temp_save_dir = temp_save_dir
new_model_instance.save_pretrained(register_save_dir)
if original_tokenizer:
original_tokenizer.save_pretrained(register_save_dir)

self.pre_register_edit(register_save_dir)

for mlflow_logger in self.mlflow_loggers:
if self.mlflow_registered_model_name:
log.debug(
f'Registering model to UC at {mlflow_logger.model_registry_prefix}.{self.mlflow_registered_model_name}',
)

# Save the monitor process to be restored after registering the model.
with _monitor_process_saver(mlflow_logger):
process = SpawnProcess(
target=_log_model_with_multi_process,
kwargs={
'mlflow_logger':
mlflow_logger,
'python_logging_level':
logging.getLogger('llmfoundry').level,
'transformers_model':
register_save_dir,
'artifact_path':
'final_model_checkpoint',
'pretrained_model_name':
self.pretrained_model_name,
'registered_model_name':
self.mlflow_registered_model_name,
'await_registration_for':
3600,
'mlflow_logging_config':
self.mlflow_logging_config,
},
)

process.start()
self.register_processes.append(process)

# Save the temporary directory to be cleaned up later.
if use_temp_dir:
self.temp_save_dir = temp_save_dir
else:
# Clean up the temporary directory if we don't need to register to mlflow.
if use_temp_dir:
shutil.rmtree(temp_save_dir)
dist.barrier()

def _save_and_register_peft_model(
self,
state: State,
new_model_instance: Any,
original_tokenizer: Optional[Any],
save_dir: str,
):
new_model_instance = self.transform_model_pre_registration(
new_model_instance,
)
components = {'model': new_model_instance}
if original_tokenizer is not None:
components['tokenizer'] = original_tokenizer

log.debug('Logging Hugging Face model to MLFlow')
for i, mlflow_logger in enumerate(self.mlflow_loggers):
log.debug(
f'Registering model to UC at {mlflow_logger.model_registry_prefix}.{self.mlflow_registered_model_name}',
)

local_save_path = str(Path(save_dir) / f'mlflow_save_{i}',)

# TODO: Remove after mlflow fixes the bug that makes this necessary
import mlflow
mlflow.store._unity_catalog.registry.rest_store.get_feature_dependencies = lambda *args, **kwargs: ''

model_saving_kwargs: dict[str, Any] = {
'path': local_save_path,
}
model_saving_kwargs['flavor'] = 'peft'
model_saving_kwargs['save_pretrained_dir'] = save_dir
model_saving_kwargs['metadata'] = self.mlflow_logging_config[
'metadata']

context_manager = te.onnx_export(
True,
) if is_te_imported and state.precision == Precision.AMP_FP8 else contextlib.nullcontext(
)
with context_manager:
# Add the pip requirements directly to avoid mlflow
# attempting to run inference on the model
model_saving_kwargs['pip_requirements'] = [
'transformers',
'torch',
]
mlflow_logger.save_model(**model_saving_kwargs)

# Upload the license file generated by mlflow during the model saving.
# Get and log the license file.
license_filename = _maybe_get_license_filename(
local_save_path,
self.pretrained_model_name,
)
if license_filename is not None:
mlflow_logger._mlflow_client.log_artifact(
mlflow_logger._run_id,
os.path.join(local_save_path, license_filename),
)

self.pre_register_edit(local_save_path)

with _monitor_process_saver(mlflow_logger):
process = SpawnProcess(
target=_register_model_with_run_id_multiprocess,
kwargs={
'mlflow_logger':
mlflow_logger,
'composer_logging_level':
logging.getLogger('composer').level,
'model_uri':
local_save_path,
'name':
self.mlflow_registered_model_name,
'await_creation_for':
3600,
},
)
process.start()
self.register_processes.append(process)
Loading
Loading