Skip to content

Commit

Permalink
Add barriers and manually clean up tempdir
Browse files Browse the repository at this point in the history
  • Loading branch information
irenedea committed Mar 31, 2024
1 parent c9c22a7 commit 1cdc78d
Showing 1 changed file with 173 additions and 161 deletions.
334 changes: 173 additions & 161 deletions llmfoundry/callbacks/hf_checkpointer.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
import math
import os
import re
import shutil
import tempfile
import time
from multiprocessing.context import SpawnProcess
Expand Down Expand Up @@ -192,7 +193,10 @@ def __init__(

self.last_checkpoint_batch: Optional[Time] = None
self.mlflow_loggers = []

self.child_processes: List[SpawnProcess] = []
# Temporary save directory used by child_processes.
self.temp_save_dir = None

def run_event(self, event: Event, state: State, logger: Logger) -> None:
# The interval scheduler handles only returning True for the appropriate events
Expand Down Expand Up @@ -230,6 +234,10 @@ def run_event(self, event: Event, state: State, logger: Logger) -> None:
while not self._all_child_processes_done():
time.sleep(2)

# Clean up temporary save directory; all processes are done with it.
if self.temp_save_dir is not None:
shutil.rmtree(self.temp_save_dir)

def _is_last_batch(self, state: State):
elapsed_duration = state.get_elapsed_duration()
if elapsed_duration is not None and elapsed_duration >= 1.0:
Expand All @@ -246,6 +254,8 @@ def _is_last_batch(self, state: State):
return False

def _all_child_processes_done(self) -> bool:
if len(self.child_processes) == 0:
return True
not_done = any(process.is_alive() for process in self.child_processes)
x = torch.tensor(1 if not_done else 0).to(device='cuda')
dist.all_reduce(x, reduce_operation='MAX')
Expand All @@ -268,170 +278,172 @@ def _save_checkpoint(self, state: State, logger: Logger):
Path(self.save_dir_format_str) /
self.huggingface_folder_name_fstr), state.run_name,
state.timestamp)
dir_context_mgr = tempfile.TemporaryDirectory(
) if self.remote_ud is not None else contextlib.nullcontext(
enter_result=save_dir)

with dir_context_mgr as temp_save_dir:
assert isinstance(temp_save_dir,
str) # pyright doesn't know about enter_result

log.debug('Gathering state dict')
from torch.distributed.fsdp import FullyShardedDataParallel as FSDP

if state.is_model_ddp:
composer_model = state.model.module
original_model: PreTrainedModel = state.model.module.model
state_dict_model = state.model.module.model
original_tokenizer = state.model.module.tokenizer
elif isinstance(state.model.model, FSDP):
composer_model = state.model
original_model: PreTrainedModel = state.model.model.module
state_dict_model = state.model.model
original_tokenizer = state.model.tokenizer

temp_save_dir = tempfile.mkdtemp(
) if self.remote_ud is not None else save_dir

log.debug('Gathering state dict')
from torch.distributed.fsdp import FullyShardedDataParallel as FSDP

if state.is_model_ddp:
composer_model = state.model.module
original_model: PreTrainedModel = state.model.module.model
state_dict_model = state.model.module.model
original_tokenizer = state.model.module.tokenizer
elif isinstance(state.model.model, FSDP):
composer_model = state.model
original_model: PreTrainedModel = state.model.model.module
state_dict_model = state.model.model
original_tokenizer = state.model.tokenizer
else:
composer_model = state.model
original_model: PreTrainedModel = state.model.model
state_dict_model = state.model.model
original_tokenizer = state.model.tokenizer

state_dict_context = fsdp_state_dict_type_context(
original_model,
state_dict_type='full') if ((not state.is_model_ddp) and isinstance(
state_dict_model, FSDP)) else contextlib.nullcontext()

with state_dict_context:
state_dict = state_dict_model.state_dict()

# convert the state dict to the requested precision
for k, v in state_dict.items():
if isinstance(v, torch.Tensor):
state_dict[k] = v.to(dtype=self.dtype)

new_model_instance = None # Need this for pyright because variable could be unbound

if dist.get_global_rank() == 0:
log.debug('Saving Hugging Face checkpoint in global rank 0')

copied_config = copy.deepcopy(original_model.config)
if copied_config.model_type == 'mpt':
copied_config.attn_config['attn_impl'] = 'torch'
copied_config.init_device = 'cpu'

log.debug(f'Creating new model instance')

