Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Unexpected Increase in Rollout Time After Reducing num_hidden_layers in deepseek-llm-7b-chat Model #24

Open
metaqiang opened this issue Nov 25, 2024 · 2 comments

Comments

@metaqiang
Copy link

Description

When running the examples/ppo_trainer/run_deepseek_megatron.sh script with the base model deepseek-llm-7b-chat, I encountered an unexpected behavior related to the num_hidden_layers parameter. Originally, the model has num_hidden_layers set to 30, and the rollout time is approximately 35 seconds. I modified num_hidden_layers to 15, anticipating that the rollout time would roughly halve. However, the rollout time instead increased to about 71 seconds.

Steps to Reproduce

  1. Original Configuration:

    • Run the script examples/ppo_trainer/run_deepseek_megatron.sh with the base model deepseek-llm-7b-chat having num_hidden_layers=30.
    • Observe the rollout time, which is approximately 35 seconds.
  2. Modified Configuration:

    • Change the num_hidden_layers parameter in the model configuration from 30 to 15.
    • Rerun the same script with the modified configuration.
    • Notice that the rollout time increases to approximately 71 seconds instead of decreasing.

Expected Behavior

Reducing the num_hidden_layers from 30 to 15 should lead to a proportional decrease in rollout generation time, ideally halving the time from around 35 seconds to approximately 17-18 seconds.

Actual Behavior

After modifying num_hidden_layers to 15, the rollout time unexpectedly doubled from ~35 seconds to ~71 seconds.

Additional Information

Model Structure (after reducing):

(WorkerDict pid=3206005) parallel_model: ParallelLlamaForCausalLMRmPadPP(
(WorkerDict pid=3206005)   (model): ParallelLlamaModelRmPadPP(
(WorkerDict pid=3206005)     (embed_tokens): VocabParallelEmbedding()
(WorkerDict pid=3206005)     (layers): ModuleList(
(WorkerDict pid=3206005)       (0-14): 15 x ParallelLlamaDecoderLayerRmPad(
(WorkerDict pid=3206005)         (self_attn): ParallelLlamaAttentionRmPad(
(WorkerDict pid=3206005)           (qkv_proj): QKVParallelLinear()
(WorkerDict pid=3206005)           (o_proj): RowParallelLinear()
(WorkerDict pid=3206005)           (rotary_emb): LlamaRotaryEmbedding()
(WorkerDict pid=3206005)         )
(WorkerDict pid=3206005)         (mlp): ParallelLlamaMLP(
(WorkerDict pid=3206005)           (gate_up_proj): MergedColumnParallelLinear()
(WorkerDict pid=3206005)           (down_proj): RowParallelLinear()
(WorkerDict pid=3206005)           (act_fn): SiLU()
(WorkerDict pid=3206005)         )
(WorkerDict pid=3206005)         (input_layernorm): ParallelLlamaRMSNorm()
(WorkerDict pid=3206005)         (post_attention_layernorm): ParallelLlamaRMSNorm()
(WorkerDict pid=3206005)       )
(WorkerDict pid=3206005)     )
(WorkerDict pid=3206005)     (norm): ParallelLlamaRMSNorm()
(WorkerDict pid=3206005)   )
(WorkerDict pid=3206005)   (lm_head): ColumnParallelLinear()
(WorkerDict pid=3206005) )

Timing Code:

@register(dispatch_mode=Dispatch.MEGATRON_PP_AS_DP_PROTO)
def generate_sequences(self, prompts: DataProto):
    assert self._is_rollout
    start = torch.cuda.Event(enable_timing=True)
    end = torch.cuda.Event(enable_timing=True)

    prompts.batch = prompts.batch.cuda()
    meta_info = {'eos_token_id': self.tokenizer.eos_token_id, 'pad_token_id': self.tokenizer.pad_token_id}
    prompts.meta_info.update(meta_info)
    with self.sharding_manager:
        log_gpu_memory_usage('After entering sharding manager', logger=logger)

        prompts = self.sharding_manager.preprocess_data(prompts)
        start.record()
        output = self.rollout.generate_sequences(prompts=prompts)
        end.record()
        
        log_gpu_memory_usage('After rollout generation', logger=logger)

        output = self.sharding_manager.postprocess_data(output)

    validate = prompts.meta_info.get('validate', False)
    if self._is_actor and not validate:
        # we should always recompute old_log_probs when it is HybridEngine
        output.meta_info['micro_batch_size'] = self.config.rollout.log_prob_micro_batch_size
        output.meta_info['temperature'] = self.config.rollout.temperature
        old_log_probs = self.actor.compute_log_prob(data=output)
        output.batch['old_log_probs'] = old_log_probs

    output = output.to('cpu')
    # clear kv cache
    torch.cuda.empty_cache()
    log_gpu_memory_usage('After recompute log prob', logger=logger)
    torch.cuda.synchronize()
    elapsed_time = start.elapsed_time(end)
    print(f'elapsed_time: {elapsed_time}')
    return output

Question

What could be causing the rollout time to increase when reducing the num_hidden_layers from 30 to 15 in the deepseek-llm-7b-chat model? Are there any configuration or implementation issues that might lead to this performance degradation?

@PeterSH6
Copy link
Collaborator

I'm not sure why this will happen. I wonder how you modify the num_hidden_layers? Do you make sure that both the vllm model and megatron model configuration are modified correctly?

I have three suggestions for debugging:

  1. You may also investigate the number of GPU blocks from vLLM logging in the two configurations to check whether their allocated KVCache are different.
  2. Test this setting using the official vLLM offline inference script
  3. Make sure that you have clean up GPU mem in different runs

@metaqiang
Copy link
Author

Thank you! We modify the num_hidden_layers by changing the hugging face config file. Is this way wrong? How should we change num_hidden_layers better?

image

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

2 participants