Skip to content

Commit

Permalink
Merge branch 'main' into dev/zero_bubble
Browse files Browse the repository at this point in the history
  • Loading branch information
duanjunwen committed Nov 18, 2024
2 parents 41fdd21 + 5a03d26 commit cb9e5cc
Show file tree
Hide file tree
Showing 24 changed files with 324 additions and 78 deletions.
2 changes: 1 addition & 1 deletion applications/ColossalChat/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -284,7 +284,7 @@ For more details, see [`inference/`](https://github.com/hpcaitech/ColossalAI/tre
## O1 Journey
### Inference with Self-refined MCTS
We provide the implementation of MCT Self-Refine (MCTSr) algorithm, an innovative integration of Large Language Models with Monte Carlo Tree Search.
To run inference with MCTS, simply use the following script.
You can serve model using vLLM and update the config file in `Qwen32B_prompt_CFG` and then run the following script.
```python
from coati.reasoner.guided_search.mcts import MCTS
from coati.reasoner.guided_search.prompt_store.qwen import Qwen32B_prompt_CFG
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -58,7 +58,7 @@ def initialization(self):
"""
Root Initiation.
"""
# Dummy answer as root.
# Simple answer as root. You can also use negative response such as "I do not know" as a response.
base_answer = self.sample_base_answer()
self.root = MCTSNode(answer=base_answer)
self.self_evaluate(self.root)
Expand Down Expand Up @@ -190,7 +190,7 @@ def sample_base_answer(self):
messages=[
{
"role": "system",
"content": "The user will provide a problem. Solve the problem. The response should begin with [reasoning process]...[Verification]... and end with [Final Answer]. \nThe answer is [answer] \n#### [answer].",
"content": self.cfg.base_system_prompt,
},
{
"role": "user",
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@ class PromptCFG(BaseModel):
model: str
base_url: str
max_tokens: int = 4096
base_system_prompt: str
critic_system_prompt: str
refine_system_prompt: str
evaluate_system_prompt: str
Original file line number Diff line number Diff line change
Expand Up @@ -7,14 +7,16 @@
Qwen32B_prompt_CFG = PromptCFG(
base_url="http://0.0.0.0:8008/v1",
model="Qwen2.5-32B-Instruct",
critic_system_prompt="Provide a detailed and constructive critique to improve the answer. "
"Highlight specific areas that need refinement or correction.",
base_system_prompt="The user will present a problem. Analyze and solve the problem in the following structure:\n"
"Begin with [Reasoning Process] to explain the approach. \n Proceed with [Verification] to confirm the solution. \n Conclude with [Final Answer] in the format: 'Answer: [answer]'",
critic_system_prompt="Provide a detailed and constructive critique of the answer, focusing on ways to improve its clarity, accuracy, and relevance."
"Highlight specific areas that need refinement or correction, and offer concrete suggestions for enhancing the overall quality and effectiveness of the response.",
refine_system_prompt="""# Instruction
Refine the answer based on the critique. The response should begin with [reasoning process]...[Verification]... and end with [Final Answer].
""",
evaluate_system_prompt=(
"Analyze this answer strictly and critic, provide a reward score between -100 and 100 for the answer quality, using very strict standards. "
"Do not give a full score above 95. Make sure the reward score is an integer. "
"Return *ONLY* the score."
"Critically analyze this answer and provide a reward score between -100 and 100 based on strict standards."
"The score should clearly reflect the quality of the answer."
"Make sure the reward score is an integer. You should only return the score. If the score is greater than 95, return 95."
),
)
7 changes: 6 additions & 1 deletion colossalai/amp/naive_amp/mixed_precision_optimizer.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
from typing import Dict, List, Tuple
from typing import Dict, List, Optional, Tuple

import torch
from torch import Tensor, inf
Expand Down Expand Up @@ -84,6 +84,7 @@ def __init__(
self.master_to_working_map[master_p] = p
master_params.append(master_p)
group["params"] = master_params
self._current_grad_norm: Optional[float] = None

def backward(self, loss: Tensor, inputs=None, retain_graph=False, **kwargs):
loss = self.mixed_precision.pre_backward(loss)
Expand Down Expand Up @@ -192,6 +193,7 @@ def step(self, *args, **kwargs):
if p.grad is not None
]
total_norm = self._compute_grad_norm(param_gradient_pairs)
self._current_grad_norm = total_norm
self._unscale_and_clip_grads(total_norm)

self.optim.step(*args, **kwargs)
Expand All @@ -217,3 +219,6 @@ def get_working_to_master_map(self) -> Dict[int, torch.Tensor]:

def get_master_to_working_map(self) -> Dict[int, torch.Tensor]:
return {id(master_p): working_p for master_p, working_p in self.master_to_working_map.items()}

def get_grad_norm(self, norm_type=2, **kwargs):
return self._current_grad_norm
5 changes: 5 additions & 0 deletions colossalai/booster/plugin/hybrid_parallel_plugin.py
Original file line number Diff line number Diff line change
Expand Up @@ -301,6 +301,7 @@ def __init__(
self.pp_pg = pp_process_group
self.tp_size = get_world_size(self.tp_pg) if self.tp_pg is not None else 1
self.pp_size = get_world_size(self.pp_pg) if self.pp_pg is not None else 1
self._current_grad_norm: Optional[float] = None
super().__init__(optim)

def backward(self, loss: Tensor, inputs=None, retain_graph=False, **kwargs):
Expand Down Expand Up @@ -372,6 +373,7 @@ def step(self, *args, **kwargs):
(p, p.grad) for group in self.optim.param_groups for p in group["params"] if p.grad is not None
]
total_norm = self._compute_grad_norm(param_gradient_pairs)
self._current_grad_norm = total_norm

# Clip the gradients to prevent exploding gradients.
self._clip_grad_norm(total_norm)
Expand Down Expand Up @@ -485,6 +487,9 @@ def get_working_to_master_map(self):
def get_master_to_working_map(self):
return None

def get_grad_norm(self, norm_type=2, **kwargs):
return self._current_grad_norm


class HybridParallelAMPOptimizer(MixedPrecisionOptimizer):
def __init__(
Expand Down
10 changes: 10 additions & 0 deletions colossalai/booster/plugin/low_level_zero_plugin.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,7 @@
save_state_dict,
sharded_optimizer_loading_epilogue,
)
from colossalai.cluster import ProcessGroupMesh
from colossalai.interface import AMPModelMixin, ModelWrapper, OptimizerWrapper
from colossalai.interface.optimizer import DistributedOptim
from colossalai.logging import get_dist_logger
Expand Down Expand Up @@ -333,6 +334,7 @@ class LowLevelZeroPlugin(DPPluginBase):
verbose (bool, optional): verbose mode. Debug info including grad overflow will be printed. Defaults to False.
use_fp8 (bool, optional): Whether to enable fp8 mixed precision training. Defaults to False.
fp8_communication (bool, optional): Whether to enable fp8 communication. Defaults to False.
extra_dp_size (int, optional): The number of extra data parallel groups. Defaults to 1.
"""

def __init__(
Expand All @@ -358,11 +360,16 @@ def __init__(
cast_inputs: bool = True,
fp8_communication: bool = False,
use_fp8: bool = False,
extra_dp_size: int = 1,
) -> None:
super().__init__()
assert stage in (1, 2), f"LowLevelZeroPlugin only supports stage 1/2 training"
assert precision in SUPPORTED_PRECISION, f"LowLevelZeroPlugin only supports {SUPPORTED_PRECISION} training"
assert norm_type == 2.0, f"LowLevelZeroPlugin only supports norm_type=2.0 now"
if extra_dp_size > 1:
assert dist.get_world_size() % extra_dp_size == 0, "extra_dp_size should be a factor of world_size"
inner_dp_size = dist.get_world_size() // extra_dp_size
self.pg_mesh = ProcessGroupMesh(extra_dp_size, inner_dp_size)
self.stage = stage
self.precision = precision
self.zero_optim_kwargs = dict(
Expand All @@ -383,6 +390,9 @@ def __init__(
overlap_allgather=overlap_allgather,
fp8_communication=fp8_communication,
)
if extra_dp_size > 1:
self.zero_optim_kwargs["extra_dp_group"] = self.pg_mesh.get_group_along_axis(0)
self.zero_optim_kwargs["dp_process_group"] = self.pg_mesh.get_group_along_axis(1)
self.lora_enabled = False
self.verbose = verbose
self.logger = get_dist_logger()
Expand Down
27 changes: 22 additions & 5 deletions colossalai/cli/launcher/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -64,7 +64,8 @@
"This will be converted to --arg1=1 --arg2=2 during execution",
)
@click.option("--ssh-port", type=int, default=None, help="(optional) the port used for ssh connection")
@click.argument("user_script", type=str)
@click.option("-m", type=str, default=None, help="run library module as a script (terminates option list)")
@click.argument("user_script", type=str, required=False, default=None)
@click.argument("user_args", nargs=-1)
def run(
host: str,
Expand All @@ -77,8 +78,9 @@ def run(
master_port: int,
extra_launch_args: str,
ssh_port: int,
m: str,
user_script: str,
user_args: str,
user_args: tuple,
) -> None:
"""
To launch multiple processes on a single node or multiple nodes via command line.
Expand All @@ -102,9 +104,24 @@ def run(
# run with hostfile excluding the hosts selected
colossalai run --hostfile <file_path> --master_addr host1 --exclude host2 --nprocs_per_node 4 train.py
"""
if not user_script.endswith(".py"):
click.echo(f"Error: invalid Python file {user_script}. Did you use a wrong option? Try colossalai run --help")
exit()
if m is not None:
if m.endswith(".py"):
click.echo(f"Error: invalid Python module {m}. Did you use a wrong option? Try colossalai run --help")
exit()
if user_script is not None:
user_args = (user_script,) + user_args
user_script = m
m = True
else:
if user_script is None:
click.echo("Error: missing script argument. Did you use a wrong option? Try colossalai run --help")
exit()
if not user_script.endswith(".py"):
click.echo(
f"Error: invalid Python file {user_script}. Did you use a wrong option? Try colossalai run --help"
)
exit()
m = False

args_dict = locals()
args = Config(args_dict)
Expand Down
9 changes: 8 additions & 1 deletion colossalai/cli/launcher/run.py
Original file line number Diff line number Diff line change
Expand Up @@ -113,6 +113,7 @@ def get_launch_command(
user_args: List[str],
node_rank: int,
num_nodes: int,
run_as_module: bool,
extra_launch_args: str = None,
) -> str:
"""
Expand Down Expand Up @@ -155,6 +156,8 @@ def _arg_dict_to_list(arg_dict):

torch_version = version.parse(torch.__version__)
assert torch_version.major >= 1
if torch_version.major < 2 and run_as_module:
raise ValueError("Torch version < 2.0 does not support running as module")

if torch_version.major == 1 and torch_version.minor < 9:
# torch distributed launch cmd with torch < 1.9
Expand Down Expand Up @@ -198,7 +201,10 @@ def _arg_dict_to_list(arg_dict):
]
cmd += _arg_dict_to_list(default_torchrun_rdzv_args)

cmd += _arg_dict_to_list(extra_launch_args) + [user_script] + user_args
cmd += _arg_dict_to_list(extra_launch_args)
if run_as_module:
cmd.append("-m")
cmd += [user_script] + user_args
cmd = " ".join(cmd)
return cmd

Expand Down Expand Up @@ -294,6 +300,7 @@ def launch_multi_processes(args: Config) -> None:
user_args=args.user_args,
node_rank=node_id,
num_nodes=len(active_device_pool),
run_as_module=args.m,
extra_launch_args=args.extra_launch_args,
)
runner.send(hostinfo=hostinfo, cmd=cmd)
Expand Down
12 changes: 12 additions & 0 deletions colossalai/interface/optimizer.py
Original file line number Diff line number Diff line change
Expand Up @@ -152,6 +152,18 @@ def unwrap(self):
"""
return self.optim

def get_grad_norm(self, norm_type: Union[float, int] = 2.0, **kwargs) -> Optional[float]:
"""
Returns the gradient norm of an iterable of parameters. This method should be called after optimizer.step().
Args:
norm_type (float or int): type of the used p-norm. Can be ``'inf'`` for infinity norm.
Returns:
Optional[float]: Total norm of the gradients (viewed as a single vector). If there are no valid gradients, returns None.
"""
raise NotImplementedError("The method get_grad_norm is not implemented yet.")


class DistributedOptim(Optimizer):
def setup_distributed(
Expand Down
7 changes: 6 additions & 1 deletion colossalai/zero/gemini/gemini_optimizer.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
# this code is inspired by the DeepSpeed library and implemented with our own design from scratch
import copy
import math
from typing import Any, Dict, Iterator, OrderedDict, Set, Tuple, Union
from typing import Any, Dict, Iterator, Optional, OrderedDict, Set, Tuple, Union

import torch
import torch.distributed as dist
Expand Down Expand Up @@ -195,6 +195,7 @@ def __init__(
self._logger.warning(f'gpu_margin_mem_ratio is meaningless when placement_policy is not "auto"', ranks=[0])

self._register_states = disposable(self._register_states_)
self._current_grad_norm: Optional[float] = None

def _set_grad_ptr(self):
for group in self.param_groups:
Expand Down Expand Up @@ -255,6 +256,7 @@ def _get_combined_scale(self):

if self.clipping_flag:
total_norm = self._calc_global_norm()
self._current_grad_norm = total_norm
clip = ((total_norm / div_scale) + 1e-6) / self.max_norm
if clip > 1:
div_scale = clip * div_scale
Expand Down Expand Up @@ -848,6 +850,9 @@ def clip_grad_by_norm(
f"Gemini controls grad clipping by itself, so you should not use clip_grad_by_norm", ranks=[0]
)

def get_grad_norm(self, norm_type=2, **kwargs):
return self._current_grad_norm


class GeminiAdamOptimizer(GeminiOptimizer):
def __init__(self, model: torch.nn.Module, **defaults: Any) -> None:
Expand Down
42 changes: 41 additions & 1 deletion colossalai/zero/low_level/_utils.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
import math
from typing import Optional
from typing import Optional, Tuple, Union

import numpy as np
import torch
import torch.distributed as dist
from torch._utils import _flatten_dense_tensors, _unflatten_dense_tensors
Expand Down Expand Up @@ -209,3 +210,42 @@ def sync_tensor(flat_tensor, tensor_list):
# update the tensor data
for p, q in zip(tensor_list, updated_params):
p.data = q.data


def all_gather_into_flat_tensor_nd(
output_tensor: torch.Tensor,
input_tensor: torch.Tensor,
group: Union[dist.ProcessGroup, Tuple[dist.ProcessGroup, ...]],
async_op: bool = False,
):
if isinstance(group, dist.ProcessGroup):
group = (group,)
sizes = [dist.get_world_size(pg) for pg in group]
ranks = [dist.get_rank(pg) for pg in group]
for i, pg in list(enumerate(group))[::-1]:
if i == 0:
out = output_tensor
else:
prev_sizes = sizes[:i]
prev_ranks = ranks[:i]
chunks = output_tensor.chunk(np.prod(prev_sizes))
out = chunks[np.ravel_multi_index(prev_ranks, prev_sizes)]
handle = dist.all_gather_into_tensor(out, input_tensor, group=pg, async_op=async_op)
input_tensor = out
return handle


def get_nd_world_size(group) -> int:
if isinstance(group, tuple):
return int(np.prod([dist.get_world_size(pg) for pg in group]))
else:
return dist.get_world_size(group)


def get_nd_rank(group) -> int:
if isinstance(group, tuple):
return np.ravel_multi_index(
tuple(dist.get_rank(group=pg) for pg in group), [dist.get_world_size(pg) for pg in group]
)
else:
return dist.get_rank(group)
15 changes: 12 additions & 3 deletions colossalai/zero/low_level/bookkeeping/base_store.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,20 @@
from typing import Tuple, Union

import numpy as np
import torch.distributed as dist
from torch.distributed import ProcessGroup


class BaseStore:
def __init__(self, torch_pg: ProcessGroup):
self._world_size = dist.get_world_size(group=torch_pg)
self._local_rank = dist.get_rank(group=torch_pg)
def __init__(self, torch_pg: Union[ProcessGroup, Tuple[ProcessGroup, ...]]):
if isinstance(torch_pg, tuple):
self.sizes = [dist.get_world_size(group=pg) for pg in torch_pg]
self._world_size = int(np.prod(self.sizes))
self._local_rank = np.ravel_multi_index(tuple(dist.get_rank(group=pg) for pg in torch_pg), self.sizes)
else:
self._world_size = dist.get_world_size(group=torch_pg)
self._local_rank = dist.get_rank(group=torch_pg)
self.sizes = [self._world_size]
self.torch_pg = torch_pg

@property
Expand Down
6 changes: 3 additions & 3 deletions colossalai/zero/low_level/bookkeeping/bucket_store.py
Original file line number Diff line number Diff line change
Expand Up @@ -78,13 +78,13 @@ def build_grad_in_bucket(self):
}
"""
for param, padding_size in zip(self._param_list, self._padding_size):
grad = param.grad.clone().detach().flatten()
grad = param.grad.detach().flatten()
if padding_size > 0:
with torch.no_grad():
grad = torch.nn.functional.pad(grad.view(-1), [0, padding_size])
grad_list = grad.split(grad.numel() // self._world_size)
for rank in range(self._world_size):
grad_current_rank = grad_list[rank].clone().detach()
grad_current_rank = grad_list[rank].detach()
self.grad_to_param_mapping[id(grad_current_rank)] = id(param)
self._grad_in_bucket[rank].append(grad_current_rank)
param.grad = None
Expand All @@ -110,7 +110,7 @@ def get_flatten_grad(self) -> Tensor:

flat_grad = []
for grad_list in self._grad_in_bucket.values():
flat_grad.append(_flatten_dense_tensors(grad_list))
flat_grad.extend(grad_list)
flat_grad = _flatten_dense_tensors(flat_grad)
return flat_grad

Expand Down
Loading

0 comments on commit cb9e5cc

Please sign in to comment.