From fb73a88a7a117ec5665eb9927e61566329b8b688 Mon Sep 17 00:00:00 2001 From: Irene Dea Date: Fri, 19 Jul 2024 22:08:52 +0000 Subject: [PATCH] debug OOM - state dict convert during offload --- llmfoundry/callbacks/hf_checkpointer.py | 19 ++++++++++++++++--- 1 file changed, 16 insertions(+), 3 deletions(-) diff --git a/llmfoundry/callbacks/hf_checkpointer.py b/llmfoundry/callbacks/hf_checkpointer.py index f11f488fba..bfa8ab43d6 100644 --- a/llmfoundry/callbacks/hf_checkpointer.py +++ b/llmfoundry/callbacks/hf_checkpointer.py @@ -15,6 +15,7 @@ from typing import Any, Dict, List, Optional, Sequence, Tuple, Union import numpy as np +import psutil import torch import torch.nn as nn from composer.core import Callback, Event, Precision, State, Time, TimeUnit @@ -394,6 +395,12 @@ def _save_checkpoint(self, state: State, logger: Logger): gc_cuda() + # you can have the percentage of used RAM + log.debug(f'used memory {psutil.virtual_memory().percent}') + + # you can calculate percentage of available memory + log.debug(f'available memory {psutil.virtual_memory().available * 100 / psutil.virtual_memory().total}') + if state.is_model_ddp: composer_model = state.model.module original_model: PreTrainedModel = state.model.module.model @@ -438,9 +445,6 @@ def dtensor_to_tensor_hook( else: state_dict[fqn] = None del tensor - else: - log.debug(f'Not a DTensor {fqn}') - gc_cuda() if dist.get_global_rank() != 0: for fqn in dtensor_fqns: @@ -481,6 +485,15 @@ def dtensor_to_tensor_hook( new_model_instance = None # Need this for pyright because variable could be unbound + gc_cuda() + + + # you can have the percentage of used RAM + log.debug(f'after gather state dict used memory {psutil.virtual_memory().percent}') + + # you can calculate percentage of available memory + log.debug(f'after gather state dict available memory {psutil.virtual_memory().available * 100 / psutil.virtual_memory().total}') + if dist.get_global_rank() == 0: log.debug('Saving Hugging Face checkpoint in global rank 0')