Skip to content

Commit

Permalink
params partition for skip_init (#4722)
Browse files Browse the repository at this point in the history
Some models use ```skip_init``` to initialize weights. ```skip_init```
first initializes on a meta device in ```__init__``` of a module and
then uses ```to_empty()```. This conflicts with the deepspeed hook
```module.__init__``` mechanism. it's necessary to wait for
```skip_init``` to finish before executing ```_post_init_method```.
However, the ```from ... import skip_init``` behavior typically occurs
outside the context, there seems to be no good way to directly hook into
```skip_init```. Hence, the approach here is to delay the execution of
```_post_init_method``` to resolve this issue.
Known affected models include HuggingFace models like chatglm2 and
chatglm3."

---------

Co-authored-by: Logan Adams <[email protected]>
Co-authored-by: Olatunji Ruwase <[email protected]>
Co-authored-by: Masahiro Tanaka <[email protected]>
  • Loading branch information
4 people authored Jan 18, 2024
1 parent 870ae04 commit 3110c38
Show file tree
Hide file tree
Showing 2 changed files with 134 additions and 3 deletions.
60 changes: 57 additions & 3 deletions deepspeed/runtime/zero/partition_parameters.py
Original file line number Diff line number Diff line change
Expand Up @@ -312,6 +312,7 @@ def __init__(self, enabled=True, mem_efficient_linear=True, ds_config=None, dtyp
torch.half, torch.bfloat16, torch.float
], f"Invalid data type {self.dtype}, allowed values are [torch.half, torch.bfloat16, torch.float]"
self.wrapped_cls = set()
self.skip_init_depth = 0

self.quantized_initialization = None
if ds_config is not None and ds_config.weight_quantization_config and ds_config.weight_quantization_config.quantized_initialization:
Expand Down Expand Up @@ -435,6 +436,51 @@ def wrapped_apply(module: Module, fn_to_apply: Callable) -> None:

return wrapped_apply

def hook_for_skip_init(module):
# this function is intended for handling the logic of torch.nn.utils.skip_init
# skip_init:module_cls(*args, **kwargs).to_empty(device=final_device), where kwargs['device']='meta'
# the function call occurs between module_cls(*args, **kwargs) and to_empty(device=final_device).
def partition_after_empty_init(f):

@functools.wraps(f)
def wrapper(module, *args, **kwargs):
_module = f(module, *args, **kwargs)
# here is the post-hook for module.apply(empty_like...)
# after module.apply(empty_like...), the module has completed its empty init on real device
# since skip_init won't involve any computations or weight adjustments, we can directly utilize post_init
self._post_init_method(_module)
return _module

return wrapper

def post_wrapper_to_empty(f):
# append some wrapper restoration after to_empty() call
@functools.wraps(f)
def wrapper(*args, **kwargs):
res = f(*args, **kwargs)
# restore _apply hook
for subclass in get_all_subclasses(torch.nn.modules.module.Module):
_disable_class_apply(subclass)
# self restore
module.to_empty = f
return res

return wrapper

def _enable_class_apply(cls):
cls._old_apply_of_skip_init_hook = cls._apply
cls._apply = partition_after_empty_init(cls._apply)

def _disable_class_apply(cls):
cls._apply = cls._old_apply_of_skip_init_hook

# add hooks for to_empty: apply_(empty_like)
for subclass in get_all_subclasses(torch.nn.modules.module.Module):
_enable_class_apply(subclass)

# add a restore hook when exiting skip_init
module.to_empty = post_wrapper_to_empty(module.to_empty)

def partition_after(f):

@functools.wraps(f)
Expand All @@ -456,16 +502,25 @@ def wrapper(module, *args, **kwargs):
is_child_module = True
setattr(module, "_ds_child_entered", True)

f(module, *args, **kwargs)
init_on_meta = 'device' in kwargs and kwargs['device'] == 'meta'
if init_on_meta:
self.skip_init_depth += 1

f(module, *args, **kwargs)
if init_on_meta and self.skip_init_depth == 1:
# check and handle the logic of empty_init
hook_for_skip_init(module)
if is_child_module:
# child's __init__ is done, now we can run a single post_init on the child object
delattr(module, "_ds_child_entered")

print_rank_0(f'Running post_init for {module.__class__.__name__}', force=False)
self._post_init_method(module)
if self.skip_init_depth == 0:
self._post_init_method(module)

print_rank_0(f'After initializing followed by post init for {module.__class__.__name__}', force=False)
if init_on_meta:
self.skip_init_depth -= 1

return wrapper

Expand Down Expand Up @@ -512,7 +567,6 @@ def _init_subclass(cls, **kwargs):
self.patched = True

def unpatch_init_and_builtins(self):

if self.patched:

def _disable_class(cls):
Expand Down
77 changes: 77 additions & 0 deletions tests/unit/runtime/zero/test_zero.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@
from torch.nn.modules.container import ModuleList
from torch.nn.modules.loss import L1Loss
from torch.nn.parameter import Parameter
from torch.nn.utils import skip_init

from unit.common import DistributedTest
from unit.simple_model import SimpleModel, random_dataloader
Expand Down Expand Up @@ -1193,6 +1194,82 @@ def create_tensor(vals):
_assert_partition_status(ds_engine, {ZeroParamStatus.NOT_AVAILABLE})


class TestParamPartitioningSkipInit(DistributedTest):
world_size = 2

def test(self):
config_dict = {
"train_batch_size": 4,
"steps_per_print": 1,
"optimizer": {
"type": "Adam",
"params": {
"lr": 1e-4
}
},
"fp16": {
"enabled": True
},
"zero_optimization": {
"stage": 3
},
}
hidden_dim = 10

class SubModel(torch.nn.Module):

def __init__(self, input_size, output_size, dropout_prob=0.5, device=None):
super(SubModel, self).__init__()
self.linear = torch.nn.Linear(input_size, output_size, device=device)
self.dropout = torch.nn.Dropout(dropout_prob)
self.module_list = torch.nn.ModuleList([torch.nn.Linear(input_size, output_size, device=device)])

def forward(self, x):
x = self.linear(x)
x = self.dropout(x)
x = self.module_list[0](x)
return x

class MyModel(torch.nn.Module):

def __init__(self, hidden_dim):
super(MyModel, self).__init__()
self.l1 = skip_init(Linear, hidden_dim, hidden_dim)
self.l2 = skip_init(SubModel, hidden_dim, hidden_dim)
self.l3 = torch.nn.Linear(hidden_dim, hidden_dim)
self.cel = torch.nn.CrossEntropyLoss()
self.l4 = skip_init(SubModel, hidden_dim, hidden_dim)

def forward(self, x, y):
x = self.l1(x)
x = self.l2(x)
x = self.l3(x)
x = self.l4(x)
loss = self.cel(x, y)
val = [x, loss]
return val

with deepspeed.zero.Init(config=config_dict):
model = MyModel(hidden_dim)
world_size = dist.get_world_size()
ds_tensor_numel = math.ceil(hidden_dim * hidden_dim / world_size)
assert model.l1.weight.ds_tensor.numel() == ds_tensor_numel
assert model.l2.linear.weight.ds_tensor.numel() == ds_tensor_numel
assert model.l2.module_list[0].weight.ds_tensor.numel() == ds_tensor_numel
assert model.l3.weight.ds_tensor.numel() == ds_tensor_numel
assert model.l4.linear.weight.ds_tensor.numel() == ds_tensor_numel
assert model.l4.module_list[0].weight.ds_tensor.numel() == ds_tensor_numel

model, _, _, _ = deepspeed.initialize(model=model, model_parameters=model.parameters(), config=config_dict)
data_loader = random_dataloader(model=model, total_samples=50, hidden_dim=hidden_dim, device=model.device)
dist.barrier()
for n, batch in enumerate(data_loader):
loss = model(batch[0], batch[1])
loss = loss[1]
model.backward(loss)
model.step()


class TestZeroOffloadStage1(DistributedTest):
world_size = 2

Expand Down

0 comments on commit 3110c38

Please sign in to comment.