diff --git a/colossalai/booster/plugin/gemini_plugin.py b/colossalai/booster/plugin/gemini_plugin.py index 9c7dc6836c1e..0908fa40dcb8 100644 --- a/colossalai/booster/plugin/gemini_plugin.py +++ b/colossalai/booster/plugin/gemini_plugin.py @@ -10,6 +10,7 @@ from torch.optim import Optimizer from torch.optim.lr_scheduler import _LRScheduler as LRScheduler from torch.utils.data import DataLoader +from torch.distributed.distributed_c10d import _get_default_group from colossalai.checkpoint_io import CheckpointIndexFile, CheckpointIO, GeneralCheckpointIO from colossalai.checkpoint_io.utils import ( @@ -34,8 +35,7 @@ SUPPORTED_PRECISION = ["fp16", "bf16"] PRECISION_STR_TO_DTYPE = {"fp16": torch.half, "bf16": torch.bfloat16} -DP_AXIS = 0 -TP_AXIS = 1 +ZERO_AXIS, DP_AXIS, TP_AXIS = 0, 1, 2 def get_param_info(optim: Optimizer): # Get a backup of necessary information of parameters for future use, which includes: @@ -304,8 +304,8 @@ class GeminiPlugin(DPPluginBase): max_norm (float, optional): max_norm used for `clip_grad_norm`. You should notice that you shall not do clip_grad_norm by yourself when using ZeRO DDP. The ZeRO optimizer will take care of clip_grad_norm. norm_type (float, optional): norm_type used for `clip_grad_norm`. - enable_tensor_parallelism (bool, optional): Whether to use tensor parallelism strategy, which is implemented in Shardformer. Default to False. - tp_size (int, optional): If 'enable_tensor_parallelism' is set to true, please configure 'tp_size' which determines the size of the tensor parallel process group. Default to 1. + tp_size (int, optional): If 'tp_size' is set to be greater than 1, it means using tensor parallelism strategy, which is implemented in Shardformer, 'tp_size' determines the size of the tensor parallel process group. Default to 1. + extra_dp_size (int, optional): If 'extra_dp_size' is set to be greater than 1, it means creating another group to run with a ddp-like strategy. Default to 1. enable_all_optimization (bool, optional): Whether to switch on all the optimizations supported by Shardformer. Currently all the optimization methods include fused normalization, flash attention and JIT. Defaults to False. @@ -347,8 +347,8 @@ def __init__( max_scale: float = 2**32, max_norm: float = 0.0, norm_type: float = 2.0, - enable_tensor_parallelism: bool = False, tp_size: int = 1, + extra_dp_size:int = 1, enable_all_optimization: bool = False, enable_fused_normalization: bool = False, enable_flash_attention: bool = False, @@ -393,7 +393,7 @@ def __init__( max_norm=max_norm, norm_type=norm_type, ) - self.enable_tensor_parallelism = enable_tensor_parallelism + self.enable_tensor_parallelism = tp_size > 1 self.enable_all_optimization = enable_all_optimization self.enable_fused_normalization = enable_fused_normalization self.enable_flash_attention = enable_flash_attention @@ -402,12 +402,17 @@ def __init__( self.enable_sequence_overlap = enable_sequence_overlap self.verbose = verbose - self.tp_size = tp_size if self.enable_tensor_parallelism else 1 - self.dp_size = dist.get_world_size() // self.tp_size - assert self.dp_size > 1, f"The size of the DP group should be greater than 1. Please reduce the TP group size." - self.pg_mesh = ProcessGroupMesh(self.dp_size, self.tp_size) - self.dp_group = self.pg_mesh.get_group_along_axis(DP_AXIS) - self.tp_group = self.pg_mesh.get_group_along_axis(TP_AXIS) + self.tp_size = tp_size + self.extra_dp_size = extra_dp_size + world_size = dist.get_world_size() + self.zero_size = world_size // (self.tp_size * self.extra_dp_size) + assert world_size == (self.tp_size * self.extra_dp_size) * self.zero_size, f"The global group size can't be evenly divided by the subgroup size." + + self.pg_mesh = ProcessGroupMesh(self.zero_size, self.extra_dp_size, self.tp_size) + self.zero_group = self.pg_mesh.get_group_along_axis(ZERO_AXIS) if self.zero_size < world_size else _get_default_group() + self.extra_dp_group = self.pg_mesh.get_group_along_axis(DP_AXIS) if self.extra_dp_size > 1 else None + self.tp_group = self.pg_mesh.get_group_along_axis(TP_AXIS) if self.tp_size > 1 else None + self.shard_config = ShardConfig( tensor_parallel_process_group=self.tp_group, enable_tensor_parallelism=self.enable_tensor_parallelism, @@ -458,7 +463,7 @@ def configure( shardformer = ShardFormer(self.shard_config) model, _ = shardformer.optimize(model) - model = GeminiDDP(model, **self.gemini_config, process_group=self.dp_group, verbose=self.verbose) + model = GeminiDDP(model, **self.gemini_config, zero_group=self.zero_group, extra_dp_group=self.extra_dp_group, verbose=self.verbose) if optimizer is not None and not isinstance(optimizer, OptimizerWrapper): optimizer = GeminiOptimizer( diff --git a/colossalai/zero/gemini/chunk/chunk.py b/colossalai/zero/gemini/chunk/chunk.py index 4ea6cc662025..42a8cdbb34cb 100644 --- a/colossalai/zero/gemini/chunk/chunk.py +++ b/colossalai/zero/gemini/chunk/chunk.py @@ -61,12 +61,13 @@ class Chunk: def __init__( self, chunk_size: int, - process_group: ProcessGroup, + zero_group: ProcessGroup, dtype: torch.dtype, init_device: Optional[torch.device] = None, cpu_shard_init: bool = False, keep_gathered: bool = False, pin_memory: bool = False, + extra_dp_group: ProcessGroup = None, ) -> None: """ Chunk: A container owning a piece of contiguous memory space for tensors @@ -76,7 +77,7 @@ def __init__( Args: chunk_size (int): the number of elements in the chunk - process_group (ProcessGroup): the process group of this chunk + zero_group (ProcessGroup): the process group of this chunk dtype (torch.dtype): the data type of the chunk init_device (torch.device): optional, During the chunk construction process, where the tensor is stored. The default value is None, which is the current GPU @@ -90,9 +91,11 @@ def __init__( self.chunk_size = chunk_size self.utilized_size = 0 - self.torch_pg = process_group + self.torch_pg = zero_group self.pg_size = dist.get_world_size(self.torch_pg) self.pg_rank = dist.get_rank(self.torch_pg) + self.extra_dp_group = extra_dp_group + self.extra_dp_size = dist.get_world_size(self.extra_dp_group) if self.extra_dp_group is not None else 1 # the chunk size should be divisible by the dp degree if not keep_gathered: @@ -384,14 +387,20 @@ def reduce(self): # just move cuda_global_chunk to cuda_shard # the communication is not necessary self.__scatter() + if self.extra_dp_group is not None: + dist.all_reduce(self.cuda_shard, group=self.extra_dp_group) elif self.keep_gathered: # we use all-reduce here dist.all_reduce(self.cuda_global_chunk, group=self.torch_pg) + if self.extra_dp_group is not None: + dist.all_reduce(self.cuda_global_chunk, group=self.extra_dp_group) else: self.cuda_shard = torch.empty(self.shard_size, dtype=self.dtype, device=get_current_device()) input_list = list(torch.chunk(self.cuda_global_chunk, chunks=self.pg_size, dim=0)) dist.reduce_scatter(self.cuda_shard, input_list, group=self.torch_pg) + if self.extra_dp_group is not None: + dist.all_reduce(self.cuda_shard, group=self.extra_dp_group) free_storage(self.cuda_global_chunk) self.is_gathered = False @@ -608,10 +617,11 @@ def init_grad_chunk(self) -> "Chunk": # grad chunk is not initialized grad_chunk = Chunk( chunk_size=self.chunk_size, - process_group=self.torch_pg, + zero_group=self.torch_pg, dtype=self.dtype, keep_gathered=self.keep_gathered, pin_memory=self.pin_memory, + extra_dp_group=self.extra_dp_group, ) grad_chunk.num_tensors = self.num_tensors grad_chunk.utilized_size = self.utilized_size @@ -640,4 +650,4 @@ def init_grad_chunk(self) -> "Chunk": self.grad_chunk.l2_norm = None alloc_storage(self.grad_chunk.cuda_global_chunk) - return self.grad_chunk + return self.grad_chunk \ No newline at end of file diff --git a/colossalai/zero/gemini/chunk/manager.py b/colossalai/zero/gemini/chunk/manager.py index d3c512fe978d..5ad622a13910 100644 --- a/colossalai/zero/gemini/chunk/manager.py +++ b/colossalai/zero/gemini/chunk/manager.py @@ -38,7 +38,8 @@ def register_tensor( tensor: torch.Tensor, group_type: str, config_key: int, - process_group: ProcessGroup, + zero_group: ProcessGroup, + extra_dp_group: ProcessGroup = None, cpu_offload: bool = False, pin_memory: bool = False, ) -> None: @@ -76,15 +77,16 @@ def register_tensor( if tensor.numel() > chunk_size: chunk_size = tensor.numel() - dp_size = dist.get_world_size(process_group) + dp_size = dist.get_world_size(zero_group) chunk_size = chunk_size + (-chunk_size % dp_size) chunk = Chunk( chunk_size=chunk_size, - process_group=process_group, + zero_group=zero_group, dtype=tensor.dtype, cpu_shard_init=cpu_offload, pin_memory=pin_memory, + extra_dp_group=extra_dp_group, **chunk_kwargs, ) @@ -288,4 +290,4 @@ def rearrange_accumulated_grad_chunk(self, chunk: Chunk) -> Chunk: # Release accumulated_grad free_storage(accumulated_grad) - return grad_chunk + return grad_chunk \ No newline at end of file diff --git a/colossalai/zero/gemini/gemini_ddp.py b/colossalai/zero/gemini/gemini_ddp.py index ade0a4909902..ff943f4b49e0 100644 --- a/colossalai/zero/gemini/gemini_ddp.py +++ b/colossalai/zero/gemini/gemini_ddp.py @@ -86,9 +86,10 @@ def __init__( strict_ddp_mode: bool = False, scatter_after_inference: bool = True, mixed_precision: torch.dtype = torch.float16, - process_group: Optional[ProcessGroup] = None, + zero_group: Optional[ProcessGroup] = None, memstats: Optional[MemStats] = None, # genimi memory stats master_weights: bool = True, + extra_dp_group: Optional[ProcessGroup] = None, verbose: bool = False, ) -> None: assert mixed_precision in (torch.float16, torch.bfloat16) @@ -105,7 +106,7 @@ def __init__( search_range_m=search_range_m, min_chunk_size_m=min_chunk_size_m, strict_ddp_flag=strict_ddp_mode, - process_group=process_group, + process_group=zero_group, verbose=verbose, ) self.gemini_manager = GeminiManager( @@ -128,7 +129,8 @@ def __init__( self.name2param: Dict[str, nn.Parameter] = dict() self.scatter_after_inference = scatter_after_inference self.mixed_precision = mixed_precision - self.dp_process_group = process_group or _get_default_group() + self.zero_group = zero_group or _get_default_group() + self.extra_dp_group = extra_dp_group self.reuse_fp16_chunk = master_weights self.master_weights = master_weights @@ -377,8 +379,12 @@ def grad_handle(self, p, grad): self.chunk_manager.release_chunk(chunk) if grad_chunk.is_gathered: grad_chunk.cuda_global_chunk.div_(chunk.pg_size) + if self.extra_dp_group is not None: + grad_chunk.cuda_global_chunk.div_(chunk.extra_dp_size) else: grad_chunk.cuda_shard.div_(chunk.pg_size) + if self.extra_dp_group is not None: + grad_chunk.cuda_shard.div_(chunk.extra_dp_size) # check overflow elements self.overflow_counter += grad_chunk.has_inf_or_nan # record l2 norm for gradient clipping. flag is bound to fp16 chunk @@ -733,7 +739,7 @@ def load_parameter(chunk_slice, data): unexpected_keys.append(key) def _init_chunks(self, param_order, strict_ddp_mode: bool, cpu_offload: bool, pin_memory: bool): - dp_world_size = dist.get_world_size(self.dp_process_group) + zero_world_size = dist.get_world_size(self.zero_group) for p in param_order.generate(): self._preprocess_param(p) assert type(p) is ColoParameter @@ -753,8 +759,9 @@ def _init_chunks(self, param_order, strict_ddp_mode: bool, cpu_offload: bool, pi self.chunk_manager.register_tensor( tensor=p, group_type="fp16_param", - config_key=dp_world_size, - process_group=self.dp_process_group, + config_key=zero_world_size, + zero_group=self.zero_group, + extra_dp_group=self.extra_dp_group, cpu_offload=cpu_offload, pin_memory=pin_memory, ) @@ -767,8 +774,9 @@ def _init_chunks(self, param_order, strict_ddp_mode: bool, cpu_offload: bool, pi self.chunk_manager.register_tensor( tensor=fp32_p, group_type="fp32_param", - config_key=dp_world_size, - process_group=self.dp_process_group, + config_key=zero_world_size, + zero_group=self.zero_group, + extra_dp_group=self.extra_dp_group, cpu_offload=cpu_offload, pin_memory=pin_memory, ) @@ -881,4 +889,4 @@ def state_dict_shard( if block is not None: yield block, block_size - yield sharder.current_block, sharder.current_block_size + yield sharder.current_block, sharder.current_block_size \ No newline at end of file diff --git a/tests/test_booster/test_plugin/test_gemini_plugin.py b/tests/test_booster/test_plugin/test_gemini_plugin.py index 97ec0233f766..61debe47b599 100644 --- a/tests/test_booster/test_plugin/test_gemini_plugin.py +++ b/tests/test_booster/test_plugin/test_gemini_plugin.py @@ -1,5 +1,6 @@ from contextlib import nullcontext from typing import Optional +import pytest import torch import torch.distributed as dist @@ -17,14 +18,15 @@ from tests.kit.model_zoo import model_zoo -def run_fn(init_method, model_fn, data_gen_fn, output_transform_fn, enable_tensor_parallelism) -> Optional[str]: +def run_fn(init_method, model_fn, data_gen_fn, output_transform_fn, zero_size, tp_size) -> Optional[str]: try: if init_method == "lazy": ctx = LazyInitContext() else: ctx = nullcontext() - enable_all_optimization = True if enable_tensor_parallelism else False - plugin = GeminiPlugin(max_norm=1.0, initial_scale=2**5, enable_tensor_parallelism=enable_tensor_parallelism, enable_all_optimization=enable_all_optimization) + extra_dp_size = dist.get_world_size() // (zero_size * tp_size) + enable_all_optimization = True if tp_size > 1 else False + plugin = GeminiPlugin(max_norm=1.0, initial_scale=2**5, tp_size=tp_size, extra_dp_size=extra_dp_size, enable_all_optimization=enable_all_optimization) booster = Booster(plugin=plugin) with ctx: model = model_fn() @@ -62,8 +64,9 @@ def run_fn(init_method, model_fn, data_gen_fn, output_transform_fn, enable_tenso @parameterize("subset", ["torchvision", "transformers", "diffusers"]) @parameterize("init_method", ["none"]) -@parameterize("enable_tensor_parallelism", [True, False]) -def check_gemini_plugin(subset: str, init_method: str = "none", enable_tensor_parallelism: bool = True, early_stop: bool = True): +@parameterize("zero_size", [2]) +@parameterize("tp_size", [2]) +def check_gemini_plugin(subset: str, init_method: str = "none", early_stop: bool = True, zero_size: int = 1, tp_size: int = 1): """check gemini plugin over model zoo Args: @@ -125,9 +128,9 @@ def check_gemini_plugin(subset: str, init_method: str = "none", enable_tensor_pa # TODO debug blip2 when using tp, something wrong with shift_logits's shape if "transformers_blip2" in name: - enable_tensor_parallelism = False + tp_size = 1 - err = run_fn(init_method, model_fn, data_gen_fn, output_transform_fn, enable_tensor_parallelism) + err = run_fn(init_method, model_fn, data_gen_fn, output_transform_fn, zero_size, tp_size) torch.cuda.empty_cache() if err is None: passed_models.append(name) @@ -153,6 +156,11 @@ def run_dist(rank, world_size, port, early_stop: bool = True): def test_gemini_plugin(early_stop: bool = True): spawn(run_dist, 4, early_stop=early_stop) +@pytest.mark.largedist +@rerun_if_address_is_in_use() +def test_gemini_plugin_3d(early_stop: bool = True): + spawn(run_dist, 8, early_stop=early_stop) + if __name__ == "__main__": - test_gemini_plugin(early_stop=False) + test_gemini_plugin(early_stop=False) \ No newline at end of file diff --git a/tests/test_checkpoint_io/test_gemini_checkpoint_io.py b/tests/test_checkpoint_io/test_gemini_checkpoint_io.py index 821ce9fbbbd9..8343c5f07e30 100644 --- a/tests/test_checkpoint_io/test_gemini_checkpoint_io.py +++ b/tests/test_checkpoint_io/test_gemini_checkpoint_io.py @@ -37,20 +37,21 @@ @parameterize("placement_config", MODEL_PLACEMENT_CONFIGS) @parameterize("model_name", ["transformers_bert_for_sequence_classification"]) @parameterize("use_safetensors", [False, True]) -@parameterize("enable_tensor_parallelism", [True, False]) -@parameterize("tp_size", [2]) -def exam_state_dict_with_origin(placement_config, model_name, use_safetensors: bool, enable_tensor_parallelism: bool, tp_size: int): +@parameterize("tp_size", [1, 2]) +@parameterize("zero_size", [2]) +def exam_state_dict_with_origin(placement_config, model_name, use_safetensors: bool, tp_size: int, zero_size: int): from transformers import BertForSequenceClassification (model_fn, data_gen_fn, output_transform_fn, _, _) = next(iter(model_zoo.get_sub_registry(model_name).values())) bert_model = model_fn() - enable_all_optimization = True if enable_tensor_parallelism else False + enable_all_optimization = True if tp_size > 1 else False with shared_tempdir() as tempdir: pretrained_path = os.path.join(tempdir, "pretrained") bert_model.config.save_pretrained(save_directory=pretrained_path) - plugin = GeminiPlugin(**placement_config, enable_tensor_parallelism=enable_tensor_parallelism, tp_size=tp_size, enable_all_optimization=enable_all_optimization) + extra_dp_size = dist.get_world_size() // (zero_size * tp_size) + plugin = GeminiPlugin(**placement_config, tp_size=tp_size, enable_all_optimization=enable_all_optimization, extra_dp_size=extra_dp_size) booster = Booster(plugin=plugin) bert_model, _, _, _, _ = booster.boost(bert_model) model_size = sum(p.numel() * p.element_size() for p in bert_model.parameters()) / 1024**2 @@ -69,13 +70,14 @@ def exam_state_dict_with_origin(placement_config, model_name, use_safetensors: b @parameterize("shard", [True, False]) @parameterize("model_name", ["transformers_gpt"]) @parameterize("size_per_shard", [32]) -@parameterize("enable_tensor_parallelism", [True, False]) -@parameterize("tp_size", [2]) -def exam_state_dict(placement_config, shard: bool, model_name: str, size_per_shard: int, enable_tensor_parallelism: bool, tp_size: int): +@parameterize("tp_size", [1, 2]) +@parameterize("zero_size", [2]) +def exam_state_dict(placement_config, shard: bool, model_name: str, size_per_shard: int, tp_size: int, zero_size: int): (model_fn, data_gen_fn, output_transform_fn, _, _) = next(iter(model_zoo.get_sub_registry(model_name).values())) criterion = lambda x: x.mean() - enable_all_optimization = True if enable_tensor_parallelism else False - plugin = GeminiPlugin(**placement_config, precision="fp16", initial_scale=(2**14), enable_tensor_parallelism=enable_tensor_parallelism, tp_size=tp_size, enable_all_optimization=enable_all_optimization) + enable_all_optimization = True if tp_size > 1 else False + extra_dp_size = dist.get_world_size() // (zero_size * tp_size) + plugin = GeminiPlugin(**placement_config, precision="fp16", initial_scale=(2**14), tp_size=tp_size, extra_dp_size=extra_dp_size, enable_all_optimization=enable_all_optimization) booster = Booster(plugin=plugin) model = model_fn() @@ -158,3 +160,9 @@ def run_dist(rank, world_size, port): @rerun_if_address_is_in_use() def test_gemini_ckpIO(world_size): spawn(run_dist, world_size) + +@pytest.mark.largedist +@pytest.mark.parametrize("world_size", [8]) +@rerun_if_address_is_in_use() +def test_gemini_ckpIO_3d(world_size): + spawn(run_dist, world_size) \ No newline at end of file diff --git a/tests/test_shardformer/test_hybrid_parallel_grad_clip_norm/test_amp_optimizer.py b/tests/test_shardformer/test_hybrid_parallel_grad_clip_norm/test_amp_optimizer.py index 9e7336b93b3a..f652d18e9494 100644 --- a/tests/test_shardformer/test_hybrid_parallel_grad_clip_norm/test_amp_optimizer.py +++ b/tests/test_shardformer/test_hybrid_parallel_grad_clip_norm/test_amp_optimizer.py @@ -124,25 +124,6 @@ def check_forward_backward(model_fn, data_gen_fn, output_transform_fn, loss_fn, @parameterize( "test_config", [ - { - "tp_size": 1, - "pp_size": 2, - "num_microbatches": 4, - "enable_all_optimization": True, - "use_lazy_init": True, - "precision": "fp16", - "max_norm": 5, - "initial_scale": 1, - }, - { - "tp_size": 2, - "pp_size": 1, - "enable_all_optimization": True, - "use_lazy_init": False, - "precision": "fp16", - "max_norm": 5, - "initial_scale": 1, - }, { "tp_size": 2, "pp_size": 2, @@ -153,23 +134,6 @@ def check_forward_backward(model_fn, data_gen_fn, output_transform_fn, loss_fn, "max_norm": 5, "initial_scale": 1, }, - { - "tp_size": 1, - "pp_size": 2, - "num_microbatches": 4, - "enable_all_optimization": True, - "use_lazy_init": True, - "precision": "bf16", - "max_norm": 5, - }, - { - "tp_size": 2, - "pp_size": 1, - "enable_all_optimization": True, - "use_lazy_init": False, - "precision": "bf16", - "max_norm": 5, - }, { "tp_size": 2, "pp_size": 2, diff --git a/tests/test_shardformer/test_hybrid_parallel_grad_clip_norm/test_naive_optimizer.py b/tests/test_shardformer/test_hybrid_parallel_grad_clip_norm/test_naive_optimizer.py index b8ead795da76..a749a2966fde 100644 --- a/tests/test_shardformer/test_hybrid_parallel_grad_clip_norm/test_naive_optimizer.py +++ b/tests/test_shardformer/test_hybrid_parallel_grad_clip_norm/test_naive_optimizer.py @@ -102,28 +102,11 @@ def check_forward_backward(model_fn, data_gen_fn, output_transform_fn, loss_fn, @parameterize( "test_config", [ - { - "tp_size": 1, - "pp_size": 2, - "num_microbatches": 4, - "enable_all_optimization": True, - "use_lazy_init": True, - "precision": "fp32", - "max_norm": 5, - }, - { - "tp_size": 2, - "pp_size": 1, - "enable_all_optimization": True, - "use_lazy_init": False, - "precision": "fp32", - "max_norm": 5, - }, { "tp_size": 2, "pp_size": 2, "num_microbatches": 4, - "enable_all_optimization": True, + "enable_all_optimization": False, "use_lazy_init": False, "precision": "fp32", "max_norm": 5, @@ -148,7 +131,7 @@ def run_test(test_config): "tp_size": 2, "pp_size": 2, "num_microbatches": 4, - "enable_all_optimization": True, + "enable_all_optimization": False, "use_lazy_init": False, "precision": "fp32", "max_norm": 5, diff --git a/tests/test_shardformer/test_hybrid_parallel_grad_clip_norm/test_zero_optimizer.py b/tests/test_shardformer/test_hybrid_parallel_grad_clip_norm/test_zero_optimizer.py index 061c702552cf..41f06a4c3888 100644 --- a/tests/test_shardformer/test_hybrid_parallel_grad_clip_norm/test_zero_optimizer.py +++ b/tests/test_shardformer/test_hybrid_parallel_grad_clip_norm/test_zero_optimizer.py @@ -106,17 +106,7 @@ def check_forward_backward(model_fn, data_gen_fn, output_transform_fn, loss_fn, "pp_size": 2, "num_microbatches": 4, "zero_stage": 1, - "enable_all_optimization": True, - "use_lazy_init": True, - "precision": "fp16", - "max_norm": 5, - "initial_scale": 1, - }, - { - "tp_size": 2, - "pp_size": 1, - "zero_stage": 1, - "enable_all_optimization": True, + "enable_all_optimization": False, "use_lazy_init": False, "precision": "fp16", "max_norm": 5, @@ -126,36 +116,17 @@ def check_forward_backward(model_fn, data_gen_fn, output_transform_fn, loss_fn, "tp_size": 2, "pp_size": 1, "zero_stage": 2, - "enable_all_optimization": True, + "enable_all_optimization": False, "use_lazy_init": False, "precision": "fp16", "max_norm": 5, "initial_scale": 1, }, - { - "tp_size": 1, - "pp_size": 2, - "num_microbatches": 4, - "zero_stage": 1, - "enable_all_optimization": True, - "use_lazy_init": True, - "precision": "bf16", - "max_norm": 5, - }, { "tp_size": 2, "pp_size": 1, "zero_stage": 1, - "enable_all_optimization": True, - "use_lazy_init": False, - "precision": "bf16", - "max_norm": 5, - }, - { - "tp_size": 2, - "pp_size": 1, - "zero_stage": 2, - "enable_all_optimization": True, + "enable_all_optimization": False, "use_lazy_init": False, "precision": "bf16", "max_norm": 5, diff --git a/tests/test_zero/test_gemini/test_chunkv2.py b/tests/test_zero/test_gemini/test_chunkv2.py index a31c888e966d..5977c706fdd1 100644 --- a/tests/test_zero/test_gemini/test_chunkv2.py +++ b/tests/test_zero/test_gemini/test_chunkv2.py @@ -39,7 +39,7 @@ def exam_chunk_basic(init_device, keep_gathered, pin_memory): pg = _get_default_group() my_chunk = Chunk( chunk_size=1024, - process_group=pg, + zero_group=pg, dtype=torch.float32, init_device=init_device, cpu_shard_init=True,