You signed in with another tab or window. Reload to refresh your session.You signed out in another tab or window. Reload to refresh your session.You switched accounts on another tab or window. Reload to refresh your session.Dismiss alert
{{ message }}
This repository has been archived by the owner on Oct 19, 2024. It is now read-only.
If we apply tensor rematerialization to a layer with some RNG computations, we have to make sure RNG computations generate the same results during the first forward and recomputation.
Because alpa uses stateful RNG, we have to save the RNG state during the first forward and restore it during the second forward. The process looks like
# First forwardForward_1_beginold_state=get_rng_state()
…
Forward_1_end# Other layers
...
# RecomputationForward_2_begincur_state=get_rng_state()
set_rng_state(old_state)
…
set_rng_state(cur_state)
Forward_2_end# Backward
...
Alpa and pytorch both use stateful RNG, so megatron-lm also implements the same functionality here. On the other hand, the official jax uses stateless RNG, so it does not have this problem.
Implementation
I recently added some RNG thunk and exposes the rng_state. The rng_state is this global variable on CPU.
We need to implement two additional thunks get_rng_state and set_rng_state to manipulate the state. For example, get_rng_state can copy the states from CPU to GPU as a tensor. We then need to insert custom calls following the above pseudo code.
The text was updated successfully, but these errors were encountered:
Background
If we apply tensor rematerialization to a layer with some RNG computations, we have to make sure RNG computations generate the same results during the first forward and recomputation.
Because alpa uses stateful RNG, we have to save the RNG state during the first forward and restore it during the second forward. The process looks like
Alpa and pytorch both use stateful RNG, so megatron-lm also implements the same functionality here. On the other hand, the official jax uses stateless RNG, so it does not have this problem.
Implementation
I recently added some RNG thunk and exposes the rng_state. The rng_state is this global variable on CPU.
We need to implement two additional thunks
get_rng_state
andset_rng_state
to manipulate the state. For example,get_rng_state
can copy the states from CPU to GPU as a tensor. We then need to insert custom calls following the above pseudo code.The text was updated successfully, but these errors were encountered: