Skip to content
This repository has been archived by the owner on Oct 19, 2024. It is now read-only.

[BUG] Inaccurate memory profiling during pipeline stage construction #684

Open
zhuohan123 opened this issue Aug 31, 2022 · 2 comments
Open
Labels
known bug Something isn't working

Comments

@zhuohan123
Copy link
Member

When constructing pipeline stages with Alpa's auto inter-operator parallel algorithm, we need to accurately profile the memory usage of a constructed pipeline stage to determine whether a solution is valid or not. A pipeline stage can be decomposed into the following 3 parts:

  1. Forward
  2. Backward
  3. Update (apply_grad)

Given forward_layers, backward_layers, and update_layers. The current profiling workflow is:

  1. Merge forward_layers and backward_layers into a compute_part. There is a special hook between forward and backward layers that mark the variables between them as intermediate_variables. These variables need to be stored during the execution of forward_layers for the execution of the backward_layers and can only be deleted after backward_layers finishes. During pipeline execution, we need to store multiple sets of intermediate_variables because there are multiple micro-batches on-the-fly.
  2. Merge the compute_part and the update_layers.
  3. Run auto-sharding pass to shard the stage and get the sharding spec of all tensors.
  4. Decouple the sharded compute_part and update_layers.
  5. Compile and profile the compute cost of the compute_part.

Currently, we measure the following memory:

  1. peak_memory: The peak memory achieved in the compute_part.
  2. intermediate_size: The size of all intermediate_variables.
  3. initial_size: The size of input tensors to update_layers. Typically optimizer states.
  4. available_memory: The size of GPU memory.

And we calculate the maximal number of micro-batches we can store on-the-fly with: max_n_micro_batches = (available_memory - peak_memory - initial_size) / intermediate_size. (And we set max_n_succ_stages = max_n_micro_batches per 1F1B pipeline schedule). Note that this is an under-estimate of how many batches we can fit: Actually in peak_memory, there is one copy of intermediate_variables. We don’t count this copy because when profiling for peak_memory, in the backward pass, the memory reserved for intermediate_variables will be freed as the variables becomes inactive.

To fix this issue, there are two solutions:

  1. Hot fix (fast but dirty): output intermediate_variables in compute_part, which can force these variables not to be freed.
  2. Clean fix (clean but slow): Instead of profiling a single compute_part with shard parallel, profile forward_part and backward_part separately with the pipeshard runtime.

cc @ZYHowell

@zhuohan123 zhuohan123 added the known bug Something isn't working label Aug 31, 2022
@merrymercy
Copy link
Member

Note that this is an under-estimate of how many batches we can fit:

Do you mean over-estimation? Because we don't count intermediate_variables, the value we computed is higher than the actual value.

Another question: can we just use max_n_micro_batches = (available_memory - peak_memory - intermediate_size - initial_size) / intermediate_size? This should be a safe underestimation, but I don't know whether it is too inaccurate.

@zhuohan123
Copy link
Member Author

No, we count about one extra intermediate_variables, so the memory we estimate is more than the actual memory used. So it's an under-estimation of how many batches we can fit.

Sign up for free to subscribe to this conversation on GitHub. Already have an account? Sign in.
Labels
known bug Something isn't working
Projects
None yet
Development

No branches or pull requests

2 participants