You signed in with another tab or window. Reload to refresh your session.You signed out in another tab or window. Reload to refresh your session.You switched accounts on another tab or window. Reload to refresh your session.Dismiss alert
Initialization: We build the necessary parallel groups by using mpu.initialize_model_parallel without initializing the global args() in Megatron-LM.
Model:
We implement our own ParallelLlamaModel model with TP, PP, SP and flash-attn + remove padding support.
To support building the model, we modify the get_model() and unwrap_model() from the megatron/training directory and delete the get_args() usage.
Optimizer:
In Megatron-LM v0.4.0, the optimizer are not packed in megatron.core but store in megatron/optimizer directory. The hyper-parameters of the optimizer are controlled by the global args. Therefore, we delete all the get_args() usage in optimizer classes.
To support building the optimizer, we modify the get_megatron_optimizer() in megatron.optimizer and delete the get_args() usage.
Also, we learn from the latest version of MCore that utilize a OptimizerConfig to manage the hyper-parameter of the optimizer.
Pipeline schedule:
In Megatron-LM v0.4.0, get_forward_backward_func doesn't support flash-attn's sequence packing. If the sequence packing/remove padding is used, the actual send/recv shapes of multiple micro batch in forward_backward_func may be inconsistent. The Megatron-LM v0.4.0 determine the send/recv shapes based on seq len * micro batch size * hidden size. If all the samples in the same micro batch will be padding to the same seq len, there will be no bugs. However, if the sequence packing/remove padding is used, each micro batch will have a different send/recv shape. The batch size of each micro batch is (1, total_seq_len). The total_seq_len is the sum of the seq len after the actual remove padding, so the shape of each micro batch is not the same.
Our solution: Provide the shapes of each micro batch directly, no longer computed through the megatron internal function.
Upgrade to higher megatron-core version (>=0.7.0)
We identify some potential works/issues when directly upgrading "Megatron-LM + MCore v0.4.0" to "MCore v0.7.0". In MCore v0.7.0, the optimizer is packed in MCore, we may be able to solely depend on MCore package.
Since MCore v0.6, the GradBuffer in grad_buffer.py from previous version has changed to ParamAndGradBuffer in param_and_grad_buffer,py . Therefore, when synchronizing the weights between Megatron-LM and vLLM through AllGatherPPModel in veRL, we need to further sync the weight to where param_buffer is used.
Check whether the pipeline scheduling and remove padding issue still exists in current version
Need to fix some inconsistency when using ModelParallelConfig and OptimizerConfig in models and optimizer initialization.
Pain points in MCore
1. Supporting several device meshes (i.e., parallel groups) for different models in one Megatron process group
Currently, a Megatron training job only supports one parallelism strategy (TP-PP-DP) since the corresponding parallel groups are managed by global variables in parallel_state.py. However, for RL Post-Training, we want different models to be trained within the same Megatron process group, but with different parallelism strategies. In essence, we would like Megatron-LM/MCore to support the creation of multiple device meshes (i.e., parallel groups), each managing the parallelism strategy for a specific model.
A potential API for this functionality could look like this:
# FSDP with device meshself.device_mesh=torch.distributed.device_mesh.init_device_mesh(mesh_shape[8])
fsdp_module=FSDP(module,
...
...
sharding_strategy=sharding_strategy, # zero3device_mesh=self.device_mesh) # each model in the same processes is binded with one device mesh# Megatron Modelsself.device_mesh=mpu.initialize_device_mesh(tp=2, pp=1, dp=4)
megatron_module=get_model(model_provider_func=megatron_actor_model_provider,
model_type=ModelType.encoder_or_decoder,
wrap_with_ddp=False,
device_mesh=self.device_mesh)
# inside get_model()model=model_provider_func(pre_process=pre_process,
post_process=post_process,
add_encoder=add_encoder,
add_decoder=add_decoder,
device_mesh=self.device_mesh) # each model in the same processes is binded with one device mesh
In this way, in colocate placement, we can support different models (actor, critic, reference, reward models) being executed in the same WorkerGroup under their respective optimal parallelism strategies
2. GPTModel in MCore may not bitwise align with huggingface model
The GPTModel in MCore offers a variety of model implementations by customizing the TransformerConfig.
However, we discovered that the GPTModel with TransformerEngine does not achieve bitwise alignment with the Hugging Face Llama model. This misalignment further affects the convergence when training with PPO and other RL algorithms. One potential reason for this issue may stem from the discrepancies between the vLLM Llama model (from Hugging Face) and the GPTModel. To address this, we have to implement our own Llama-architecture model for training with Megatron.
By resolving this misalignment issue, we believe that it will be much easier to support the training of a wide range of models using the Megatron backend (and also support Context Parallel).
The text was updated successfully, but these errors were encountered:
How veRL use Megatron-LM and MCore v0.4.0
mpu.initialize_model_parallel
without initializing the global args() in Megatron-LM.ParallelLlamaModel
model with TP, PP, SP and flash-attn + remove padding support.get_model()
andunwrap_model()
from the megatron/training directory and delete theget_args()
usage.megatron.core
but store in megatron/optimizer directory. The hyper-parameters of the optimizer are controlled by the global args. Therefore, we delete all theget_args()
usage in optimizer classes.get_megatron_optimizer()
inmegatron.optimizer
and delete theget_args()
usage.OptimizerConfig
to manage the hyper-parameter of the optimizer.get_forward_backward_func
doesn't support flash-attn's sequence packing. If the sequence packing/remove padding is used, the actual send/recv shapes of multiple micro batch inforward_backward_func
may be inconsistent. The Megatron-LM v0.4.0 determine the send/recv shapes based on seq len * micro batch size * hidden size. If all the samples in the same micro batch will be padding to the same seq len, there will be no bugs. However, if the sequence packing/remove padding is used, each micro batch will have a different send/recv shape. The batch size of each micro batch is (1, total_seq_len). The total_seq_len is the sum of the seq len after the actual remove padding, so the shape of each micro batch is not the same.Upgrade to higher megatron-core version (>=0.7.0)
We identify some potential works/issues when directly upgrading "Megatron-LM + MCore v0.4.0" to "MCore v0.7.0". In MCore v0.7.0, the optimizer is packed in MCore, we may be able to solely depend on MCore package.
GradBuffer
in grad_buffer.py from previous version has changed toParamAndGradBuffer
in param_and_grad_buffer,py . Therefore, when synchronizing the weights between Megatron-LM and vLLM throughAllGatherPPModel
in veRL, we need to further sync the weight to whereparam_buffer
is used.ModelParallelConfig
andOptimizerConfig
in models and optimizer initialization.Pain points in MCore
1. Supporting several device meshes (i.e., parallel groups) for different models in one Megatron process group
Currently, a Megatron training job only supports one parallelism strategy (TP-PP-DP) since the corresponding parallel groups are managed by global variables in parallel_state.py. However, for RL Post-Training, we want different models to be trained within the same Megatron process group, but with different parallelism strategies. In essence, we would like Megatron-LM/MCore to support the creation of multiple device meshes (i.e., parallel groups), each managing the parallelism strategy for a specific model.
A potential API for this functionality could look like this:
In this way, in colocate placement, we can support different models (actor, critic, reference, reward models) being executed in the same WorkerGroup under their respective optimal parallelism strategies
2. GPTModel in MCore may not bitwise align with huggingface model
The GPTModel in MCore offers a variety of model implementations by customizing the
TransformerConfig
.However, we discovered that the GPTModel with TransformerEngine does not achieve bitwise alignment with the Hugging Face Llama model. This misalignment further affects the convergence when training with PPO and other RL algorithms. One potential reason for this issue may stem from the discrepancies between the vLLM Llama model (from Hugging Face) and the GPTModel. To address this, we have to implement our own Llama-architecture model for training with Megatron.
By resolving this misalignment issue, we believe that it will be much easier to support the training of a wide range of models using the Megatron backend (and also support Context Parallel).
The text was updated successfully, but these errors were encountered: