Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

bug: Crash due to tensor size mismatch at the end of an epoch #411

Open
HarikrishnanBalagopal opened this issue Dec 9, 2024 · 9 comments

Comments

@HarikrishnanBalagopal
Copy link
Contributor

HarikrishnanBalagopal commented Dec 9, 2024

Describe the bug

Crash when training reaches the last batch of an epoch.
Using granite-20b-base model.
Using jsonl dataset with input and output columns.
The error seems to be a tensor size mismatch causing torch concatenate operation to crash.

Platform

Please provide details about the environment you are using, including the following:

Running in an Openshift GPU cluster.

Sample Code

accelerate launch \
  --use_fsdp \
  --fsdp_auto_wrap_policy=TRANSFORMER_BASED_WRAP \
  --fsdp_forward_prefetch=false \
  --fsdp_offload_params=false \
  --fsdp_sharding_strategy=FULL_SHARD \
  --fsdp_state_dict_type=FULL_STATE_DICT \
  --fsdp_cpu_ram_efficient_loading=true \
  --fsdp_sync_module_states=true \
  --num_processes=8 \
  --dynamo_backend="no" \
  --machine_rank="${RANK}" \
  --main_process_ip="${MASTER_ADDR}" \
  --main_process_port="${MASTER_PORT}" \
  --mixed_precision="no" \
  --num_machines="${WORLD_SIZE}" \
  --rdzv_backend="static" \
  --same_network \
  -m tuning.sft_trainer \
  --adam_beta1="0.9" \
  --adam_beta2="0.98" \
  --adam_epsilon="1e-10" \
  --aim_repo="${AIMSTACK_DB}" \
  --data_config_path="dataset_config.yaml" \
  --dataloader_drop_last="true" \
  --evaluation_strategy="no" \
  --experiment="train-6608e801-85b8-4463-bc38-17f8a3d4a87f" \
  --gradient_accumulation_steps="4" \
  --gradient_checkpointing="true" \
  --learning_rate="1e-05" \
  --log_level="debug" \
  --logging_steps="5" \
  --logging_strategy="steps" \
  --lr_scheduler_type="cosine" \
  --max_steps="2250" \
  --model_name_or_path="/modeling/models/granite-20b-base-ept-merged-70-30" \
  --optim="adamw_torch" \
  --output_dir="/modeling/checkpoints/train-6608e801-85b8-4463-bc38-17f8a3d4a87f" \
  --packing="False" \
  --per_device_train_batch_size="8" \
  --save_steps="250" \
  --save_strategy="steps" \
  --split_batches="true" \
  --torch_dtype="bfloat16" \
  --tracker="aim" \
  --training_data_path= \
  --use_flash_attn="true" \
  --use_reentrant="true" \
  --warmup_ratio="0.1" \
  --warmup_steps="200" \
  --weight_decay="0.1"

The dataset_config.yaml file:

    train_datasets:
      - path: /modeling/data/instruction_tuning/train/train_1000_rows.jsonl
        prob: 1.0
    # column name where input data is available
    input_feature: "input"
    # column name where output data is available
    output_feature: "output"
    streaming: false
    data_sampler: "samples_based"

Expected behavior

Training should finish without errors.

Observed behavior

Crash on last batch of the epoch

DEBUG:__init__.py:incrementing progress idx -> 111
 34%|███▎      | 754/2250 [01:50<04:18,  5.78it/s]DEBUG:__init__.py:incrementing progress idx -> 112
{'loss': 0.3445, 'grad_norm': 11.25, 'learning_rate': 8.297929097711207e-06, 'epoch': 0.34}
 34%|███▎      | 755/2250 [01:57<05:03,  4.93it/s]DEBUG:__init__.py:incrementing progress idx -> 113
 34%|███▎      | 756/2250 [02:03<06:27,  3.86it/s]DEBUG:__init__.py:incrementing progress idx -> 114
