Skip to content

Commit

Permalink
add gc_cuda()
Browse files Browse the repository at this point in the history
  • Loading branch information
irenedea committed Jul 19, 2024
1 parent 51c024d commit cf66ad9
Showing 1 changed file with 14 additions and 3 deletions.
17 changes: 14 additions & 3 deletions llmfoundry/callbacks/hf_checkpointer.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,6 +41,8 @@
from llmfoundry.utils.huggingface_hub_utils import \
edit_files_for_hf_compatibility

from llmfoundry.callbacks.scheduled_gc_callback import gc_cuda

try:
import transformer_engine.pytorch as te
is_te_imported = True
Expand Down Expand Up @@ -390,6 +392,8 @@ def _save_checkpoint(self, state: State, logger: Logger):
log.debug('Gathering state dict')
from torch.distributed.fsdp import FullyShardedDataParallel as FSDP

gc_cuda()

if state.is_model_ddp:
composer_model = state.model.module
original_model: PreTrainedModel = state.model.module.model
Expand Down Expand Up @@ -426,11 +430,18 @@ def dtensor_to_tensor_hook(
tensor = state_dict[fqn]
if isinstance(tensor, DTensor):
dtensor_fqns.append(fqn)
tensor = tensor.full_tensor() # type: ignore
tensor = tensor.full_tensor()
if dist.get_global_rank() == 0:
if cpu_offload:
tensor = tensor.cpu()
state_dict[fqn] = tensor.to(dtype=self.dtype)
tensor = tensor.to(dtype=self.dtype, device=torch.device('cpu'))
state_dict[fqn] = tensor
else:
state_dict[fqn] = None
del tensor
else:
log.debug(f'Not a DTensor {fqn}')
gc_cuda()

if dist.get_global_rank() != 0:
for fqn in dtensor_fqns:
del state_dict[fqn]
Expand Down

0 comments on commit cf66ad9

Please sign in to comment.