Skip to content

Commit

Permalink
Use log_model API to register the model (#1544)
Browse files Browse the repository at this point in the history
Co-authored-by: Daniel King <[email protected]>
  • Loading branch information
nancyhung and dakinggg authored Nov 1, 2024
1 parent 7c991e9 commit 2ce6296
Show file tree
Hide file tree
Showing 2 changed files with 282 additions and 121 deletions.
332 changes: 242 additions & 90 deletions llmfoundry/callbacks/hf_checkpointer.py
Original file line number Diff line number Diff line change
Expand Up @@ -60,6 +60,26 @@

_LICENSE_FILE_PATTERN = re.compile(r'license(\.[a-z]+|$)', re.IGNORECASE)

from contextlib import contextmanager


@contextmanager
def _monitor_process_saver(mlflow_logger: MLFlowLogger):
# Save the current monitor process
if hasattr(mlflow_logger, 'monitor_process'):
original_monitor_process = mlflow_logger.monitor_process # type: ignore
mlflow_logger.monitor_process = None # type: ignore
else:
original_monitor_process = None

try:
# Yield control back to the calling code
yield
finally:
# Restore the monitor process
if original_monitor_process is not None:
mlflow_logger.monitor_process = original_monitor_process # type: ignore


def _maybe_get_license_filename(
local_dir: str,
Expand Down Expand Up @@ -108,6 +128,91 @@ def _maybe_get_license_filename(
return None


def _log_model_with_multi_process(
mlflow_logger: MLFlowLogger,
python_logging_level: int,
transformers_model: str,
artifact_path: str,
pretrained_model_name: str,
registered_model_name: Optional[str],
await_registration_for: int,
mlflow_logging_config: dict[str, Any],
):
"""Call MLFlowLogger.log_model.
First, patch the mlflow save_model function by removing duplicate tokenizer
files in the model directory. Then, register the model to mlflow from a
child process.
"""
# Setup logging for child process. This ensures that any logs from composer are surfaced.
if python_logging_level > 0:
# If logging_level is 0, then the composer logger was unset.
logging.basicConfig(
format=
f'%(asctime)s: rank{dist.get_global_rank()}[%(process)d][%(threadName)s]: %(levelname)s: %(name)s: %(message)s',
force=True,
)
logging.getLogger('composer').setLevel(python_logging_level)
logging.getLogger('llmfoundry').setLevel(python_logging_level)

import mlflow
original_save_model = mlflow.transformers.save_model

def save_model_patch(*args: Any, **kwargs: Any):
original_save_model(*args, **kwargs)
tokenizer_files = []
save_path = kwargs['path']
tokenizer_path = os.path.join(save_path, 'components', 'tokenizer')
if os.path.exists(tokenizer_path):
tokenizer_files = os.listdir(
os.path.join(save_path, 'components', 'tokenizer'),
)
try:
# Check if there are duplicate tokenizer files in the model directory and remove them.
for tokenizer_file_name in tokenizer_files:
dupe_file = os.path.isfile(
os.path.join(save_path, 'model', tokenizer_file_name),
)
if dupe_file:
log.debug(
f'Removing duplicate tokenizer file: {tokenizer_file_name}',
)
os.remove(
os.path.join(save_path, 'model', tokenizer_file_name),
)
license_filename = _maybe_get_license_filename(
save_path,
pretrained_model_name,
)
if license_filename is not None:
mlflow_logger._mlflow_client.log_artifact(
mlflow_logger._run_id,
os.path.join(save_path, license_filename),
)
except Exception as e:
log.error(
f'Exception when removing duplicate tokenizer files in the model directory',
e,
)

mlflow.transformers.save_model = save_model_patch # type: ignore

mlflow.set_tracking_uri(mlflow_logger.tracking_uri)
if mlflow_logger.model_registry_uri is not None:
mlflow.set_registry_uri(mlflow_logger.model_registry_uri)

register_model_path = f'{mlflow_logger.model_registry_prefix}.{registered_model_name}' if mlflow_logger.model_registry_prefix and registered_model_name else registered_model_name
mlflow_logger.log_model(
transformers_model=transformers_model,
flavor='transformers',
artifact_path=artifact_path,
registered_model_name=register_model_path,
run_id=mlflow_logger._run_id,
await_registration_for=await_registration_for,
**mlflow_logging_config,
)


def _register_model_with_run_id_multiprocess(
mlflow_logger: MLFlowLogger,
composer_logging_level: int,
Expand Down Expand Up @@ -676,102 +781,149 @@ def tensor_hook(

if dist.get_global_rank() == 0:
if register_to_mlflow:
new_model_instance = self.transform_model_pre_registration(
new_model_instance,
)

components = {'model': new_model_instance}
if original_tokenizer is not None:
components['tokenizer'] = original_tokenizer

log.debug('Logging Hugging Face model to MLFlow')
for i, mlflow_logger in enumerate(self.mlflow_loggers):
log.debug(
f'Registering model to UC at {mlflow_logger.model_registry_prefix}.{self.mlflow_registered_model_name}',
)
local_save_path = str(
Path(temp_save_dir) / f'mlflow_save_{i}',
)

# TODO: Remove after mlflow fixes the bug that makes this necessary
import mlflow
import mlflow.store
mlflow.store._unity_catalog.registry.rest_store.get_feature_dependencies = lambda *args, **kwargs: ''
model_saving_kwargs: dict[str, Any] = {
'path': local_save_path,
}
if self.using_peft:
model_saving_kwargs['flavor'] = 'peft'
model_saving_kwargs['save_pretrained_dir'
] = temp_save_dir
model_saving_kwargs[
'metadata'] = self.mlflow_logging_config['metadata']
else:
model_saving_kwargs['flavor'] = 'transformers'
model_saving_kwargs['transformers_model'] = components
model_saving_kwargs.update(self.mlflow_logging_config)
if self.using_peft:

context_manager = te.onnx_export(
True,
) if is_te_imported and state.precision == Precision.AMP_FP8 else contextlib.nullcontext(
# Save and register peft model to mlflow, this code path uses our older two step logic
self._save_and_register_peft_model(
state,
new_model_instance,
original_tokenizer,
temp_save_dir,
)
with context_manager:
# Add the pip requirements directly to avoid mlflow
# attempting to run inference on the model
model_saving_kwargs['pip_requirements'] = [
'transformers',
'torch',
]
mlflow_logger.save_model(**model_saving_kwargs)

# Upload the license file generated by mlflow during the model saving.
license_filename = _maybe_get_license_filename(
local_save_path,
self.pretrained_model_name,
else:
register_save_dir = os.path.join(
temp_save_dir,
'register_save',
)
if license_filename is not None:
mlflow_logger._mlflow_client.log_artifact(
mlflow_logger._run_id,
os.path.join(local_save_path, license_filename),
)

self.pre_register_edit(local_save_path,)

# Save the monitor process to be restored after registering the model.
if hasattr(mlflow_logger, 'monitor_process'):
monitor_process = mlflow_logger.monitor_process # type: ignore
mlflow_logger.monitor_process = None # type: ignore
else:
monitor_process = None

# Spawn a new process to register the model.
process = SpawnProcess(
target=_register_model_with_run_id_multiprocess,
kwargs={
'mlflow_logger':
mlflow_logger,
'composer_logging_level':
logging.getLogger('composer').level,
'model_uri':
local_save_path,
'name':
self.mlflow_registered_model_name,
'await_creation_for':
3600,
},
assert new_model_instance is not None
new_model_instance = self.transform_model_pre_registration(
new_model_instance,
)
process.start()

# Restore the monitor process.
if monitor_process is not None:
mlflow_logger.monitor_process = monitor_process # type: ignore
self.register_processes.append(process)

# Save the temporary directory to be cleaned up later.
if use_temp_dir:
self.temp_save_dir = temp_save_dir
new_model_instance.save_pretrained(register_save_dir)
if original_tokenizer:
original_tokenizer.save_pretrained(register_save_dir)

self.pre_register_edit(register_save_dir)

for mlflow_logger in self.mlflow_loggers:
if self.mlflow_registered_model_name:
log.debug(
f'Registering model to UC at {mlflow_logger.model_registry_prefix}.{self.mlflow_registered_model_name}',
)

# Save the monitor process to be restored after registering the model.
with _monitor_process_saver(mlflow_logger):
process = SpawnProcess(
target=_log_model_with_multi_process,
kwargs={
'mlflow_logger':
mlflow_logger,
'python_logging_level':
logging.getLogger('llmfoundry').level,
'transformers_model':
register_save_dir,
'artifact_path':
'final_model_checkpoint',
'pretrained_model_name':
self.pretrained_model_name,
'registered_model_name':
self.mlflow_registered_model_name,
'await_registration_for':
3600,
'mlflow_logging_config':
self.mlflow_logging_config,
},
)

process.start()
self.register_processes.append(process)

# Save the temporary directory to be cleaned up later.
if use_temp_dir:
self.temp_save_dir = temp_save_dir
else:
# Clean up the temporary directory if we don't need to register to mlflow.
if use_temp_dir:
shutil.rmtree(temp_save_dir)
dist.barrier()

def _save_and_register_peft_model(
self,
state: State,
new_model_instance: Any,
original_tokenizer: Optional[Any],
save_dir: str,
):
new_model_instance = self.transform_model_pre_registration(
new_model_instance,
)
components = {'model': new_model_instance}
if original_tokenizer is not None:
components['tokenizer'] = original_tokenizer

log.debug('Logging Hugging Face model to MLFlow')
for i, mlflow_logger in enumerate(self.mlflow_loggers):
log.debug(
f'Registering model to UC at {mlflow_logger.model_registry_prefix}.{self.mlflow_registered_model_name}',
)

local_save_path = str(Path(save_dir) / f'mlflow_save_{i}',)

# TODO: Remove after mlflow fixes the bug that makes this necessary
import mlflow
import mlflow.store
mlflow.store._unity_catalog.registry.rest_store.get_feature_dependencies = lambda *args, **kwargs: ''

model_saving_kwargs: dict[str, Any] = {
'path': local_save_path,
}
model_saving_kwargs['flavor'] = 'peft'
model_saving_kwargs['save_pretrained_dir'] = save_dir
model_saving_kwargs['metadata'] = self.mlflow_logging_config[
'metadata']

context_manager = te.onnx_export(
True,
) if is_te_imported and state.precision == Precision.AMP_FP8 else contextlib.nullcontext(
)
with context_manager:
# Add the pip requirements directly to avoid mlflow
# attempting to run inference on the model
model_saving_kwargs['pip_requirements'] = [
'transformers',
'torch',
]
mlflow_logger.save_model(**model_saving_kwargs)

# Upload the license file generated by mlflow during the model saving.
# Get and log the license file.
license_filename = _maybe_get_license_filename(
local_save_path,
self.pretrained_model_name,
)
if license_filename is not None:
mlflow_logger._mlflow_client.log_artifact(
mlflow_logger._run_id,
os.path.join(local_save_path, license_filename),
)

self.pre_register_edit(local_save_path)

with _monitor_process_saver(mlflow_logger):
process = SpawnProcess(
target=_register_model_with_run_id_multiprocess,
kwargs={
'mlflow_logger':
mlflow_logger,
'composer_logging_level':
logging.getLogger('composer').level,
'model_uri':
local_save_path,
'name':
self.mlflow_registered_model_name,
'await_creation_for':
3600,
},
)
process.start()
self.register_processes.append(process)
Loading

0 comments on commit 2ce6296

Please sign in to comment.