Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix memory leak in fp8 causing OOM (and potentially 3x vRAM usage) #2089

Merged
merged 5 commits into from
Oct 31, 2023
Merged
Show file tree
Hide file tree
Changes from 4 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
59 changes: 29 additions & 30 deletions src/accelerate/accelerator.py
Original file line number Diff line number Diff line change
Expand Up @@ -1357,35 +1357,6 @@ def prepare_model(self, model: torch.nn.Module, device_placement: bool = None, e
" Please rerun your script specifying `--num_processes=1` or by launching with `python {{myscript.py}}`."
)

if (getattr(model, "is_loaded_in_8bit", False) or getattr(model, "is_loaded_in_4bit", False)) and getattr(
model, "hf_device_map", False
):
model_devices = set(model.hf_device_map.values())
if len(model_devices) > 1 and self.distributed_type != DistributedType.NO:
raise ValueError(
"You can't train a model that has been loaded in 8-bit precision on multiple devices in any distributed mode."
" In order to use 8-bit models that have been loaded across multiple GPUs the solution is to use Naive Pipeline Parallelism."
" Therefore you should not specify that you are under any distributed regime in your accelerate config."
)
current_device = list(model_devices)[0]
current_device_index = current_device.index if isinstance(current_device, torch.device) else current_device

if torch.device(current_device_index) != self.device:
# if on the first device (GPU 0) we don't care
if (self.device.index is not None) or (current_device_index != 0):
raise ValueError(
"You can't train a model that has been loaded in 8-bit precision on a different device than the one "
"you're training on. Make sure you loaded the model on the correct device using for example `device_map={'':torch.cuda.current_device()}"
"you're training on. Make sure you loaded the model on the correct device using for example `device_map={'':torch.cuda.current_device() or device_map={'':torch.xpu.current_device()}"
)

if "cpu" in model_devices or "disk" in model_devices:
raise ValueError(
"You can't train a model that has been loaded in 8-bit precision with CPU or disk offload."
)
elif device_placement and not self.verify_device_map(model):
model = model.to(self.device)

if self.native_amp:
model._original_forward = model.forward
model_forward_func = model.forward.__func__ if hasattr(model.forward, "__func__") else model.forward
Expand All @@ -1401,7 +1372,7 @@ def prepare_model(self, model: torch.nn.Module, device_placement: bool = None, e
with torch.no_grad():
convert_model(model)
model._converted_to_transformer_engine = True
model._original_forward = model.forward
#model._original_forward = model.forward

kwargs = self.fp8_recipe_handler.to_kwargs() if self.fp8_recipe_handler is not None else {}
if "fp8_format" in kwargs:
Expand All @@ -1416,6 +1387,34 @@ def prepare_model(self, model: torch.nn.Module, device_placement: bool = None, e
"or higher, compute capability of 8.9 or higher). Will use FP16 instead."
)
model.forward = fp8_autocast(enabled=fp8_enabled, fp8_recipe=fp8_recipe)(model.forward)

if (getattr(model, "is_loaded_in_8bit", False) or getattr(model, "is_loaded_in_4bit", False)) and getattr(
model, "hf_device_map", False
):
model_devices = set(model.hf_device_map.values())
if len(model_devices) > 1 and self.distributed_type != DistributedType.NO:
raise ValueError(
"You can't train a model that has been loaded in 8-bit precision on multiple devices in any distributed mode."
" In order to use 8-bit models that have been loaded across multiple GPUs the solution is to use Naive Pipeline Parallelism."
" Therefore you should not specify that you are under any distributed regime in your accelerate config."
)
current_device = list(model_devices)[0]
current_device_index = current_device.index if isinstance(current_device, torch.device) else current_device

if torch.device(current_device_index) != self.device:
# if on the first device (GPU 0) we don't care
if (self.device.index is not None) or (current_device_index != 0):
raise ValueError(
"You can't train a model that has been loaded in 8-bit precision on a different device than the one "
"you're training on. Make sure you loaded the model on the correct device using for example `device_map={'':torch.cuda.current_device() or device_map={'':torch.xpu.current_device()}"
)

if "cpu" in model_devices or "disk" in model_devices:
raise ValueError(
"You can't train a model that has been loaded in 8-bit precision with CPU or disk offload."
)
elif device_placement and not self.verify_device_map(model):
model = model.to(self.device)
if not evaluation_mode:
if self.distributed_type in (
DistributedType.MULTI_GPU,
Expand Down
16 changes: 8 additions & 8 deletions src/accelerate/utils/transformer_engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,31 +36,31 @@ def convert_model(model, to_transformer_engine=True, _convert_linear=True, _conv
te_module = te.Linear(
module.in_features, module.out_features, bias=has_bias, params_dtype=module.weight.dtype
)
te_module.weight.data = module.weight.data.clone()
module.weight.copy_(te_module.weight)
if has_bias:
te_module.bias.data = module.bias.data.clone()
module.bias.copy_(te_module.bias)

setattr(model, name, te_module)
elif isinstance(module, nn.LayerNorm) and to_transformer_engine and _convert_ln:
te_module = te.LayerNorm(module.normalized_shape[0], eps=module.eps, params_dtype=module.weight.dtype)
te_module.weight.data = module.weight.data.clone()
te_module.bias.data = module.bias.data.clone()
module.weight.copy_(te_module.weight)
module.bias.copy_(te_module.bias)

setattr(model, name, te_module)
elif isinstance(module, te.Linear) and not to_transformer_engine and _convert_linear:
has_bias = module.bias is not None
new_module = nn.Linear(
module.in_features, module.out_features, bias=has_bias, params_dtype=module.weight.dtype
)
new_module.weight.data = module.weight.data.clone()
module.weight.copy_(new_module.weight)
if has_bias:
new_module.bias.data = module.bias.data.clone()
module.bias.copy_(new_module.bias)

setattr(model, name, new_module)
elif isinstance(module, te.LayerNorm) and not to_transformer_engine and _convert_ln:
new_module = nn.LayerNorm(module.normalized_shape[0], eps=module.eps, params_dtype=module.weight.dtype)
new_module.weight.data = module.weight.data.clone()
new_module.bias.data = module.bias.data.clone()
module.weight.copy_(new_module.weight)
module.bias.copy_(new_module.bias)

setattr(model, name, new_module)
else:
Expand Down
Loading