Skip to content
This repository has been archived by the owner on Oct 19, 2024. It is now read-only.

[PERF] Save/Restore RNG States for Rematerialization #517

Open
merrymercy opened this issue Jun 16, 2022 · 0 comments
Open

[PERF] Save/Restore RNG States for Rematerialization #517

merrymercy opened this issue Jun 16, 2022 · 0 comments
Assignees
Labels
enhancement New feature

Comments

@merrymercy
Copy link
Member

merrymercy commented Jun 16, 2022

Background

If we apply tensor rematerialization to a layer with some RNG computations, we have to make sure RNG computations generate the same results during the first forward and recomputation.

Because alpa uses stateful RNG, we have to save the RNG state during the first forward and restore it during the second forward. The process looks like

# First forward
Forward_1_begin
old_state = get_rng_state()
…
Forward_1_end

# Other layers
...

# Recomputation
Forward_2_begin 
cur_state = get_rng_state()
set_rng_state(old_state)
…
set_rng_state(cur_state)
Forward_2_end

# Backward
...

Alpa and pytorch both use stateful RNG, so megatron-lm also implements the same functionality here. On the other hand, the official jax uses stateless RNG, so it does not have this problem.

Implementation

I recently added some RNG thunk and exposes the rng_state. The rng_state is this global variable on CPU.
We need to implement two additional thunks get_rng_state and set_rng_state to manipulate the state. For example, get_rng_state can copy the states from CPU to GPU as a tensor. We then need to insert custom calls following the above pseudo code.

@merrymercy merrymercy added the known bug Something isn't working label Jun 17, 2022
@merrymercy merrymercy changed the title Save/Restore RNG States for Rematerialization [PERF] Save/Restore RNG States for Rematerialization Sep 9, 2022
@merrymercy merrymercy added enhancement New feature and removed known bug Something isn't working labels Sep 9, 2022
Sign up for free to subscribe to this conversation on GitHub. Already have an account? Sign in.
Labels
enhancement New feature
Projects
None yet
Development

No branches or pull requests

2 participants