This repository has been archived by the owner on Oct 19, 2024. It is now read-only.
[BUG] Inaccurate memory profiling during pipeline stage construction #684
Labels
known bug
Something isn't working
When constructing pipeline stages with Alpa's auto inter-operator parallel algorithm, we need to accurately profile the memory usage of a constructed pipeline stage to determine whether a solution is valid or not. A pipeline stage can be decomposed into the following 3 parts:
apply_grad
)Given
forward_layers
,backward_layers
, andupdate_layers
. The current profiling workflow is:forward_layers
andbackward_layers
into acompute_part
. There is a special hook between forward and backward layers that mark the variables between them asintermediate_variables
. These variables need to be stored during the execution offorward_layers
for the execution of thebackward_layers
and can only be deleted afterbackward_layers
finishes. During pipeline execution, we need to store multiple sets ofintermediate_variables
because there are multiple micro-batches on-the-fly.compute_part
and theupdate_layers
.compute_part
andupdate_layers
.compute_part
.Currently, we measure the following memory:
peak_memory
: The peak memory achieved in thecompute_part
.intermediate_size
: The size of allintermediate_variables
.initial_size
: The size of input tensors toupdate_layers
. Typically optimizer states.available_memory
: The size of GPU memory.And we calculate the maximal number of micro-batches we can store on-the-fly with:
max_n_micro_batches = (available_memory - peak_memory - initial_size) / intermediate_size
. (And we setmax_n_succ_stages = max_n_micro_batches
per 1F1B pipeline schedule). Note that this is an under-estimate of how many batches we can fit: Actually inpeak_memory
, there is one copy ofintermediate_variables
. We don’t count this copy because when profiling forpeak_memory
, in the backward pass, the memory reserved forintermediate_variables
will be freed as the variables becomes inactive.To fix this issue, there are two solutions:
intermediate_variables
incompute_part
, which can force these variables not to be freed.compute_part
with shard parallel, profileforward_part
andbackward_part
separately with the pipeshard runtime.cc @ZYHowell
The text was updated successfully, but these errors were encountered: