diff --git a/applications/ColossalChat/README.md b/applications/ColossalChat/README.md index ef904b864a14..690c398184ab 100755 --- a/applications/ColossalChat/README.md +++ b/applications/ColossalChat/README.md @@ -284,7 +284,7 @@ For more details, see [`inference/`](https://github.com/hpcaitech/ColossalAI/tre ## O1 Journey ### Inference with Self-refined MCTS We provide the implementation of MCT Self-Refine (MCTSr) algorithm, an innovative integration of Large Language Models with Monte Carlo Tree Search. -To run inference with MCTS, simply use the following script. +You can serve model using vLLM and update the config file in `Qwen32B_prompt_CFG` and then run the following script. ```python from coati.reasoner.guided_search.mcts import MCTS from coati.reasoner.guided_search.prompt_store.qwen import Qwen32B_prompt_CFG diff --git a/applications/ColossalChat/coati/reasoner/guided_search/mcts.py b/applications/ColossalChat/coati/reasoner/guided_search/mcts.py index 693e2b750539..a87211da210c 100644 --- a/applications/ColossalChat/coati/reasoner/guided_search/mcts.py +++ b/applications/ColossalChat/coati/reasoner/guided_search/mcts.py @@ -58,7 +58,7 @@ def initialization(self): """ Root Initiation. """ - # Dummy answer as root. + # Simple answer as root. You can also use negative response such as "I do not know" as a response. base_answer = self.sample_base_answer() self.root = MCTSNode(answer=base_answer) self.self_evaluate(self.root) @@ -190,7 +190,7 @@ def sample_base_answer(self): messages=[ { "role": "system", - "content": "The user will provide a problem. Solve the problem. The response should begin with [reasoning process]...[Verification]... and end with [Final Answer]. \nThe answer is [answer] \n#### [answer].", + "content": self.cfg.base_system_prompt, }, { "role": "user", diff --git a/applications/ColossalChat/coati/reasoner/guided_search/prompt_store/base.py b/applications/ColossalChat/coati/reasoner/guided_search/prompt_store/base.py index b325b8fa2381..57b63def174e 100644 --- a/applications/ColossalChat/coati/reasoner/guided_search/prompt_store/base.py +++ b/applications/ColossalChat/coati/reasoner/guided_search/prompt_store/base.py @@ -5,6 +5,7 @@ class PromptCFG(BaseModel): model: str base_url: str max_tokens: int = 4096 + base_system_prompt: str critic_system_prompt: str refine_system_prompt: str evaluate_system_prompt: str diff --git a/applications/ColossalChat/coati/reasoner/guided_search/prompt_store/qwen.py b/applications/ColossalChat/coati/reasoner/guided_search/prompt_store/qwen.py index 8bf0fa959da9..64dbc24155de 100644 --- a/applications/ColossalChat/coati/reasoner/guided_search/prompt_store/qwen.py +++ b/applications/ColossalChat/coati/reasoner/guided_search/prompt_store/qwen.py @@ -7,14 +7,16 @@ Qwen32B_prompt_CFG = PromptCFG( base_url="http://0.0.0.0:8008/v1", model="Qwen2.5-32B-Instruct", - critic_system_prompt="Provide a detailed and constructive critique to improve the answer. " - "Highlight specific areas that need refinement or correction.", + base_system_prompt="The user will present a problem. Analyze and solve the problem in the following structure:\n" + "Begin with [Reasoning Process] to explain the approach. \n Proceed with [Verification] to confirm the solution. \n Conclude with [Final Answer] in the format: 'Answer: [answer]'", + critic_system_prompt="Provide a detailed and constructive critique of the answer, focusing on ways to improve its clarity, accuracy, and relevance." + "Highlight specific areas that need refinement or correction, and offer concrete suggestions for enhancing the overall quality and effectiveness of the response.", refine_system_prompt="""# Instruction Refine the answer based on the critique. The response should begin with [reasoning process]...[Verification]... and end with [Final Answer]. """, evaluate_system_prompt=( - "Analyze this answer strictly and critic, provide a reward score between -100 and 100 for the answer quality, using very strict standards. " - "Do not give a full score above 95. Make sure the reward score is an integer. " - "Return *ONLY* the score." + "Critically analyze this answer and provide a reward score between -100 and 100 based on strict standards." + "The score should clearly reflect the quality of the answer." + "Make sure the reward score is an integer. You should only return the score. If the score is greater than 95, return 95." ), ) diff --git a/colossalai/amp/naive_amp/mixed_precision_optimizer.py b/colossalai/amp/naive_amp/mixed_precision_optimizer.py index 8fb56aee4fce..1539bc01d134 100644 --- a/colossalai/amp/naive_amp/mixed_precision_optimizer.py +++ b/colossalai/amp/naive_amp/mixed_precision_optimizer.py @@ -1,4 +1,4 @@ -from typing import Dict, List, Tuple +from typing import Dict, List, Optional, Tuple import torch from torch import Tensor, inf @@ -84,6 +84,7 @@ def __init__( self.master_to_working_map[master_p] = p master_params.append(master_p) group["params"] = master_params + self._current_grad_norm: Optional[float] = None def backward(self, loss: Tensor, inputs=None, retain_graph=False, **kwargs): loss = self.mixed_precision.pre_backward(loss) @@ -192,6 +193,7 @@ def step(self, *args, **kwargs): if p.grad is not None ] total_norm = self._compute_grad_norm(param_gradient_pairs) + self._current_grad_norm = total_norm self._unscale_and_clip_grads(total_norm) self.optim.step(*args, **kwargs) @@ -217,3 +219,6 @@ def get_working_to_master_map(self) -> Dict[int, torch.Tensor]: def get_master_to_working_map(self) -> Dict[int, torch.Tensor]: return {id(master_p): working_p for master_p, working_p in self.master_to_working_map.items()} + + def get_grad_norm(self, norm_type=2, **kwargs): + return self._current_grad_norm diff --git a/colossalai/booster/plugin/hybrid_parallel_plugin.py b/colossalai/booster/plugin/hybrid_parallel_plugin.py index 58d055bb06af..fb73d7e71f87 100644 --- a/colossalai/booster/plugin/hybrid_parallel_plugin.py +++ b/colossalai/booster/plugin/hybrid_parallel_plugin.py @@ -301,6 +301,7 @@ def __init__( self.pp_pg = pp_process_group self.tp_size = get_world_size(self.tp_pg) if self.tp_pg is not None else 1 self.pp_size = get_world_size(self.pp_pg) if self.pp_pg is not None else 1 + self._current_grad_norm: Optional[float] = None super().__init__(optim) def backward(self, loss: Tensor, inputs=None, retain_graph=False, **kwargs): @@ -372,6 +373,7 @@ def step(self, *args, **kwargs): (p, p.grad) for group in self.optim.param_groups for p in group["params"] if p.grad is not None ] total_norm = self._compute_grad_norm(param_gradient_pairs) + self._current_grad_norm = total_norm # Clip the gradients to prevent exploding gradients. self._clip_grad_norm(total_norm) @@ -485,6 +487,9 @@ def get_working_to_master_map(self): def get_master_to_working_map(self): return None + def get_grad_norm(self, norm_type=2, **kwargs): + return self._current_grad_norm + class HybridParallelAMPOptimizer(MixedPrecisionOptimizer): def __init__( diff --git a/colossalai/booster/plugin/low_level_zero_plugin.py b/colossalai/booster/plugin/low_level_zero_plugin.py index f3a6901ada6b..d5afa2ba83ce 100644 --- a/colossalai/booster/plugin/low_level_zero_plugin.py +++ b/colossalai/booster/plugin/low_level_zero_plugin.py @@ -29,6 +29,7 @@ save_state_dict, sharded_optimizer_loading_epilogue, ) +from colossalai.cluster import ProcessGroupMesh from colossalai.interface import AMPModelMixin, ModelWrapper, OptimizerWrapper from colossalai.interface.optimizer import DistributedOptim from colossalai.logging import get_dist_logger @@ -333,6 +334,7 @@ class LowLevelZeroPlugin(DPPluginBase): verbose (bool, optional): verbose mode. Debug info including grad overflow will be printed. Defaults to False. use_fp8 (bool, optional): Whether to enable fp8 mixed precision training. Defaults to False. fp8_communication (bool, optional): Whether to enable fp8 communication. Defaults to False. + extra_dp_size (int, optional): The number of extra data parallel groups. Defaults to 1. """ def __init__( @@ -358,11 +360,16 @@ def __init__( cast_inputs: bool = True, fp8_communication: bool = False, use_fp8: bool = False, + extra_dp_size: int = 1, ) -> None: super().__init__() assert stage in (1, 2), f"LowLevelZeroPlugin only supports stage 1/2 training" assert precision in SUPPORTED_PRECISION, f"LowLevelZeroPlugin only supports {SUPPORTED_PRECISION} training" assert norm_type == 2.0, f"LowLevelZeroPlugin only supports norm_type=2.0 now" + if extra_dp_size > 1: + assert dist.get_world_size() % extra_dp_size == 0, "extra_dp_size should be a factor of world_size" + inner_dp_size = dist.get_world_size() // extra_dp_size + self.pg_mesh = ProcessGroupMesh(extra_dp_size, inner_dp_size) self.stage = stage self.precision = precision self.zero_optim_kwargs = dict( @@ -383,6 +390,9 @@ def __init__( overlap_allgather=overlap_allgather, fp8_communication=fp8_communication, ) + if extra_dp_size > 1: + self.zero_optim_kwargs["extra_dp_group"] = self.pg_mesh.get_group_along_axis(0) + self.zero_optim_kwargs["dp_process_group"] = self.pg_mesh.get_group_along_axis(1) self.lora_enabled = False self.verbose = verbose self.logger = get_dist_logger() diff --git a/colossalai/cli/launcher/__init__.py b/colossalai/cli/launcher/__init__.py index 0f9ead6495db..99d87948cb5f 100644 --- a/colossalai/cli/launcher/__init__.py +++ b/colossalai/cli/launcher/__init__.py @@ -64,7 +64,8 @@ "This will be converted to --arg1=1 --arg2=2 during execution", ) @click.option("--ssh-port", type=int, default=None, help="(optional) the port used for ssh connection") -@click.argument("user_script", type=str) +@click.option("-m", type=str, default=None, help="run library module as a script (terminates option list)") +@click.argument("user_script", type=str, required=False, default=None) @click.argument("user_args", nargs=-1) def run( host: str, @@ -77,8 +78,9 @@ def run( master_port: int, extra_launch_args: str, ssh_port: int, + m: str, user_script: str, - user_args: str, + user_args: tuple, ) -> None: """ To launch multiple processes on a single node or multiple nodes via command line. @@ -102,9 +104,24 @@ def run( # run with hostfile excluding the hosts selected colossalai run --hostfile --master_addr host1 --exclude host2 --nprocs_per_node 4 train.py """ - if not user_script.endswith(".py"): - click.echo(f"Error: invalid Python file {user_script}. Did you use a wrong option? Try colossalai run --help") - exit() + if m is not None: + if m.endswith(".py"): + click.echo(f"Error: invalid Python module {m}. Did you use a wrong option? Try colossalai run --help") + exit() + if user_script is not None: + user_args = (user_script,) + user_args + user_script = m + m = True + else: + if user_script is None: + click.echo("Error: missing script argument. Did you use a wrong option? Try colossalai run --help") + exit() + if not user_script.endswith(".py"): + click.echo( + f"Error: invalid Python file {user_script}. Did you use a wrong option? Try colossalai run --help" + ) + exit() + m = False args_dict = locals() args = Config(args_dict) diff --git a/colossalai/cli/launcher/run.py b/colossalai/cli/launcher/run.py index 88f70f02ec27..45b1056fdd5e 100644 --- a/colossalai/cli/launcher/run.py +++ b/colossalai/cli/launcher/run.py @@ -113,6 +113,7 @@ def get_launch_command( user_args: List[str], node_rank: int, num_nodes: int, + run_as_module: bool, extra_launch_args: str = None, ) -> str: """ @@ -155,6 +156,8 @@ def _arg_dict_to_list(arg_dict): torch_version = version.parse(torch.__version__) assert torch_version.major >= 1 + if torch_version.major < 2 and run_as_module: + raise ValueError("Torch version < 2.0 does not support running as module") if torch_version.major == 1 and torch_version.minor < 9: # torch distributed launch cmd with torch < 1.9 @@ -198,7 +201,10 @@ def _arg_dict_to_list(arg_dict): ] cmd += _arg_dict_to_list(default_torchrun_rdzv_args) - cmd += _arg_dict_to_list(extra_launch_args) + [user_script] + user_args + cmd += _arg_dict_to_list(extra_launch_args) + if run_as_module: + cmd.append("-m") + cmd += [user_script] + user_args cmd = " ".join(cmd) return cmd @@ -294,6 +300,7 @@ def launch_multi_processes(args: Config) -> None: user_args=args.user_args, node_rank=node_id, num_nodes=len(active_device_pool), + run_as_module=args.m, extra_launch_args=args.extra_launch_args, ) runner.send(hostinfo=hostinfo, cmd=cmd) diff --git a/colossalai/interface/optimizer.py b/colossalai/interface/optimizer.py index c8cf3ec21360..22115a72c3e3 100644 --- a/colossalai/interface/optimizer.py +++ b/colossalai/interface/optimizer.py @@ -152,6 +152,18 @@ def unwrap(self): """ return self.optim + def get_grad_norm(self, norm_type: Union[float, int] = 2.0, **kwargs) -> Optional[float]: + """ + Returns the gradient norm of an iterable of parameters. This method should be called after optimizer.step(). + + Args: + norm_type (float or int): type of the used p-norm. Can be ``'inf'`` for infinity norm. + + Returns: + Optional[float]: Total norm of the gradients (viewed as a single vector). If there are no valid gradients, returns None. + """ + raise NotImplementedError("The method get_grad_norm is not implemented yet.") + class DistributedOptim(Optimizer): def setup_distributed( diff --git a/colossalai/zero/gemini/gemini_optimizer.py b/colossalai/zero/gemini/gemini_optimizer.py index ccd4634b5fe2..ca91b4d9f27c 100644 --- a/colossalai/zero/gemini/gemini_optimizer.py +++ b/colossalai/zero/gemini/gemini_optimizer.py @@ -1,7 +1,7 @@ # this code is inspired by the DeepSpeed library and implemented with our own design from scratch import copy import math -from typing import Any, Dict, Iterator, OrderedDict, Set, Tuple, Union +from typing import Any, Dict, Iterator, Optional, OrderedDict, Set, Tuple, Union import torch import torch.distributed as dist @@ -195,6 +195,7 @@ def __init__( self._logger.warning(f'gpu_margin_mem_ratio is meaningless when placement_policy is not "auto"', ranks=[0]) self._register_states = disposable(self._register_states_) + self._current_grad_norm: Optional[float] = None def _set_grad_ptr(self): for group in self.param_groups: @@ -255,6 +256,7 @@ def _get_combined_scale(self): if self.clipping_flag: total_norm = self._calc_global_norm() + self._current_grad_norm = total_norm clip = ((total_norm / div_scale) + 1e-6) / self.max_norm if clip > 1: div_scale = clip * div_scale @@ -848,6 +850,9 @@ def clip_grad_by_norm( f"Gemini controls grad clipping by itself, so you should not use clip_grad_by_norm", ranks=[0] ) + def get_grad_norm(self, norm_type=2, **kwargs): + return self._current_grad_norm + class GeminiAdamOptimizer(GeminiOptimizer): def __init__(self, model: torch.nn.Module, **defaults: Any) -> None: diff --git a/colossalai/zero/low_level/_utils.py b/colossalai/zero/low_level/_utils.py index 5ab703f09063..8a641f71719c 100644 --- a/colossalai/zero/low_level/_utils.py +++ b/colossalai/zero/low_level/_utils.py @@ -1,6 +1,7 @@ import math -from typing import Optional +from typing import Optional, Tuple, Union +import numpy as np import torch import torch.distributed as dist from torch._utils import _flatten_dense_tensors, _unflatten_dense_tensors @@ -209,3 +210,42 @@ def sync_tensor(flat_tensor, tensor_list): # update the tensor data for p, q in zip(tensor_list, updated_params): p.data = q.data + + +def all_gather_into_flat_tensor_nd( + output_tensor: torch.Tensor, + input_tensor: torch.Tensor, + group: Union[dist.ProcessGroup, Tuple[dist.ProcessGroup, ...]], + async_op: bool = False, +): + if isinstance(group, dist.ProcessGroup): + group = (group,) + sizes = [dist.get_world_size(pg) for pg in group] + ranks = [dist.get_rank(pg) for pg in group] + for i, pg in list(enumerate(group))[::-1]: + if i == 0: + out = output_tensor + else: + prev_sizes = sizes[:i] + prev_ranks = ranks[:i] + chunks = output_tensor.chunk(np.prod(prev_sizes)) + out = chunks[np.ravel_multi_index(prev_ranks, prev_sizes)] + handle = dist.all_gather_into_tensor(out, input_tensor, group=pg, async_op=async_op) + input_tensor = out + return handle + + +def get_nd_world_size(group) -> int: + if isinstance(group, tuple): + return int(np.prod([dist.get_world_size(pg) for pg in group])) + else: + return dist.get_world_size(group) + + +def get_nd_rank(group) -> int: + if isinstance(group, tuple): + return np.ravel_multi_index( + tuple(dist.get_rank(group=pg) for pg in group), [dist.get_world_size(pg) for pg in group] + ) + else: + return dist.get_rank(group) diff --git a/colossalai/zero/low_level/bookkeeping/base_store.py b/colossalai/zero/low_level/bookkeeping/base_store.py index 7f2f9664b7de..291f7a0135bc 100644 --- a/colossalai/zero/low_level/bookkeeping/base_store.py +++ b/colossalai/zero/low_level/bookkeeping/base_store.py @@ -1,11 +1,20 @@ +from typing import Tuple, Union + +import numpy as np import torch.distributed as dist from torch.distributed import ProcessGroup class BaseStore: - def __init__(self, torch_pg: ProcessGroup): - self._world_size = dist.get_world_size(group=torch_pg) - self._local_rank = dist.get_rank(group=torch_pg) + def __init__(self, torch_pg: Union[ProcessGroup, Tuple[ProcessGroup, ...]]): + if isinstance(torch_pg, tuple): + self.sizes = [dist.get_world_size(group=pg) for pg in torch_pg] + self._world_size = int(np.prod(self.sizes)) + self._local_rank = np.ravel_multi_index(tuple(dist.get_rank(group=pg) for pg in torch_pg), self.sizes) + else: + self._world_size = dist.get_world_size(group=torch_pg) + self._local_rank = dist.get_rank(group=torch_pg) + self.sizes = [self._world_size] self.torch_pg = torch_pg @property diff --git a/colossalai/zero/low_level/bookkeeping/bucket_store.py b/colossalai/zero/low_level/bookkeeping/bucket_store.py index 19d20de2b250..6729d4615f20 100644 --- a/colossalai/zero/low_level/bookkeeping/bucket_store.py +++ b/colossalai/zero/low_level/bookkeeping/bucket_store.py @@ -78,13 +78,13 @@ def build_grad_in_bucket(self): } """ for param, padding_size in zip(self._param_list, self._padding_size): - grad = param.grad.clone().detach().flatten() + grad = param.grad.detach().flatten() if padding_size > 0: with torch.no_grad(): grad = torch.nn.functional.pad(grad.view(-1), [0, padding_size]) grad_list = grad.split(grad.numel() // self._world_size) for rank in range(self._world_size): - grad_current_rank = grad_list[rank].clone().detach() + grad_current_rank = grad_list[rank].detach() self.grad_to_param_mapping[id(grad_current_rank)] = id(param) self._grad_in_bucket[rank].append(grad_current_rank) param.grad = None @@ -110,7 +110,7 @@ def get_flatten_grad(self) -> Tensor: flat_grad = [] for grad_list in self._grad_in_bucket.values(): - flat_grad.append(_flatten_dense_tensors(grad_list)) + flat_grad.extend(grad_list) flat_grad = _flatten_dense_tensors(flat_grad) return flat_grad diff --git a/colossalai/zero/low_level/bookkeeping/tensor_bucket.py b/colossalai/zero/low_level/bookkeeping/tensor_bucket.py index 3c95aa6babcd..452080a491c7 100644 --- a/colossalai/zero/low_level/bookkeeping/tensor_bucket.py +++ b/colossalai/zero/low_level/bookkeeping/tensor_bucket.py @@ -1,10 +1,12 @@ from typing import Optional +import numpy as np import torch import torch.distributed as dist from torch._utils import _flatten_dense_tensors, _unflatten_dense_tensors from colossalai.quantization.fp8 import all_gather_fp8 +from colossalai.zero.low_level._utils import all_gather_into_flat_tensor_nd class TensorBucket: @@ -65,12 +67,18 @@ def unflatten_and_copy(self, flat_tensor): def all_gather(self, group=None, fp8_communication: bool = False): flat = self.flatten() - buffer = torch.empty(flat.numel() * dist.get_world_size(group), device=flat.device, dtype=flat.dtype) + if isinstance(group, tuple): + world_size = np.prod([dist.get_world_size(pg) for pg in group]) + else: + world_size = dist.get_world_size(group) + buffer = torch.empty(flat.numel() * world_size, device=flat.device, dtype=flat.dtype) if fp8_communication: + # TODO: fit fp8 all_gather_fp8(list(buffer.chunk(dist.get_world_size(group))), flat, group=group, fp8_format="e4m3") else: - dist.all_gather_into_tensor(buffer, flat, group=group) - unflat_buffers = [self.unflatten(buffer) for buffer in buffer.chunk(dist.get_world_size(group))] + # dist.all_gather_into_tensor(buffer, flat, group=group) + all_gather_into_flat_tensor_nd(buffer, flat, group=group) + unflat_buffers = [self.unflatten(buffer) for buffer in buffer.chunk(world_size)] # transpose the list of list unflat_buffers = list(map(list, zip(*unflat_buffers))) for unflat_shards, tensor in zip(unflat_buffers, self._bucket): diff --git a/colossalai/zero/low_level/low_level_optim.py b/colossalai/zero/low_level/low_level_optim.py index 91449497b877..4b237d2fb0f2 100644 --- a/colossalai/zero/low_level/low_level_optim.py +++ b/colossalai/zero/low_level/low_level_optim.py @@ -2,7 +2,7 @@ import copy from contextlib import contextmanager, nullcontext from functools import partial -from typing import Dict, Iterator, List, Optional, Tuple +from typing import Dict, Iterator, List, Optional, Tuple, Union from weakref import proxy import torch @@ -23,7 +23,15 @@ from colossalai.quantization.fp8 import all_gather_fp8, all_reduce_fp8, reduce_scatter_fp8 from colossalai.tensor.moe_tensor.api import is_moe_tensor -from ._utils import calculate_global_norm_from_list, has_inf_or_nan, release_param_grad, sync_tensor +from ._utils import ( + all_gather_into_flat_tensor_nd, + calculate_global_norm_from_list, + get_nd_rank, + get_nd_world_size, + has_inf_or_nan, + release_param_grad, + sync_tensor, +) from .bookkeeping import BucketStore, GradientStore, TensorBucket from .zero_hook import set_all_gather_handle, wait_all_gather_handle @@ -68,7 +76,7 @@ class LowLevelZeroOptimizer(OptimizerWrapper): def __init__( self, optimizer: Optimizer, - pg_to_param_list: Optional[Dict[ProcessGroup, List[nn.Parameter]]] = None, + pg_to_param_list: Optional[Dict[Union[ProcessGroup, Tuple[ProcessGroup, ...]], List[nn.Parameter]]] = None, initial_scale: int = 2**16, # grad scaler config min_scale: int = 1, growth_factor: float = 2.0, @@ -84,6 +92,7 @@ def __init__( partition_grad: bool = False, # stage 2 flag cpu_offload: bool = False, # cpu offload dp_process_group: Optional[ProcessGroup] = None, + extra_dp_group: Optional[ProcessGroup] = None, forced_dtype: Optional[torch.dtype] = None, master_weights: bool = True, # master weights overlap_allgather: bool = False, @@ -98,9 +107,17 @@ def __init__( if (dp_process_group is not None) and (pg_to_param_list is not None): raise ValueError("dp_process_group and pg_to_param_list should not be provided at the same time.") + if pg_to_param_list is None and extra_dp_group is not None and dp_process_group is None: + raise ValueError("dp_process_group should be provided when extra_dp_group is provided.") + if pg_to_param_list is None and extra_dp_group is not None and fp8_communication: + raise ValueError( + "fp8_communication is not supported when pg_to_param_list is None and extra_dp_group is provided." + ) if pg_to_param_list is None: unique_dp_group = dist.group.WORLD if dp_process_group is None else dp_process_group + if extra_dp_group is not None: + unique_dp_group = (extra_dp_group, unique_dp_group) pg_to_param_list = {unique_dp_group: []} for group in self.optim.param_groups: pg_to_param_list[unique_dp_group].extend(group["params"]) @@ -218,6 +235,7 @@ def __init__( ) elif self._dtype is torch.bfloat16: self.mixed_precision_mixin = BF16MixedPrecisionMixin() + self._current_grad_norm: Optional[float] = None def __del__(self): for hook in self.grad_handles: @@ -335,10 +353,12 @@ def _run_reduction(self): flat_grads = flat_grads.to(self._communication_dtype) if not self._partition_grads: - if self._fp8_communication: - all_reduce_fp8(flat_grads, group=bucket_store.torch_pg) - else: - dist.all_reduce(flat_grads, group=bucket_store.torch_pg) + for i, sz in enumerate(bucket_store.sizes): + grp = bucket_store.torch_pg if len(bucket_store.sizes) == 1 else bucket_store.torch_pg[i] + if self._fp8_communication: + all_reduce_fp8(flat_grads, group=grp) + else: + dist.all_reduce(flat_grads, group=grp) if flat_grads.dtype != grad_dtype: flat_grads = flat_grads.to(grad_dtype) @@ -346,16 +366,20 @@ def _run_reduction(self): grad_in_bucket = bucket_store.get_grad() self._update_unpartitoned_grad(bucket_store, grad_in_bucket.values(), flat_grads_per_rank, group_id) else: - flat_grads_list = list(flat_grads.split(len(flat_grads) // bucket_store.world_size)) - received_grad = torch.zeros_like(flat_grads_list[0]) - if self._fp8_communication: - reduce_scatter_fp8( - received_grad, - flat_grads_list, - group=bucket_store.torch_pg, - ) - else: - dist.reduce_scatter(received_grad, flat_grads_list, group=bucket_store.torch_pg) + cur_flat_grads = flat_grads + for i, sz in enumerate(bucket_store.sizes): + grp = bucket_store.torch_pg if len(bucket_store.sizes) == 1 else bucket_store.torch_pg[i] + flat_grads_list = list(cur_flat_grads.split(len(cur_flat_grads) // sz)) + received_grad = torch.zeros_like(flat_grads_list[0]) + if self._fp8_communication: + reduce_scatter_fp8( + received_grad, + flat_grads_list, + group=grp, + ) + else: + dist.reduce_scatter_tensor(received_grad, cur_flat_grads, group=grp) + cur_flat_grads = received_grad if received_grad.dtype != grad_dtype: received_grad = received_grad.to(grad_dtype) @@ -556,6 +580,7 @@ def step(self, closure=None): # unscale and clip grads global_norm = calculate_global_norm_from_list(norm_list=norm_groups) + self._current_grad_norm = global_norm self._unscale_and_clip_grads(grad_partition_groups, global_norm) # update the parameters @@ -580,11 +605,13 @@ def step(self, closure=None): pg = self.param_to_pg[working_param] padded_working_param = self._working_param_to_padded_working_param[working_param] if self._overlap_allgather: - handle = dist.all_gather_into_tensor(padded_working_param, param_to_gather, pg, async_op=True) + # handle = dist.all_gather_into_tensor(padded_working_param, param_to_gather, pg, async_op=True) + handle = all_gather_into_flat_tensor_nd(padded_working_param, param_to_gather, pg, async_op=True) set_all_gather_handle(working_param, handle) else: if param_to_gather.numel() > self.pg_to_tensor_bucket[pg].max_size: if self._fp8_communication: + # TODO: fit fp8 communication all_gather_fp8( list(padded_working_param.chunk(dist.get_world_size(pg))), param_to_gather, @@ -592,7 +619,8 @@ def step(self, closure=None): fp8_format="e4m3", ) else: - dist.all_gather_into_tensor(padded_working_param, param_to_gather, pg) + # dist.all_gather_into_tensor(padded_working_param, param_to_gather, pg) + all_gather_into_flat_tensor_nd(padded_working_param, param_to_gather, pg) continue try: self.pg_to_tensor_bucket[pg].add_to_bucket(param_to_gather, write_back_tensor=working_param) @@ -605,7 +633,9 @@ def step(self, closure=None): if not tensor_bucket.is_empty(): tensor_bucket.all_gather(pg, fp8_communication=self._fp8_communication) - def _compute_grad_norm(self, dp_pg: ProcessGroup, gradients: List[Tensor], norm_type: int = 2) -> float: + def _compute_grad_norm( + self, dp_pg: Union[ProcessGroup, Tuple[ProcessGroup, ...]], gradients: List[Tensor], norm_type: int = 2 + ) -> float: r""" Compute and return the gradient norm for gradient clipping. @@ -628,7 +658,11 @@ def _compute_grad_norm(self, dp_pg: ProcessGroup, gradients: List[Tensor], norm_ device=get_accelerator().get_current_device(), dtype=torch.float, ) - dist.all_reduce(total_norm_cuda, op=torch.distributed.ReduceOp.MAX, group=dp_pg) + if isinstance(dp_pg, tuple): + for grp in dp_pg: + dist.all_reduce(total_norm_cuda, op=torch.distributed.ReduceOp.MAX, group=grp) + else: + dist.all_reduce(total_norm_cuda, op=torch.distributed.ReduceOp.MAX, group=dp_pg) total_norm = total_norm_cuda.item() else: @@ -643,11 +677,19 @@ def _compute_grad_norm(self, dp_pg: ProcessGroup, gradients: List[Tensor], norm_ device=get_accelerator().get_current_device(), dtype=torch.float, ) - torch.distributed.all_reduce( - total_norm_exponentiated_cuda, - op=torch.distributed.ReduceOp.SUM, - group=dp_pg, - ) + if isinstance(dp_pg, tuple): + for grp in dp_pg: + dist.all_reduce( + total_norm_exponentiated_cuda, + op=torch.distributed.ReduceOp.SUM, + group=grp, + ) + else: + torch.distributed.all_reduce( + total_norm_exponentiated_cuda, + op=torch.distributed.ReduceOp.SUM, + group=dp_pg, + ) total_norm = total_norm_exponentiated_cuda.item() ** (1.0 / norm_type) return total_norm @@ -747,11 +789,9 @@ def state_dict(self) -> Dict: if isinstance(v, torch.Tensor) and k != "step": working_param = self.master_to_working_param[id(param)] pg = self.param_to_pg[working_param] - gather_tensor = [torch.zeros(v.shape, device=device, dtype=v.dtype) for _ in range(pg.size())] - dist.all_gather(gather_tensor, v.to(device), group=pg) - param_state = ( - torch.stack(gather_tensor).view(-1)[: working_param.numel()].reshape_as(working_param).cpu() - ) + gathered_tensor = torch.empty(v.numel() * get_nd_world_size(pg), device=device, dtype=v.dtype) + all_gather_into_flat_tensor_nd(gathered_tensor, v.to(device).flatten(), pg) + param_state = gathered_tensor[: working_param.numel()].reshape_as(working_param).cpu() zero_state[param][k] = param_state states_dict = self._pack_state(zero_state) @@ -773,15 +813,17 @@ def load_state_dict(self, state_dict: Dict): cnt += 1 for param_idx, state in zero_state_dict["state"].items(): pg = self.param_to_pg[self.master_to_working_param[id(idx2master[param_idx])]] + world_size = get_nd_world_size(pg) + rank = get_nd_rank(pg) for k, v in state.items(): if isinstance(v, torch.Tensor) and k != "step": - padding_size = (pg.size() - v.numel() % pg.size()) % pg.size() + padding_size = (world_size - v.numel() % world_size) % world_size with torch.no_grad(): v = v.flatten() if padding_size > 0: v = torch.nn.functional.pad(v, [0, padding_size]) - v_list = v.split(v.numel() // pg.size()) - zero_state_dict["state"][param_idx][k] = v_list[pg.rank()].detach().clone() + v_list = v.split(v.numel() // world_size) + zero_state_dict["state"][param_idx][k] = v_list[rank].detach().clone() self.optim.load_state_dict(zero_state_dict) @@ -817,11 +859,9 @@ def state_dict_shard(self, max_shard_size: int = 1024) -> Iterator[Tuple[Dict, i for k, v in states.items(): if isinstance(v, torch.Tensor) and k != "step": - state_tensor = [torch.zeros(v.shape, device=device, dtype=v.dtype) for _ in range(pg.size())] - dist.all_gather(state_tensor, v.to(device), group=pg) - state_tensor = ( - torch.stack(state_tensor).view(-1)[: working_param.numel()].reshape_as(working_param).cpu() - ) + state_tensor = torch.empty(v.numel() * get_nd_world_size(pg), device=device, dtype=v.dtype) + all_gather_into_flat_tensor_nd(state_tensor, v.to(device).flatten(), pg) + state_tensor = state_tensor[: working_param.numel()].reshape_as(working_param).cpu() current_block_size += state_tensor.numel() current_block[k] = state_tensor @@ -845,12 +885,14 @@ def update_master_params(self, model: nn.Module) -> None: p_id = id(p) if p_id in self.working_to_master_param: pg = self.param_to_pg[p] + world_size = get_nd_world_size(pg) + rank = get_nd_rank(pg) master_param = self.working_to_master_param[p_id] padding_size = self.get_param_padding_size(p) working_param = p.data.view(-1) if padding_size > 0: working_param = torch.nn.functional.pad(working_param, [0, padding_size]) - master_param.copy_(working_param.chunk(pg.size())[pg.rank()]) + master_param.copy_(working_param.chunk(world_size)[rank]) def get_working_to_master_map(self) -> Dict[int, torch.Tensor]: return self.working_to_master_param @@ -908,9 +950,12 @@ def get_param_grad(self, working_param: nn.Parameter) -> Tensor: grad = grad_store.get_working_grad_by_param_id(id(working_param)) if grad is None: return None - grad_flat = torch.empty((grad_store.world_size, *grad.shape), dtype=grad.dtype, device=grad.device) - dist.all_gather_into_tensor(grad_flat, grad, group=grad_store.torch_pg) - return grad_flat.view(-1)[: working_param.numel()].view_as(working_param) + grad_flat = grad.flatten() + output_grad = torch.empty( + grad_flat.numel() * grad_store.world_size, device=grad_flat.device, dtype=grad_flat.dtype + ) + all_gather_into_flat_tensor_nd(output_grad, grad_flat, grad_store.torch_pg) + return output_grad.view(-1)[: working_param.numel()].view_as(working_param) def get_working_grads_by_group_id(self, group_id: int) -> List[Tensor]: working_grads = [] @@ -939,3 +984,6 @@ def get_partitioned_gradients_by_param_id(self, group_id: int, param_id: int) -> def _force_wait_all_gather(self): for param in self._working_param_to_padded_working_param.keys(): wait_all_gather_handle(param) + + def get_grad_norm(self, norm_type=2, **kwargs): + return self._current_grad_norm diff --git a/requirements/requirements.txt b/requirements/requirements.txt index b77a33b0a151..cf7c37959d61 100644 --- a/requirements/requirements.txt +++ b/requirements/requirements.txt @@ -8,7 +8,7 @@ click fabric contexttimer ninja -torch>=2.2.0,<=2.4.0 +torch>=2.2.0,<=2.4.1 safetensors einops pydantic diff --git a/tests/test_booster/test_plugin/test_3d_plugin.py b/tests/test_booster/test_plugin/test_3d_plugin.py index 3e85329553e0..97995575d986 100644 --- a/tests/test_booster/test_plugin/test_3d_plugin.py +++ b/tests/test_booster/test_plugin/test_3d_plugin.py @@ -76,6 +76,8 @@ def _criterion(outputs, inputs): booster.execute_pipeline(data_iter, model, _criterion, optimizer, return_loss=True) optimizer.step() + grad_norm = optimizer.get_grad_norm() + assert grad_norm is None or isinstance(grad_norm, float) except Exception as e: return repr(e) diff --git a/tests/test_booster/test_plugin/test_gemini_plugin.py b/tests/test_booster/test_plugin/test_gemini_plugin.py index b2790c0e7504..2e9b24fecc6d 100644 --- a/tests/test_booster/test_plugin/test_gemini_plugin.py +++ b/tests/test_booster/test_plugin/test_gemini_plugin.py @@ -54,6 +54,8 @@ def run_fn(init_method, model_fn, data_gen_fn, output_transform_fn, zero_size, t booster.backward(loss, optimizer) optimizer.step() + grad_norm = optimizer.get_grad_norm() + assert grad_norm is None or isinstance(grad_norm, float) except NotImplementedError: print(f"Tensor Parallelism policy for {model.__class__} is not implemented yet\n.") diff --git a/tests/test_booster/test_plugin/test_low_level_zero_plugin.py b/tests/test_booster/test_plugin/test_low_level_zero_plugin.py index c2a08a541bc7..6616866e3d2b 100644 --- a/tests/test_booster/test_plugin/test_low_level_zero_plugin.py +++ b/tests/test_booster/test_plugin/test_low_level_zero_plugin.py @@ -50,6 +50,8 @@ def run_fn(stage, model_fn, data_gen_fn, output_transform_fn, lora_config=None) booster.backward(loss, optimizer) optimizer.step() + grad_norm = optimizer.get_grad_norm() + assert grad_norm is None or isinstance(grad_norm, float) except Exception as e: return repr(e) diff --git a/tests/test_zero/test_low_level/test_coll_nd.py b/tests/test_zero/test_low_level/test_coll_nd.py new file mode 100644 index 000000000000..c9d7e6341c48 --- /dev/null +++ b/tests/test_zero/test_low_level/test_coll_nd.py @@ -0,0 +1,42 @@ +import numpy as np +import pytest +import torch +import torch.distributed as dist + +import colossalai +from colossalai.cluster import ProcessGroupMesh +from colossalai.testing import rerun_if_address_is_in_use, spawn +from colossalai.testing.random import seed_all +from colossalai.utils import get_current_device +from colossalai.zero.low_level._utils import all_gather_into_flat_tensor_nd + + +def check_all_gather_2d(): + seed_all(1024) + tensor = torch.rand(128, device=get_current_device()) + extra_dp_size, inner_dp_size = 2, 2 + pg_mesh = ProcessGroupMesh(extra_dp_size, inner_dp_size) + extra_dp_group = pg_mesh.get_group_along_axis(0) + inner_dp_group = pg_mesh.get_group_along_axis(1) + ranks = [dist.get_rank(extra_dp_group), dist.get_rank(inner_dp_group)] + sizes = [dist.get_world_size(extra_dp_group), dist.get_world_size(inner_dp_group)] + chunk = tensor.chunk(dist.get_world_size())[np.ravel_multi_index(ranks, sizes)].clone() + out = torch.zeros_like(tensor) + all_gather_into_flat_tensor_nd(out, chunk, group=(extra_dp_group, inner_dp_group)) + assert torch.equal(out, tensor) + + +def run_dist(rank, world_size, port): + colossalai.launch(rank=rank, world_size=world_size, port=port, host="localhost") + + check_all_gather_2d() + + +@pytest.mark.dist +@rerun_if_address_is_in_use() +def test_comm_nd(): + spawn(run_dist, 4) + + +if __name__ == "__main__": + test_comm_nd() diff --git a/tests/test_zero/test_low_level/test_zero1_2.py b/tests/test_zero/test_low_level/test_zero1_2.py index 368c782fe2c4..103854f869c7 100644 --- a/tests/test_zero/test_low_level/test_zero1_2.py +++ b/tests/test_zero/test_low_level/test_zero1_2.py @@ -2,11 +2,13 @@ import pytest import torch +import torch.distributed as dist import torch.nn as nn from torch.nn.parallel import DistributedDataParallel as DDP from torch.testing import assert_close import colossalai +from colossalai.cluster import ProcessGroupMesh from colossalai.testing import parameterize, rerun_if_address_is_in_use, spawn from colossalai.testing.random import seed_all from colossalai.zero import LowLevelZeroOptimizer @@ -123,7 +125,8 @@ def exam_zero_1_2(fp8_communication: bool): @parameterize("dtype", [torch.float16, torch.bfloat16]) @parameterize("master_weights", [True, False]) -def exam_zero_1_torch_ddp(world_size, dtype: torch.dtype, master_weights: bool): +@parameterize("extra_dp_size", [1, 2]) +def exam_zero_1_torch_ddp(dtype: torch.dtype, master_weights: bool, extra_dp_size: int): """ In this test, two pairs of model and optimizers are created. 1. zero: use sharded optimizer and fp16 parameters @@ -132,6 +135,15 @@ def exam_zero_1_torch_ddp(world_size, dtype: torch.dtype, master_weights: bool): We feed these two sets of models with the same input and check if the differences in model output and updated parameters are within tolerance. """ + if extra_dp_size > 1 and dtype != torch.bfloat16: + return + if extra_dp_size > 1: + pg_mesh = ProcessGroupMesh(extra_dp_size, dist.get_world_size() // extra_dp_size) + extra_dp_group = pg_mesh.get_group_along_axis(0) + dp_group = pg_mesh.get_group_along_axis(1) + else: + extra_dp_group = None + dp_group = None local_rank = torch.distributed.get_rank() seed_all(1453) @@ -153,6 +165,8 @@ def exam_zero_1_torch_ddp(world_size, dtype: torch.dtype, master_weights: bool): initial_scale=1, reduce_bucket_size=1024 * 1024, master_weights=master_weights, + dp_process_group=dp_group, + extra_dp_group=extra_dp_group, ) torch_optimizer = torch.optim.SGD(torch_model.parameters(), lr=1) @@ -200,14 +214,14 @@ def exam_zero_1_torch_ddp(world_size, dtype: torch.dtype, master_weights: bool): def run_dist(rank, world_size, port): colossalai.launch(rank=rank, world_size=world_size, port=port, host="localhost") - exam_zero_1_torch_ddp(world_size=world_size) + exam_zero_1_torch_ddp() exam_zero_1_2() @pytest.mark.dist @rerun_if_address_is_in_use() def test_zero_1_2(): - spawn(run_dist, 2) + spawn(run_dist, 4) if __name__ == "__main__": diff --git a/tests/test_zero/test_low_level/test_zero_ckpt.py b/tests/test_zero/test_low_level/test_zero_ckpt.py index 8543dfba0c15..656559718518 100644 --- a/tests/test_zero/test_low_level/test_zero_ckpt.py +++ b/tests/test_zero/test_low_level/test_zero_ckpt.py @@ -2,12 +2,14 @@ import pytest import torch +import torch.distributed as dist import torch.nn as nn from torch.nn.parallel import DistributedDataParallel as DDP from torch.testing import assert_close import colossalai -from colossalai.testing import rerun_if_address_is_in_use, spawn +from colossalai.cluster import ProcessGroupMesh +from colossalai.testing import parameterize, rerun_if_address_is_in_use, spawn from colossalai.testing.random import seed_all from colossalai.zero import LowLevelZeroOptimizer @@ -40,11 +42,19 @@ def loose_close(a, b, dtype: torch.dtype = torch.float32): assert_close(a, b, rtol=rtol, atol=atol) -def exam_zero_1_torch_ddp_ckpt(): +@parameterize("extra_dp_size", [1, 2]) +def exam_zero_1_torch_ddp_ckpt(extra_dp_size: int): """ We examine the state_dict of zero and DDP. Moreover, we examine the zero's loading checkpoint of a torch ckpt. """ + if extra_dp_size > 1: + pg_mesh = ProcessGroupMesh(extra_dp_size, dist.get_world_size() // extra_dp_size) + extra_dp_group = pg_mesh.get_group_along_axis(0) + dp_group = pg_mesh.get_group_along_axis(1) + else: + dp_group = None + extra_dp_group = None local_rank = torch.distributed.get_rank() seed_all(1453) @@ -60,7 +70,12 @@ def exam_zero_1_torch_ddp_ckpt(): # we only test stage 1 here # the state dicts of stage 1 and stage 2 are the same zero_optimizer = LowLevelZeroOptimizer( - zero_optimizer, overlap_communication=True, initial_scale=1, reduce_bucket_size=262144 + zero_optimizer, + overlap_communication=True, + initial_scale=1, + reduce_bucket_size=262144, + dp_process_group=dp_group, + extra_dp_group=extra_dp_group, ) torch_optimizer = torch.optim.Adam(torch_model.parameters(), lr=1) @@ -111,7 +126,7 @@ def run_dist(rank, world_size, port): @pytest.mark.dist @rerun_if_address_is_in_use() def test_zero_ckpt(): - spawn(run_dist, 2) + spawn(run_dist, 4) if __name__ == "__main__": diff --git a/version.txt b/version.txt index 0bfccb080404..ef52a648073d 100644 --- a/version.txt +++ b/version.txt @@ -1 +1 @@ -0.4.5 +0.4.6