From 16350657ee32765f1ad646e696a004a0627ff0bc Mon Sep 17 00:00:00 2001 From: Omar Elayan Date: Tue, 10 Dec 2024 14:10:51 +0200 Subject: [PATCH] [inf] Add config var to enable keeping module on host Using keep_module_on_host config var will let us control if the loaded checkpoints to model parameters will be moved to the device or stay on host --- deepspeed/inference/config.py | 9 +++++++ deepspeed/inference/engine.py | 2 +- deepspeed/module_inject/auto_tp.py | 33 ++++++++++++++++------- deepspeed/module_inject/replace_module.py | 3 ++- tests/unit/inference/test_inference.py | 13 +++++++-- 5 files changed, 46 insertions(+), 14 deletions(-) diff --git a/deepspeed/inference/config.py b/deepspeed/inference/config.py index c7c7684fff79..365536b3dd85 100644 --- a/deepspeed/inference/config.py +++ b/deepspeed/inference/config.py @@ -171,6 +171,15 @@ class DeepSpeedInferenceConfig(DeepSpeedConfigModel): values for :any:`DeepSpeedMoEConfig`. """ + keep_module_on_host: bool = False + """ + When loading checkpoints to model parameters, they are moved to the device. In large very models + this might fill the device and cause OOM. Setting this flag to true, will keep checkpoints on + host and not move them directly to the device (giving an option to quantize checkpoint data before + moving it to the device for example). + Set only for models with injection policies and auto TP. + """ + quant: QuantizationConfig = {} """ NOTE: only works for int8 dtype. diff --git a/deepspeed/inference/engine.py b/deepspeed/inference/engine.py index cfca1ff4fe4c..f9eb264adc7b 100755 --- a/deepspeed/inference/engine.py +++ b/deepspeed/inference/engine.py @@ -170,7 +170,7 @@ def __init__(self, model, config): is_meta_device = hasattr(self.module, "device") and self.module.device.type == 'meta' if is_meta_device: self.module.to_empty(device=device) - else: + elif not config.keep_module_on_host: self.module.to(device) if config.tensor_parallel.tp_size > 1: diff --git a/deepspeed/module_inject/auto_tp.py b/deepspeed/module_inject/auto_tp.py index 221d490a37d2..49a741a1e814 100755 --- a/deepspeed/module_inject/auto_tp.py +++ b/deepspeed/module_inject/auto_tp.py @@ -17,9 +17,11 @@ from deepspeed.module_inject.tp_shard import get_shard_size, get_shard_size_list -def move(tensor, device): +def move(tensor, device, keep_module_on_host=False): if tensor.is_meta: - return torch.empty_like(tensor, device=device) + return torch.empty_like(tensor, device='cpu' if keep_module_on_host else device) + elif keep_module_on_host: + return tensor.to('cpu') if device != 'cpu' else tensor else: # Using new tensors help in freeing memory (after split for example) was done before by calling clone(). # Using copy=True instead of clone() will help in case of cpu --> cpu. @@ -188,7 +190,14 @@ def load(module, state_dict, prefix, mp_group=None): class AutoTP(): - def __init__(self, module, all_reduce_linears, prefix, state_dict, linear_layer_setting, orig_layer_impl): + def __init__(self, + module, + all_reduce_linears, + prefix, + state_dict, + linear_layer_setting, + orig_layer_impl, + keep_module_on_host=False): self.module = module self.all_reduce_linears = all_reduce_linears self.prefix = prefix @@ -200,6 +209,7 @@ def __init__(self, module, all_reduce_linears, prefix, state_dict, linear_layer_ self.orig_layer_impl = orig_layer_impl self.linear_policies = None self.conv_linear_layer = False + self.keep_module_on_host = keep_module_on_host def in_module_list(module, module_list): for item in module_list: @@ -359,7 +369,8 @@ def _replace(self, child, name, conv_linear_layer): data = child.weight.data.split(get_shard_size_list( weight_shape[0] if self.conv_linear_layer else weight_shape[1], self.mp_size, name), dim=1) - data_dc = move(data[mp_replace.gpu_index], get_accelerator().current_device_name()).detach() + data_dc = move(data[mp_replace.gpu_index], + get_accelerator().current_device_name(), self.keep_module_on_host).detach() del data setattr(child, "replaced", True) @@ -368,9 +379,9 @@ def _replace(self, child, name, conv_linear_layer): torch.nn.parameter.Parameter(data_dc, requires_grad=False), dist.get_rank(), dist.get_world_size(), child.bias if child.bias is None else torch.nn.parameter.Parameter( move(child.bias, - get_accelerator().current_device_name())), self.mp_group) + get_accelerator().current_device_name(), self.keep_module_on_host)), self.mp_group) return LinearAllreduce(torch.nn.parameter.Parameter(data_dc, requires_grad=False), child.bias if child.bias is None else \ - torch.nn.parameter.Parameter(move(child.bias, get_accelerator().current_device_name())), self.mp_group) + torch.nn.parameter.Parameter(move(child.bias, get_accelerator().current_device_name(), self.keep_module_on_host)), self.mp_group) else: # if conv_linear_layer [weight_shape[1], weight_shape[0] // mp_size] @@ -383,22 +394,24 @@ def _replace(self, child, name, conv_linear_layer): #The copy is a regular copy, The shape of dst and src is the same data_dc = move( prepare_tp_fused_qkvw(self.module, child.weight.data, self.mp_size, mp_replace.gpu_index), - get_accelerator().current_device_name()) + get_accelerator().current_device_name(), self.keep_module_on_host) bias_data_dc = None if child.bias is None else move( prepare_tp_fused_qkvw(self.module, child.bias.data, self.mp_size, mp_replace.gpu_index), - get_accelerator().current_device_name()) + get_accelerator().current_device_name(), self.keep_module_on_host) else: data = child.weight.data.split(get_shard_size_list(weight_shape[0], self.mp_size, name), dim=1 if self.conv_linear_layer else 0) - data_dc = move(data[mp_replace.gpu_index], get_accelerator().current_device_name()).detach() + data_dc = move(data[mp_replace.gpu_index], + get_accelerator().current_device_name(), self.keep_module_on_host).detach() del data if child.bias is not None: bias_data = child.bias.data.split(get_shard_size_list( weight_shape[1] if self.conv_linear_layer else weight_shape[0], self.mp_size, name), dim=0) - bias_data = move(bias_data[mp_replace.gpu_index], get_accelerator().current_device_name()) + bias_data = move(bias_data[mp_replace.gpu_index], + get_accelerator().current_device_name(), self.keep_module_on_host) bias_data_dc = torch.nn.parameter.Parameter(bias_data, requires_grad=False) del bias_data else: diff --git a/deepspeed/module_inject/replace_module.py b/deepspeed/module_inject/replace_module.py index 7afe6ca903fb..f7129165756d 100644 --- a/deepspeed/module_inject/replace_module.py +++ b/deepspeed/module_inject/replace_module.py @@ -268,7 +268,8 @@ def replace_wo_policy(module, all_reduce_linears, prefix="", state_dict=None): #mp_replace = ReplaceWithTensorSlicing(mp_group=config.tensor_parallel.tp_group) # 1. Create AutoTP object - _autotp = AutoTP(module, all_reduce_linears, prefix, state_dict, linear_layer_setting, orig_layer_impl) + _autotp = AutoTP(module, all_reduce_linears, prefix, state_dict, linear_layer_setting, orig_layer_impl, + config.keep_module_on_host) # 2. Set the tensor parallelism config _autotp.set_tensor_parallel_config(config.tensor_parallel.tp_size, config.tensor_parallel.tp_group) diff --git a/tests/unit/inference/test_inference.py b/tests/unit/inference/test_inference.py index 9b563523dbeb..fa57466d14b0 100644 --- a/tests/unit/inference/test_inference.py +++ b/tests/unit/inference/test_inference.py @@ -554,6 +554,7 @@ def test(self, model_w_task, injection_policy, query, inf_kwargs, assert_fn, dty @pytest.mark.seq_inference +@pytest.mark.parametrize('keep_module_on_host', [True, False]) @pytest.mark.parametrize( "model_w_task", [("Helsinki-NLP/opus-mt-en-de", "translation"), ("Salesforce/codegen-350M-mono", "text-generation")], @@ -570,6 +571,7 @@ def test( inf_kwargs, assert_fn, dtype, + keep_module_on_host, ): invalid_test_msg = validate_test(model_w_task, dtype, enable_cuda_graph=False, enable_triton=False) if invalid_test_msg: @@ -592,7 +594,10 @@ def test( framework="pt") bs_output = pipe(query, **inf_kwargs) - pipe.model = deepspeed.init_inference(pipe.model, mp_size=world_size, dtype=dtype) + pipe.model = deepspeed.init_inference(pipe.model, + mp_size=world_size, + dtype=dtype, + keep_module_on_host=keep_module_on_host) ds_output = pipe(query, **inf_kwargs) print(local_rank, "baseline", bs_output) @@ -607,6 +612,7 @@ def test_odd_world_size( inf_kwargs, assert_fn, dtype, + keep_module_on_host, ): invalid_test_msg = validate_test(model_w_task, dtype, enable_cuda_graph=False, enable_triton=False) if invalid_test_msg: @@ -624,7 +630,10 @@ def test_odd_world_size( framework="pt") bs_output = pipe(query, **inf_kwargs) - pipe.model = deepspeed.init_inference(pipe.model, mp_size=world_size, dtype=dtype) + pipe.model = deepspeed.init_inference(pipe.model, + mp_size=world_size, + dtype=dtype, + keep_module_on_host=keep_module_on_host) ds_output = pipe(query, **inf_kwargs) print(local_rank, "baseline", bs_output)