-
Notifications
You must be signed in to change notification settings - Fork 360
'a:f32[32,12,512,512] = add_any b c' cannot be delayed in apply_grad.py #808
Comments
This seems an OOM error. If the model is not very large, maybe it is related to this issue. For example, the XLA from jaxlib will try to get 90% of all available memory by default. If there is two XLA launched separately(alpa's jaxlib and tensorflow), one may only get 90% of only 10% of all memory, because another lib already occupies 90% of all memory |
Thank you for answering. This is my script: https://gist.github.com/cksmll/1663014f698d15c3fb6a665578ad7c99 |
Actually WIP with the bug: #807 |
It seems like another bug...I'll try to fix it. |
@cksmll In alpa, we monkey patches the rng from jax's stateless version to tf's stateful one, and jax uses a specific dtype for its rng, which is not handled. As a walkaround, could you please try to move the key our of the |
Please describe the bug
I use the alpa for T5-base model and get some errors below.
I get the baseline code(run_t5_mlm_flax.py) in https://github.com/huggingface/transformers/blob/main/examples/flax/language-modeling/run_t5_mlm_flax.py
And revise the code similar to alpa/example/opt_finetune/run_clm_flax.py
Overall training code and run_script is here: https://gist.github.com/cksmll/1663014f698d15c3fb6a665578ad7c99
Screenshots
Error:
The error is in the v0.2.2
When I get the main branch in this github by pip install git+https://github.com/alpa-projects/alpa@main
I got the dictionary key error(asssertionError)
Please describe the expected behavior
I want to run the alpa's pipeshard class in T5
System information and environment
To Reproduce
Steps to reproduce the behavior:
The text was updated successfully, but these errors were encountered: