Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Poential memory leak for axolotl v0.5.2 pretrain streaming datasets with liger kernel #2108

Open
6 of 8 tasks
deter3 opened this issue Nov 30, 2024 · 4 comments
Open
6 of 8 tasks
Assignees
Labels
bug Something isn't working

Comments

@deter3
Copy link

deter3 commented Nov 30, 2024

Please check that this issue hasn't been reported before.

  • I searched previous Bug Reports didn't find any similar reports.

Expected Behavior

I am using axolotl v0.5.2 and liger for llama 3.2 1B continued pretraining or llama 3.1 7b continued pretraining . Previous axolotl 0.4 without liger works perfectly with the same parameters , datasets and GPUs.

Current behaviour

With axolotl v0.5.2 and liger with same parameters , datasets and GPUs , the CPU memory keep increasing until full and the training will be killed .

Steps to reproduce

Just simply run the yaml file for training ,wait for 2-5 hours , the training will be stopped . I tested both llama 3.2 1B continued pretraining(2 A40 48Gb) or llama 3.1 7b continued pretraining (8 H100 90gb) , results are the same .

runing on RUNPOD template : runpod/pytorch:2.4.0-py3.11-cuda12.4.1-devel-ubuntu22.04

image

Config yaml

base_model: meta-llama/Llama-3.2-1B
model_type: LlamaForCausalLM
tokenizer_type: AutoTokenizer
trust_remote_code: true

plugins:
  - axolotl.integrations.liger.LigerPlugin
liger_rope: true
liger_rms_norm: true
liger_glu_activation: true
liger_fused_linear_cross_entropy: true

load_in_8bit: false
load_in_4bit: false
strict: false

max_steps: 1000
pretraining_dataset:
  - path: private datasets with just "text" section in train 
    type: completion
dataset_prepared_path: last_run_prepared
val_set_size: 0.00
output_dir: ./outputs/model-out1

sequence_len: 48500
sample_packing: true
pad_to_sequence_len: true
shuffle_merged_datasets: true


wandb_project: llama32-1b-pretraining
wandb_entity:
wandb_watch:
wandb_name:
wandb_log_model: llama32-1b

gradient_accumulation_steps: 4
micro_batch_size: 4
num_epochs: 5
optimizer: adamw_bnb_8bit
lr_scheduler: cosine
learning_rate: 0.0001

train_on_inputs: false
group_by_length: false
bf16: auto
fp16:
tf32: false

gradient_checkpointing: true
gradient_checkpointing_kwargs:
  use_reentrant: true
max_grad_norm: 1.0
noisy_embedding_alpha: 5
early_stopping_patience:
resume_from_checkpoint:
local_rank:
logging_steps: 1
xformers_attention:
flash_attention: true

warmup_steps: 30
evals_per_epoch: 
eval_table_size:
saves_per_epoch: 3
#eval_steps: 30 
save_steps: 
save_total_limit: 10
debug:
deepspeed: axolotl/deepspeed_configs/zero3_bf16.json
weight_decay: 0.0
fsdp:
fsdp_config:
special_tokens:
  pad_token: <|end_of_text|>

Possible solution

change the dataset streaming to dataset download will keep the cpu memory from increasing . change pretraining_dataset: to datasets:

base_model: meta-llama/Llama-3.2-1B
model_type: LlamaForCausalLM
tokenizer_type: AutoTokenizer
trust_remote_code: true

plugins:
  - axolotl.integrations.liger.LigerPlugin
liger_rope: true
liger_rms_norm: true
liger_glu_activation: true
liger_fused_linear_cross_entropy: true

load_in_8bit: false
load_in_4bit: false
strict: false

max_steps: 1000
datasets:
  - path: private datasets with just "text" section in train 
    type: completion
dataset_prepared_path: last_run_prepared
val_set_size: 0.00
output_dir: ./outputs/model-out1

sequence_len: 48500
sample_packing: true
pad_to_sequence_len: true
shuffle_merged_datasets: true


wandb_project: llama32-1b-pretraining
wandb_entity:
wandb_watch:
wandb_name:
wandb_log_model: llama32-1b

gradient_accumulation_steps: 4
micro_batch_size: 4
num_epochs: 5
optimizer: adamw_bnb_8bit
lr_scheduler: cosine
learning_rate: 0.0001

train_on_inputs: false
group_by_length: false
bf16: auto
fp16:
tf32: false

gradient_checkpointing: true
gradient_checkpointing_kwargs:
  use_reentrant: true
max_grad_norm: 1.0
noisy_embedding_alpha: 5
early_stopping_patience:
resume_from_checkpoint:
local_rank:
logging_steps: 1
xformers_attention:
flash_attention: true

warmup_steps: 30
evals_per_epoch: 
eval_table_size:
saves_per_epoch: 3
#eval_steps: 30 
save_steps: 
save_total_limit: 10
debug:
deepspeed: axolotl/deepspeed_configs/zero3_bf16.json
weight_decay: 0.0
fsdp:
fsdp_config:
special_tokens:
  pad_token: <|end_of_text|>

Which Operating Systems are you using?

  • Linux
  • macOS
  • Windows

Python Version

4.10

axolotl branch-commit

main

Acknowledgements

  • My issue title is concise, descriptive, and in title casing.
  • I have searched the existing issues to make sure this bug has not been reported yet.
  • I am using the latest version of axolotl.
  • I have provided enough information for the maintainers to reproduce and diagnose the issue.
@deter3 deter3 added the bug Something isn't working label Nov 30, 2024
@winglian
Copy link
Collaborator

winglian commented Dec 1, 2024

@deter3 Are you using torch 2.1.0 or did you upgrade to newer version?

@deter3
Copy link
Author

deter3 commented Dec 1, 2024

@deter3 Are you using torch 2.1.0 or did you upgrade to newer version?

runpod/pytorch:2.4.0-py3.11-cuda12.4.1-devel-ubuntu22.04 , it's 2.4.0 . @winglian

I noticed when streaming the data from huggingface , the CPU memory increased a lot and never reduced .

If pretaining dataset is downloaded directly , CPU memory is not increasing at all .

below image , 1st smem is before streaming the data from huggingface , and the 2nd smem is the after streaming the data from huggingface .

CAF97F20-601F-4ACD-9F42-3BB178BFAE2C

@deter3 deter3 changed the title Poential memory leak for axolotl v0.5.2 with liger kernel Poential memory leak for axolotl v0.5.2 pretrain streaming datasets with liger kernel Dec 4, 2024
@ByronHsu
Copy link
Contributor

ByronHsu commented Dec 5, 2024

Seems like cpu memory has little to do with liger kernel. I am wondering if you have tried using axolotl v0.5.2 without liger kernel?

@deter3
Copy link
Author

deter3 commented Dec 6, 2024

Seems like cpu memory has little to do with liger kernel. I am wondering if you have tried using axolotl v0.5.2 without liger kernel?
@ByronHsu I did not . I found out the main reason might be streaming dataset , which has nothing to do with liger kernel .

@bursteratom bursteratom self-assigned this Dec 6, 2024
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
bug Something isn't working
Projects
None yet
Development

No branches or pull requests

4 participants