Skip to content

Commit

Permalink
debug OOM - state dict convert during offload
Browse files Browse the repository at this point in the history
  • Loading branch information
irenedea committed Jul 19, 2024
1 parent cf66ad9 commit fb73a88
Showing 1 changed file with 16 additions and 3 deletions.
19 changes: 16 additions & 3 deletions llmfoundry/callbacks/hf_checkpointer.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
from typing import Any, Dict, List, Optional, Sequence, Tuple, Union

import numpy as np
import psutil
import torch
import torch.nn as nn
from composer.core import Callback, Event, Precision, State, Time, TimeUnit
Expand Down Expand Up @@ -394,6 +395,12 @@ def _save_checkpoint(self, state: State, logger: Logger):

gc_cuda()

# you can have the percentage of used RAM
log.debug(f'used memory {psutil.virtual_memory().percent}')

# you can calculate percentage of available memory
log.debug(f'available memory {psutil.virtual_memory().available * 100 / psutil.virtual_memory().total}')

if state.is_model_ddp:
composer_model = state.model.module
original_model: PreTrainedModel = state.model.module.model
Expand Down Expand Up @@ -438,9 +445,6 @@ def dtensor_to_tensor_hook(
else:
state_dict[fqn] = None
del tensor
else:
log.debug(f'Not a DTensor {fqn}')
gc_cuda()

if dist.get_global_rank() != 0:
for fqn in dtensor_fqns:
Expand Down Expand Up @@ -481,6 +485,15 @@ def dtensor_to_tensor_hook(

new_model_instance = None # Need this for pyright because variable could be unbound

gc_cuda()


# you can have the percentage of used RAM
log.debug(f'after gather state dict used memory {psutil.virtual_memory().percent}')

# you can calculate percentage of available memory
log.debug(f'after gather state dict available memory {psutil.virtual_memory().available * 100 / psutil.virtual_memory().total}')

if dist.get_global_rank() == 0:
log.debug('Saving Hugging Face checkpoint in global rank 0')

Expand Down

0 comments on commit fb73a88

Please sign in to comment.