if composer_model.using_peft:
# We don't use meta here because the state dict does not contain the full
# model, only the adapter weights.
active_adapter = original_model.active_adapter
base_model = original_model.get_base_model()
new_base_model_instance = type(base_model)(copied_config)

new_model_instance = type(original_model)(
new_base_model_instance,
original_model.peft_config[active_adapter])
new_model_instance.to(dtype=self.dtype)
else:
composer_model = state.model
original_model: PreTrainedModel = state.model.model
state_dict_model = state.model.model
original_tokenizer = state.model.tokenizer

state_dict_context = fsdp_state_dict_type_context(
original_model, state_dict_type='full') if (
(not state.is_model_ddp) and isinstance(
state_dict_model, FSDP)) else contextlib.nullcontext()

with state_dict_context:
state_dict = state_dict_model.state_dict()

# convert the state dict to the requested precision
for k, v in state_dict.items():
if isinstance(v, torch.Tensor):
state_dict[k] = v.to(dtype=self.dtype)

if dist.get_global_rank() == 0:
log.debug('Saving Hugging Face checkpoint in global rank 0')

copied_config = copy.deepcopy(original_model.config)
if copied_config.model_type == 'mpt':
copied_config.attn_config['attn_impl'] = 'torch'
copied_config.init_device = 'cpu'

log.debug(f'Creating new model instance')

if composer_model.using_peft:
# We don't use meta here because the state dict does not contain the full
# model, only the adapter weights.
active_adapter = original_model.active_adapter
base_model = original_model.get_base_model()
new_base_model_instance = type(base_model)(copied_config)

new_model_instance = type(original_model)(
new_base_model_instance,
original_model.peft_config[active_adapter])
new_model_instance.to(dtype=self.dtype)
else:
# First create the model instance on meta device to avoid the
# initialization cost.
with init_empty_weights():
new_model_instance = type(original_model)(copied_config)

# Then load the state dict in with "assign" so that the state dict
# is loaded properly even though the model is initially on meta device.
new_model_instance.load_state_dict(state_dict, assign=True)
del state_dict

log.debug('Saving Hugging Face checkpoint to disk')
new_model_instance.save_pretrained(temp_save_dir)
if original_tokenizer is not None:
assert isinstance(original_tokenizer,
PreTrainedTokenizerBase)
original_tokenizer.save_pretrained(temp_save_dir)

