diff --git a/deepspeed/runtime/zero/partition_parameters.py b/deepspeed/runtime/zero/partition_parameters.py index 992dcd446ad6..030a050b88e2 100755 --- a/deepspeed/runtime/zero/partition_parameters.py +++ b/deepspeed/runtime/zero/partition_parameters.py @@ -312,6 +312,7 @@ def __init__(self, enabled=True, mem_efficient_linear=True, ds_config=None, dtyp torch.half, torch.bfloat16, torch.float ], f"Invalid data type {self.dtype}, allowed values are [torch.half, torch.bfloat16, torch.float]" self.wrapped_cls = set() + self.skip_init_depth = 0 self.quantized_initialization = None if ds_config is not None and ds_config.weight_quantization_config and ds_config.weight_quantization_config.quantized_initialization: @@ -435,6 +436,51 @@ def wrapped_apply(module: Module, fn_to_apply: Callable) -> None: return wrapped_apply + def hook_for_skip_init(module): + # this function is intended for handling the logic of torch.nn.utils.skip_init + # skip_init:module_cls(*args, **kwargs).to_empty(device=final_device), where kwargs['device']='meta' + # the function call occurs between module_cls(*args, **kwargs) and to_empty(device=final_device). + def partition_after_empty_init(f): + + @functools.wraps(f) + def wrapper(module, *args, **kwargs): + _module = f(module, *args, **kwargs) + # here is the post-hook for module.apply(empty_like...) + # after module.apply(empty_like...), the module has completed its empty init on real device + # since skip_init won't involve any computations or weight adjustments, we can directly utilize post_init + self._post_init_method(_module) + return _module + + return wrapper + + def post_wrapper_to_empty(f): + # append some wrapper restoration after to_empty() call + @functools.wraps(f) + def wrapper(*args, **kwargs): + res = f(*args, **kwargs) + # restore _apply hook + for subclass in get_all_subclasses(torch.nn.modules.module.Module): + _disable_class_apply(subclass) + # self restore + module.to_empty = f + return res + + return wrapper + + def _enable_class_apply(cls): + cls._old_apply_of_skip_init_hook = cls._apply + cls._apply = partition_after_empty_init(cls._apply) + + def _disable_class_apply(cls): + cls._apply = cls._old_apply_of_skip_init_hook + + # add hooks for to_empty: apply_(empty_like) + for subclass in get_all_subclasses(torch.nn.modules.module.Module): + _enable_class_apply(subclass) + + # add a restore hook when exiting skip_init + module.to_empty = post_wrapper_to_empty(module.to_empty) + def partition_after(f): @functools.wraps(f) @@ -456,16 +502,25 @@ def wrapper(module, *args, **kwargs): is_child_module = True setattr(module, "_ds_child_entered", True) - f(module, *args, **kwargs) + init_on_meta = 'device' in kwargs and kwargs['device'] == 'meta' + if init_on_meta: + self.skip_init_depth += 1 + f(module, *args, **kwargs) + if init_on_meta and self.skip_init_depth == 1: + # check and handle the logic of empty_init + hook_for_skip_init(module) if is_child_module: # child's __init__ is done, now we can run a single post_init on the child object delattr(module, "_ds_child_entered") print_rank_0(f'Running post_init for {module.__class__.__name__}', force=False) - self._post_init_method(module) + if self.skip_init_depth == 0: + self._post_init_method(module) print_rank_0(f'After initializing followed by post init for {module.__class__.__name__}', force=False) + if init_on_meta: + self.skip_init_depth -= 1 return wrapper @@ -512,7 +567,6 @@ def _init_subclass(cls, **kwargs): self.patched = True def unpatch_init_and_builtins(self): - if self.patched: def _disable_class(cls): diff --git a/tests/unit/runtime/zero/test_zero.py b/tests/unit/runtime/zero/test_zero.py index 6d66ff704416..bc31e3b9a968 100644 --- a/tests/unit/runtime/zero/test_zero.py +++ b/tests/unit/runtime/zero/test_zero.py @@ -14,6 +14,7 @@ from torch.nn.modules.container import ModuleList from torch.nn.modules.loss import L1Loss from torch.nn.parameter import Parameter +from torch.nn.utils import skip_init from unit.common import DistributedTest from unit.simple_model import SimpleModel, random_dataloader @@ -1193,6 +1194,82 @@ def create_tensor(vals): _assert_partition_status(ds_engine, {ZeroParamStatus.NOT_AVAILABLE}) +class TestParamPartitioningSkipInit(DistributedTest): + world_size = 2 + + def test(self): + config_dict = { + "train_batch_size": 4, + "steps_per_print": 1, + "optimizer": { + "type": "Adam", + "params": { + "lr": 1e-4 + } + }, + "fp16": { + "enabled": True + }, + "zero_optimization": { + "stage": 3 + }, + } + hidden_dim = 10 + + class SubModel(torch.nn.Module): + + def __init__(self, input_size, output_size, dropout_prob=0.5, device=None): + super(SubModel, self).__init__() + self.linear = torch.nn.Linear(input_size, output_size, device=device) + self.dropout = torch.nn.Dropout(dropout_prob) + self.module_list = torch.nn.ModuleList([torch.nn.Linear(input_size, output_size, device=device)]) + + def forward(self, x): + x = self.linear(x) + x = self.dropout(x) + x = self.module_list[0](x) + return x + + class MyModel(torch.nn.Module): + + def __init__(self, hidden_dim): + super(MyModel, self).__init__() + self.l1 = skip_init(Linear, hidden_dim, hidden_dim) + self.l2 = skip_init(SubModel, hidden_dim, hidden_dim) + self.l3 = torch.nn.Linear(hidden_dim, hidden_dim) + self.cel = torch.nn.CrossEntropyLoss() + self.l4 = skip_init(SubModel, hidden_dim, hidden_dim) + + def forward(self, x, y): + x = self.l1(x) + x = self.l2(x) + x = self.l3(x) + x = self.l4(x) + loss = self.cel(x, y) + val = [x, loss] + return val + + with deepspeed.zero.Init(config=config_dict): + model = MyModel(hidden_dim) + world_size = dist.get_world_size() + ds_tensor_numel = math.ceil(hidden_dim * hidden_dim / world_size) + assert model.l1.weight.ds_tensor.numel() == ds_tensor_numel + assert model.l2.linear.weight.ds_tensor.numel() == ds_tensor_numel + assert model.l2.module_list[0].weight.ds_tensor.numel() == ds_tensor_numel + assert model.l3.weight.ds_tensor.numel() == ds_tensor_numel + assert model.l4.linear.weight.ds_tensor.numel() == ds_tensor_numel + assert model.l4.module_list[0].weight.ds_tensor.numel() == ds_tensor_numel + + model, _, _, _ = deepspeed.initialize(model=model, model_parameters=model.parameters(), config=config_dict) + data_loader = random_dataloader(model=model, total_samples=50, hidden_dim=hidden_dim, device=model.device) + dist.barrier() + for n, batch in enumerate(data_loader): + loss = model(batch[0], batch[1]) + loss = loss[1] + model.backward(loss) + model.step() + + class TestZeroOffloadStage1(DistributedTest): world_size = 2