Skip to content
This repository has been archived by the owner on Oct 19, 2024. It is now read-only.

'a:f32[32,12,512,512] = add_any b c' cannot be delayed in apply_grad.py #808

Open
cksmll opened this issue Dec 7, 2022 · 8 comments
Open
Assignees

Comments

@cksmll
Copy link

cksmll commented Dec 7, 2022

Please describe the bug
I use the alpa for T5-base model and get some errors below.
I get the baseline code(run_t5_mlm_flax.py) in https://github.com/huggingface/transformers/blob/main/examples/flax/language-modeling/run_t5_mlm_flax.py
And revise the code similar to alpa/example/opt_finetune/run_clm_flax.py
Overall training code and run_script is here: https://gist.github.com/cksmll/1663014f698d15c3fb6a665578ad7c99

Screenshots
Error:
image
The error is in the v0.2.2

When I get the main branch in this github by pip install git+https://github.com/alpa-projects/alpa@main
I got the dictionary key error(asssertionError)
Please describe the expected behavior
I want to run the alpa's pipeshard class in T5
System information and environment

  • OS Platform and Distribution (e.g., Linux Ubuntu 16.04, docker): Ubuntu 20.04
  • Python version: 3.8
  • CUDA version:11.3
  • NCCL version: 2.9.6
  • cupy version: cupy-cuda113
  • GPU model and memory: RTX 3090

image

  • Alpa version: v0.2.2
  • TensorFlow version: 2.10.0
  • JAX version:0.3.22

To Reproduce
Steps to reproduce the behavior:

  1. This is my script: https://gist.github.com/cksmll/1663014f698d15c3fb6a665578ad7c99
  2. To get the real data and tokenizer, you need some time but in this script, I run the fake_batch (jnp.ones) in training loop. So you could turn off the loading of data and tokenizer.
@cksmll
Copy link
Author

cksmll commented Dec 7, 2022

Oh, After getting the new version It runs. (Until a few days ago, the latest code did not solve it.)

However, it only runs when the tensorflow is not installed.
If I run the script after pip3 install tensorflow, the following memory error occurs.
image

I wonder why

@ZYHowell
Copy link
Collaborator

ZYHowell commented Dec 7, 2022

This seems an OOM error. If the model is not very large, maybe it is related to this issue. For example, the XLA from jaxlib will try to get 90% of all available memory by default. If there is two XLA launched separately(alpa's jaxlib and tensorflow), one may only get 90% of only 10% of all memory, because another lib already occupies 90% of all memory

@cksmll
Copy link
Author

cksmll commented Dec 8, 2022

Thank you for answering.
But again I got the dictionary keyError in Alpa@main version.
This is the error for T5-11b model. (huggingface: https://huggingface.co/t5-11b)
image
This error can't see in T5-base, T5-large.
As you advise, there is no tensorflow.

This is my script: https://gist.github.com/cksmll/1663014f698d15c3fb6a665578ad7c99

@ZYHowell
Copy link
Collaborator

ZYHowell commented Dec 8, 2022

Actually WIP with the bug: #807

@ZYHowell ZYHowell added the known bug Something isn't working label Dec 8, 2022
@ZYHowell
Copy link
Collaborator

ZYHowell commented Dec 9, 2022

@cksmll Could you please try the nightly alpa after #807?

@cksmll
Copy link
Author

cksmll commented Dec 10, 2022

Thanks. It finally works well with T5-11b!!!
Lastly, I want to ask you one minor thing. (Sorry for bothering you)

For training the T5, I have to turn off the drop_out layer (setting the 'Train=False' in train_step function).
Otherwise, if I turn on the drop_out layer like this. It prints error.
error:
image
How I apply the drop_out layer:
image

So, To use the alpa, Is it right to turn off the drop_out layer?

@ZYHowell
Copy link
Collaborator

It seems like another bug...I'll try to fix it.

@merrymercy merrymercy added unknown error and removed known bug Something isn't working labels Dec 20, 2022
@ZYHowell
Copy link
Collaborator

@cksmll In alpa, we monkey patches the rng from jax's stateless version to tf's stateful one, and jax uses a specific dtype for its rng, which is not handled. As a walkaround, could you please try to move the key our of the loss_fn? (and imo, for the literal correctness, you need to move it out of train_step to make each iteration not dropout at the same position)

Sign up for free to subscribe to this conversation on GitHub. Already have an account? Sign in.
Projects
None yet
Development

No branches or pull requests

3 participants