Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[RFC] Megatron-LM and MCore maintaining issues for veRL #15

Open
PeterSH6 opened this issue Nov 19, 2024 · 0 comments
Open

[RFC] Megatron-LM and MCore maintaining issues for veRL #15

PeterSH6 opened this issue Nov 19, 2024 · 0 comments
Labels
enhancement New feature or request megatron

Comments

@PeterSH6
Copy link
Collaborator

PeterSH6 commented Nov 19, 2024

How veRL use Megatron-LM and MCore v0.4.0

  • Initialization: We build the necessary parallel groups by using mpu.initialize_model_parallel without initializing the global args() in Megatron-LM.
  • Model:
    • We implement our own ParallelLlamaModel model with TP, PP, SP and flash-attn + remove padding support.
    • To support building the model, we modify the get_model() and unwrap_model() from the megatron/training directory and delete the get_args() usage.
  • Optimizer:
    • In Megatron-LM v0.4.0, the optimizer are not packed in megatron.core but store in megatron/optimizer directory. The hyper-parameters of the optimizer are controlled by the global args. Therefore, we delete all the get_args() usage in optimizer classes.
    • To support building the optimizer, we modify the get_megatron_optimizer() in megatron.optimizer and delete the get_args() usage.
    • Also, we learn from the latest version of MCore that utilize a OptimizerConfig to manage the hyper-parameter of the optimizer.
  • Pipeline schedule:
    • In Megatron-LM v0.4.0, get_forward_backward_func doesn't support flash-attn's sequence packing. If the sequence packing/remove padding is used, the actual send/recv shapes of multiple micro batch in forward_backward_func may be inconsistent. The Megatron-LM v0.4.0 determine the send/recv shapes based on seq len * micro batch size * hidden size. If all the samples in the same micro batch will be padding to the same seq len, there will be no bugs. However, if the sequence packing/remove padding is used, each micro batch will have a different send/recv shape. The batch size of each micro batch is (1, total_seq_len). The total_seq_len is the sum of the seq len after the actual remove padding, so the shape of each micro batch is not the same.
    • Our solution: Provide the shapes of each micro batch directly, no longer computed through the megatron internal function.

Upgrade to higher megatron-core version (>=0.7.0)

We identify some potential works/issues when directly upgrading "Megatron-LM + MCore v0.4.0" to "MCore v0.7.0". In MCore v0.7.0, the optimizer is packed in MCore, we may be able to solely depend on MCore package.

  • Since MCore v0.6, the GradBuffer in grad_buffer.py from previous version has changed to ParamAndGradBuffer in param_and_grad_buffer,py . Therefore, when synchronizing the weights between Megatron-LM and vLLM through AllGatherPPModel in veRL, we need to further sync the weight to where param_buffer is used.
  • Check whether the pipeline scheduling and remove padding issue still exists in current version
  • Need to fix some inconsistency when using ModelParallelConfig and OptimizerConfig in models and optimizer initialization.

Pain points in MCore

1. Supporting several device meshes (i.e., parallel groups) for different models in one Megatron process group

Currently, a Megatron training job only supports one parallelism strategy (TP-PP-DP) since the corresponding parallel groups are managed by global variables in parallel_state.py. However, for RL Post-Training, we want different models to be trained within the same Megatron process group, but with different parallelism strategies. In essence, we would like Megatron-LM/MCore to support the creation of multiple device meshes (i.e., parallel groups), each managing the parallelism strategy for a specific model.

A potential API for this functionality could look like this:

# FSDP with device mesh
self.device_mesh = torch.distributed.device_mesh.init_device_mesh(mesh_shape[8])
fsdp_module = FSDP(module,
			...
			...
			sharding_strategy=sharding_strategy, # zero3
			device_mesh=self.device_mesh) # each model in the same processes is binded with one device mesh

# Megatron Models
self.device_mesh = mpu.initialize_device_mesh(tp=2, pp=1, dp=4)

megatron_module = get_model(model_provider_func=megatron_actor_model_provider,
				model_type=ModelType.encoder_or_decoder,
				wrap_with_ddp=False,
    device_mesh=self.device_mesh)

# inside get_model()
model = model_provider_func(pre_process=pre_process,					
			post_process=post_process,					
			add_encoder=add_encoder,			
			add_decoder=add_decoder,
			device_mesh=self.device_mesh) # each model in the same processes is binded with one device mesh

In this way, in colocate placement, we can support different models (actor, critic, reference, reward models) being executed in the same WorkerGroup under their respective optimal parallelism strategies

2. GPTModel in MCore may not bitwise align with huggingface model

The GPTModel in MCore offers a variety of model implementations by customizing the TransformerConfig.

However, we discovered that the GPTModel with TransformerEngine does not achieve bitwise alignment with the Hugging Face Llama model. This misalignment further affects the convergence when training with PPO and other RL algorithms. One potential reason for this issue may stem from the discrepancies between the vLLM Llama model (from Hugging Face) and the GPTModel. To address this, we have to implement our own Llama-architecture model for training with Megatron.

By resolving this misalignment issue, we believe that it will be much easier to support the training of a wide range of models using the Megatron backend (and also support Context Parallel).

@PeterSH6 PeterSH6 added enhancement New feature or request megatron labels Nov 19, 2024
@PeterSH6 PeterSH6 pinned this issue Nov 22, 2024
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
enhancement New feature or request megatron
Projects
None yet
Development

No branches or pull requests

1 participant