# Only need to edit files for MPT because it has custom code
if original_model.config.model_type == 'mpt':
log.debug('Editing MPT files for HuggingFace compatibility')
edit_files_for_hf_compatibility(
temp_save_dir,
self.flatten_imports,
# First create the model instance on meta device to avoid the
# initialization cost.
with init_empty_weights():
new_model_instance = type(original_model)(copied_config)

# Then load the state dict in with "assign" so that the state dict
# is loaded properly even though the model is initially on meta device.
new_model_instance.load_state_dict(state_dict, assign=True)
del state_dict

log.debug('Saving Hugging Face checkpoint to disk')
new_model_instance.save_pretrained(temp_save_dir)
if original_tokenizer is not None:
assert isinstance(original_tokenizer, PreTrainedTokenizerBase)
original_tokenizer.save_pretrained(temp_save_dir)

# Only need to edit files for MPT because it has custom code
if original_model.config.model_type == 'mpt':
log.debug('Editing MPT files for HuggingFace compatibility')
edit_files_for_hf_compatibility(
temp_save_dir,
self.flatten_imports,
)

if self.remote_ud is not None:
for filename in os.listdir(temp_save_dir):
remote_file_name = os.path.join(save_dir, filename)
remote_file_uri = self.remote_ud.remote_backend.get_uri(
remote_file_name)
log.info(
f'Uploading HuggingFace formatted checkpoint to {remote_file_uri}'
)
self.remote_ud.upload_file(
state=state,
remote_file_name=remote_file_name,
file_path=Path(os.path.join(temp_save_dir, filename)),
overwrite=self.overwrite,
)

if self.remote_ud is not None:
for filename in os.listdir(temp_save_dir):
remote_file_name = os.path.join(save_dir, filename)
remote_file_uri = self.remote_ud.remote_backend.get_uri(
remote_file_name)
log.info(
f'Uploading HuggingFace formatted checkpoint to {remote_file_uri}'
)
self.remote_ud.upload_file(
state=state,
remote_file_name=remote_file_name,
file_path=Path(os.path.join(temp_save_dir,
filename)),
overwrite=self.overwrite,
)
dist.barrier()

if self.mlflow_registered_model_name and self._is_last_batch(
state):
components = {'model': new_model_instance}
if original_tokenizer is not None:
components['tokenizer'] = original_tokenizer
if dist.get_global_rank() == 0:
if self.mlflow_registered_model_name and self._is_last_batch(state):
components = {'model': new_model_instance}
if original_tokenizer is not None:
components['tokenizer'] = original_tokenizer

log.debug('Logging Hugging Face model to MLFlow')
for i, mlflow_logger in enumerate(self.mlflow_loggers):
log.debug(
f'Registering model to UC at {mlflow_logger.model_registry_prefix}.{self.mlflow_registered_model_name}'
log.debug('Logging Hugging Face model to MLFlow')
for i, mlflow_logger in enumerate(self.mlflow_loggers):
log.debug(
f'Registering model to UC at {mlflow_logger.model_registry_prefix}.{self.mlflow_registered_model_name}'
)
local_save_path = str(
Path(temp_save_dir) / f'mlflow_save_{i}')

# TODO: Remove after mlflow fixes the bug that makes this necessary
import mlflow
mlflow.store._unity_catalog.registry.rest_store.get_feature_dependencies = lambda *args, **kwargs: ''
model_saving_kwargs: Dict[str, Any] = {
'path': local_save_path
}
if composer_model.using_peft:
model_saving_kwargs['flavor'] = 'peft'
model_saving_kwargs[
'save_pretrained_dir'] = temp_save_dir
model_saving_kwargs[
'metadata'] = self.mlflow_logging_config['metadata']
else:
model_saving_kwargs['flavor'] = 'transformers'
model_saving_kwargs['transformers_model'] = components
model_saving_kwargs.update(self.mlflow_logging_config)

mlflow_logger.save_model(**model_saving_kwargs)

# Upload the license file generated by mlflow during the model saving.
license_filename = _maybe_get_license_filename(
local_save_path,
self.mlflow_logging_config['metadata'].get(
'pretrained_model_name', None))
if license_filename is not None:
mlflow_logger._mlflow_client.log_artifact(
mlflow_logger._run_id,
os.path.join(local_save_path, license_filename),
)
local_save_path = str(
Path(temp_save_dir) / f'mlflow_save_{i}')

# TODO: Remove after mlflow fixes the bug that makes this necessary
import mlflow
mlflow.store._unity_catalog.registry.rest_store.get_feature_dependencies = lambda *args, **kwargs: ''
model_saving_kwargs: Dict[str, Any] = {
'path': local_save_path
}
if composer_model.using_peft:
model_saving_kwargs['flavor'] = 'peft'
model_saving_kwargs[
'save_pretrained_dir'] = temp_save_dir
model_saving_kwargs[
'metadata'] = self.mlflow_logging_config[
'metadata']
else:
model_saving_kwargs['flavor'] = 'transformers'
model_saving_kwargs[
'transformers_model'] = components
model_saving_kwargs.update(
self.mlflow_logging_config)

mlflow_logger.save_model(**model_saving_kwargs)

# Upload the license file generated by mlflow during the model saving.
license_filename = _maybe_get_license_filename(
local_save_path,
self.mlflow_logging_config['metadata'].get(
'pretrained_model_name', None))
if license_filename is not None:
mlflow_logger._mlflow_client.log_artifact(
mlflow_logger._run_id,
os.path.join(local_save_path, license_filename),
)

# Spawn a new process to register the model.
process = SpawnProcess(
target=_register_model_with_run_id_multiprocess,
kwargs={
'mlflow_logger':
mlflow_logger,
'logging_level':
logging.getLogger('composer').level,
'model_uri':
local_save_path,
'name':
self.mlflow_registered_model_name,
'await_creation_for':
3600,
})
process.start()
self.child_processes.append(process)

# Spawn a new process to register the model.
process = SpawnProcess(
target=_register_model_with_run_id_multiprocess,
kwargs={
'mlflow_logger':
mlflow_logger,
'logging_level':
logging.getLogger('composer').level,
'model_uri':
local_save_path,
'name':
self.mlflow_registered_model_name,
'await_creation_for':
3600,
})
process.start()
self.child_processes.append(process)

# Save the temporary directory to be cleaned up later.
self.temp_save_dir = temp_save_dir
else:
# Clean up the temporary directory if we don't need to register to mlflow.
shutil.rmtree(temp_save_dir)
dist.barrier()

0 comments on commit 1cdc78d

Please sign in to comment.