DEBUG:__init__.py:incrementing progress idx -> 115
 34%|███▎      | 757/2250 [02:09<07:54,  3.15it/s]ERROR:sft_trainer.py:Traceback (most recent call last):
  File "/home/tuning/.local/lib/python3.11/site-packages/tuning/sft_trainer.py", line 646, in main
    trainer, additional_train_info = train(
                                     ^^^^^^
  File "/home/tuning/.local/lib/python3.11/site-packages/tuning/sft_trainer.py", line 432, in train
    trainer.train(resume_from_checkpoint)
  File "/home/tuning/.local/lib/python3.11/site-packages/transformers/trainer.py", line 2123, in train
    return inner_training_loop(
           ^^^^^^^^^^^^^^^^^^^^
  File "/home/tuning/.local/lib/python3.11/site-packages/transformers/trainer.py", line 2427, in _inner_training_loop
    batch_samples, num_items_in_batch = self.get_batch_samples(epoch_iterator, num_batches)
                                        ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/tuning/.local/lib/python3.11/site-packages/transformers/trainer.py", line 5045, in get_batch_samples
    batch_samples += [next(epoch_iterator)]
                      ^^^^^^^^^^^^^^^^^^^^
  File "/home/tuning/.local/lib/python3.11/site-packages/accelerate/data_loader.py", line 840, in __iter__
    batch = concatenate([batch, first_batch], dim=0)
            ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/tuning/.local/lib/python3.11/site-packages/accelerate/utils/operations.py", line 621, in concatenate
    return type(data[0])({k: concatenate([d[k] for d in data], dim=dim) for k in data[0].keys()})
                         ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/tuning/.local/lib/python3.11/site-packages/accelerate/utils/operations.py", line 621, in 
    return type(data[0])({k: concatenate([d[k] for d in data], dim=dim) for k in data[0].keys()})
                             ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/tuning/.local/lib/python3.11/site-packages/accelerate/utils/operations.py", line 624, in concatenate
    return torch.cat(data, dim=dim)
           ^^^^^^^^^^^^^^^^^^^^^^^^
RuntimeError: Sizes of tensors must match except in dimension 0. Expected size 899 but got size 766 for tensor number 1 in the list.
 34%|███▎      | 757/2250 [02:09<04:15,  5.84it/s]
DEBUG:run.py:Closing resource 
@HarikrishnanBalagopal HarikrishnanBalagopal changed the title Crash due to tensor size mismatch at the end of an epoch bug: Crash due to tensor size mismatch at the end of an epoch Dec 9, 2024
@ashokponkumar
Copy link
Collaborator

@HarikrishnanBalagopal please recreate the issue if you are able to reproduce it in main branch.

@HarikrishnanBalagopal
Copy link
Contributor Author


ERROR:sft_trainer.py:Traceback (most recent call last):
  File "/home/tuning/.local/lib/python3.11/site-packages/tuning/sft_trainer.py", line 646, in main
    trainer, additional_train_info = train(
                                     ^^^^^^
  File "/home/tuning/.local/lib/python3.11/site-packages/tuning/sft_trainer.py", line 432, in train
    trainer.train(resume_from_checkpoint)
  File "/home/tuning/.local/lib/python3.11/site-packages/transformers/trainer.py", line 2123, in train
    return inner_training_loop(
           ^^^^^^^^^^^^^^^^^^^^
  File "/home/tuning/.local/lib/python3.11/site-packages/transformers/trainer.py", line 2534, in _inner_training_loop
    self.optimizer.step()
  File "/home/tuning/.local/lib/python3.11/site-packages/accelerate/optimizer.py", line 171, in step
    self.optimizer.step(closure)
  File "/home/tuning/.local/lib/python3.11/site-packages/torch/optim/lr_scheduler.py", line 137, in wrapper
    return func.__get__(opt, opt.__class__)(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/tuning/.local/lib/python3.11/site-packages/torch/optim/optimizer.py", line 487, in wrapper
    out = func(*args, **kwargs)
          ^^^^^^^^^^^^^^^^^^^^^
  File "/home/tuning/.local/lib/python3.11/site-packages/torch/optim/optimizer.py", line 91, in _use_grad
    ret = func(self, *args, **kwargs)
          ^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/tuning/.local/lib/python3.11/site-packages/torch/optim/adamw.py", line 220, in step
    adamw(
  File "/home/tuning/.local/lib/python3.11/site-packages/torch/optim/optimizer.py", line 154, in maybe_fallback
    return func(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^
  File "/home/tuning/.local/lib/python3.11/site-packages/torch/optim/adamw.py", line 782, in adamw
    func(
  File "/home/tuning/.local/lib/python3.11/site-packages/torch/optim/adamw.py", line 531, in _multi_tensor_adamw
    torch._foreach_lerp_(device_exp_avgs, device_grads, 1 - beta1)
RuntimeError: The size of tensor a (44041728) must match the size of tensor b (6144) at non-singleton dimension 1

  1%|▏         | 31/2250 [02:35<3:05:56,  5.03s/it]

@HarikrishnanBalagopal
Copy link
Contributor Author

The bug doesn't seem to happen with the main branch official image with AIM na.artifactory.swg-devops.com/wcp-ai-foundation-team-docker-virtual/sft-trainer-aim:38bbcf5_ubi9_py311

accelerate launch \
  --use_fsdp \
  --fsdp_auto_wrap_policy=TRANSFORMER_BASED_WRAP \
  --fsdp_forward_prefetch=false \
  --fsdp_offload_params=false \
  --fsdp_sharding_strategy=FULL_SHARD \
  --fsdp_state_dict_type=FULL_STATE_DICT \
  --fsdp_cpu_ram_efficient_loading=true \
  --fsdp_sync_module_states=true \
  --num_processes=8 \
  --dynamo_backend="no" \
  --machine_rank="${RANK}" \
  --main_process_ip="${MASTER_ADDR}" \
  --main_process_port="${MASTER_PORT}" \
  --mixed_precision="no" \
  --num_machines="${WORLD_SIZE}" \
  --rdzv_backend="static" \
  --same_network \
  -m tuning.sft_trainer \
  --adam_beta1="0.9" \
  --adam_beta2="0.98" \
  --adam_epsilon="1e-10" \
  --aim_repo="${AIMSTACK_DB}" \
  --data_config_path="dataset_config.yaml" \
  --dataloader_drop_last="true" \
  --evaluation_strategy="no" \
  --experiment="train-03577468-9d72-4c2a-baaf-3336242d597f" \
  --gradient_accumulation_steps="4" \
  --gradient_checkpointing="true" \
  --learning_rate="1e-05" \
  --log_level="debug" \
  --logging_steps="5" \
  --logging_strategy="steps" \
  --lr_scheduler_type="cosine" \
  --max_steps="2250" \
  --model_name_or_path="/modeling/models/granite-20b-base-ept-merged-70-30" \
  --optim="adamw_torch" \
  --output_dir="/modeling/checkpoints/train-03577468-9d72-4c2a-baaf-3336242d597f" \
  --packing="False" \
  --per_device_train_batch_size="8" \
  --save_steps="250" \
  --save_strategy="steps" \
  --split_batches="true" \
  --torch_dtype="bfloat16" \
  --tracker="aim" \
  --use_flash_attn="true" \
  --use_reentrant="true" \
  --warmup_ratio="0.1" \
  --warmup_steps="200" \
  --weight_decay="0.1"
{"data": {"epoch": 0.16, "step": 5, "timestamp": "2024-12-18T08:20:19.540051", "value": 0.559}, "name": "training_loss"}
{"data": {"epoch": 0.32, "step": 10, "timestamp": "2024-12-18T08:20:44.335258", "value": 0.5034}, "name": "training_loss"}
{"data": {"epoch": 0.48, "step": 15, "timestamp": "2024-12-18T08:21:10.139879", "value": 0.5204}, "name": "training_loss"}
{"data": {"epoch": 0.64, "step": 20, "timestamp": "2024-12-18T08:21:35.006164", "value": 0.4812}, "name": "training_loss"}
{"data": {"epoch": 0.8, "step": 25, "timestamp": "2024-12-18T08:22:00.369152", "value": 0.4855}, "name": "training_loss"}
{"data": {"epoch": 0.96, "step": 30, "timestamp": "2024-12-18T08:22:22.809873", "value": 0.4168}, "name": "training_loss"}
{"data": {"epoch": 1.12, "step": 35, "timestamp": "2024-12-18T08:22:47.667162", "value": 0.4999}, "name": "training_loss"}
{"data": {"epoch": 1.28, "step": 40, "timestamp": "2024-12-18T08:23:11.652898", "value": 0.4688}, "name": "training_loss"}
{"data": {"epoch": 1.44, "step": 45, "timestamp": "2024-12-18T08:23:34.835309", "value": 0.4106}, "name": "training_loss"}
{"data": {"epoch": 1.6, "step": 50, "timestamp": "2024-12-18T08:23:58.842765", "value": 0.4867}, "name": "training_loss"}
{"data": {"epoch": 1.76, "step": 55, "timestamp": "2024-12-18T08:24:23.748723", "value": 0.4496}, "name": "training_loss"}
{"data": {"epoch": 1.92, "step": 60, "timestamp": "2024-12-18T08:24:48.970410", "value": 0.4115}, "name": "training_loss"}
{"data": {"epoch": 2.08, "step": 65, "timestamp": "2024-12-18T08:25:12.456075", "value": 0.3914}, "name": "training_loss"}
{"data": {"epoch": 2.24, "step": 70, "timestamp": "2024-12-18T08:25:38.317406", "value": 0.437}, "name": "training_loss"}
{"data": {"epoch": 2.4, "step": 75, "timestamp": "2024-12-18T08:26:02.539473", "value": 0.3401}, "name": "training_loss"}
{"data": {"epoch": 2.56, "step": 80, "timestamp": "2024-12-18T08:26:28.991092", "value": 0.3629}, "name": "training_loss"}
{"data": {"epoch": 2.72, "step": 85, "timestamp": "2024-12-18T08:26:54.300396", "value": 0.4046}, "name": "training_loss"}
{"data": {"epoch": 2.88, "step": 90, "timestamp": "2024-12-18T08:27:18.536794", "value": 0.3466}, "name": "training_loss"}
{"data": {"epoch": 3.04, "step": 95, "timestamp": "2024-12-18T08:27:42.192288", "value": 0.3597}, "name": "training_loss"}
{"data": {"epoch": 3.2, "step": 100, "timestamp": "2024-12-18T08:28:04.970346", "value": 0.3293}, "name": "training_loss"}
{"data": {"epoch": 3.36, "step": 105, "timestamp": "2024-12-18T08:28:29.373667", "value": 0.3678}, "name": "training_loss"}
{"data": {"epoch": 3.52, "step": 110, "timestamp": "2024-12-18T08:28:54.527638", "value": 0.3639}, "name": "training_loss"}
{"data": {"epoch": 3.68, "step": 115, "timestamp": "2024-12-18T08:29:20.037901", "value": 0.3121}, "name": "training_loss"}
{"data": {"epoch": 3.84, "step": 120, "timestamp": "2024-12-18T08:29:43.694220", "value": 0.3009}, "name": "training_loss"}
{"data": {"epoch": 4.0, "step": 125, "timestamp": "2024-12-18T08:30:08.959777", "value": 0.355}, "name": "training_loss"}
{"data": {"epoch": 4.16, "step": 130, "timestamp": "2024-12-18T08:30:34.549892", "value": 0.3049}, "name": "training_loss"}
{"data": {"epoch": 4.32, "step": 135, "timestamp": "2024-12-18T08:30:57.256689", "value": 0.2938}, "name": "training_loss"}
{"data": {"epoch": 4.48, "step": 140, "timestamp": "2024-12-18T08:31:21.639842", "value": 0.3258}, "name": "training_loss"}
{"data": {"epoch": 4.64, "step": 145, "timestamp": "2024-12-18T08:31:48.985025", "value": 0.3153}, "name": "training_loss"}
{"data": {"epoch": 4.8, "step": 150, "timestamp": "2024-12-18T08:32:13.743223", "value": 0.2664}, "name": "training_loss"}
{"data": {"epoch": 4.96, "step": 155, "timestamp": "2024-12-18T08:32:37.979222", "value": 0.3321}, "name": "training_loss"}
{"data": {"epoch": 5.12, "step": 160, "timestamp": "2024-12-18T08:33:02.821396", "value": 0.2714}, "name": "training_loss"}
{"data": {"epoch": 5.28, "step": 165, "timestamp": "2024-12-18T08:33:26.692441", "value": 0.2762}, "name": "training_loss"}
{"data": {"epoch": 5.44, "step": 170, "timestamp": "2024-12-18T08:33:49.990244", "value": 0.2638}, "name": "training_loss"}
{"data": {"epoch": 5.6, "step": 175, "timestamp": "2024-12-18T08:34:14.315508", "value": 0.2657}, "name": "training_loss"}
{"data": {"epoch": 5.76, "step": 180, "timestamp": "2024-12-18T08:34:39.632194", "value": 0.2673}, "name": "training_loss"}
{"data": {"epoch": 5.92, "step": 185, "timestamp": "2024-12-18T08:35:04.150179", "value": 0.25}, "name": "training_loss"}
{"data": {"epoch": 6.08, "step": 190, "timestamp": "2024-12-18T08:35:29.809107", "value": 0.2442}, "name": "training_loss"}
{"data": {"epoch": 6.24, "step": 195, "timestamp": "2024-12-18T08:35:54.034458", "value": 0.2177}, "name": "training_loss"}
{"data": {"epoch": 6.4, "step": 200, "timestamp": "2024-12-18T08:36:19.598923", "value": 0.2299}, "name": "training_loss"}
{"data": {"epoch": 6.56, "step": 205, "timestamp": "2024-12-18T08:36:42.592159", "value": 0.1983}, "name": "training_loss"}
{"data": {"epoch": 6.72, "step": 210, "timestamp": "2024-12-18T08:37:07.855585", "value": 0.2221}, "name": "training_loss"}
{"data": {"epoch": 6.88, "step": 215, "timestamp": "2024-12-18T08:37:33.117416", "value": 0.2615}, "name": "training_loss"}
{"data": {"epoch": 7.04, "step": 220, "timestamp": "2024-12-18T08:37:57.805831", "value": 0.232}, "name": "training_loss"}
{"data": {"epoch": 7.2, "step": 225, "timestamp": "2024-12-18T08:38:24.049801", "value": 0.1917}, "name": "training_loss"}
{"data": {"epoch": 7.36, "step": 230, "timestamp": "2024-12-18T08:38:48.742201", "value": 0.2145}, "name": "training_loss"}
{"data": {"epoch": 7.52, "step": 235, "timestamp": "2024-12-18T08:39:13.978622", "value": 0.1829}, "name": "training_loss"}
{"data": {"epoch": 7.68, "step": 240, "timestamp": "2024-12-18T08:39:36.728941", "value": 0.1867}, "name": "training_loss"}
{"data": {"epoch": 7.84, "step": 245, "timestamp": "2024-12-18T08:40:01.621554", "value": 0.1921}, "name": "training_loss"}
{"data": {"epoch": 8.0, "step": 250, "timestamp": "2024-12-18T08:40:26.636845", "value": 0.1672}, "name": "training_loss"}

@ashokponkumar
Copy link
Collaborator

@kmehant In this case the gradient accumulation steps is less than the number of steps in an epoch, isn't it?

@kmehant
Copy link
Collaborator

kmehant commented Dec 18, 2024

@ashokponkumar let me summarize

There were 2 issues.

  1. One issue which happens at the lerp function in adamw
  2. another one is the one quoted in this issue related to torch.cat

Hari confirms that the issue (2) is not happening with the transformers==4.45.2 and accelerate==1.0.1 in the official image. I am assuming WCA branch when used with these updated dependencies should help us circumvent issue (2).

I was looking at issue (1) which happens to be actively looked by HF team. There is a rewrite of inner_training_loop moving away from using accelerator.accumulate context. This should potentially solve the issue. However, for me the training simply hangs.

@ashokponkumar
Copy link
Collaborator

@ashokponkumar let me summarize

There were 2 issues.

  1. One issue which happens at the lerp function in adamw
  2. another one is the one quoted in this issue related to torch.cat

Sure. Under what conditions does issue 1 happen?

@kmehant
Copy link
Collaborator

kmehant commented Dec 18, 2024

updated my comment, there seems to be some lag

@kmehant
Copy link
Collaborator

kmehant commented Dec 18, 2024

@ashokponkumar @HarikrishnanBalagopal The issue (1) as well gone now with
transformers==4.46.2 The rewrite first started to come from this release.
accelerate==1.2.0

@kmehant kmehant reopened this Dec 18, 2024
@kmehant
Copy link
Collaborator

kmehant commented Dec 18, 2024

Reopened this issue.

We should put a min bounds of transformers>=4.46.2 and min compatible accelerate version for transformers to overcome these 2 issues, of course, after regression testing for performance.

cc: @ashokponkumar @HarikrishnanBalagopal

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

3 participants