-
-
Notifications
You must be signed in to change notification settings - Fork 899
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
- Loading branch information
Showing
26 changed files
with
1,281 additions
and
1,052 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,59 @@ | ||
--- | ||
title: Batch size vs Gradient accumulation | ||
description: Understanding of batch size and gradient accumulation steps | ||
--- | ||
|
||
Gradient accumulation means accumulating gradients over several mini-batches and updating the model weights afterward. When the samples in each batch are diverse, this technique doesn't significantly impact learning. | ||
|
||
This method allows for effective training with larger effective batch sizes without needing proportionally larger memory. Here's why: | ||
|
||
1. **Memory Consumption with Batch Size**: The primary reason increasing the batch size impacts memory is due to the storage requirements for intermediate activations. When you forward propagate a batch through a network, you have to store the activations at each layer for each sample in the batch, because these activations are used during backpropagation to compute gradients. Therefore, larger batches mean more activations, leading to greater GPU memory consumption. | ||
|
||
2. **Gradient Accumulation**: With gradient accumulation, you're effectively simulating a larger batch size by accumulating gradients over several smaller batches (or micro-batches). However, at any given time, you're only forward and backward propagating a micro-batch. This means you only store activations for the micro-batch, not the full accumulated batch. As a result, you can simulate the effect of a larger batch size without the memory cost of storing activations for a large batch. | ||
|
||
**Example 1:** | ||
Micro batch size: 3 | ||
Gradient accumulation steps: 2 | ||
Number of GPUs: 3 | ||
Total batch size = 3 * 2 * 3 = 18 | ||
|
||
``` | ||
| GPU 1 | GPU 2 | GPU 3 | | ||
|----------------|----------------|----------------| | ||
| S1, S2, S3 | S4, S5, S6 | S7, S8, S9 | | ||
| e1, e2, e3 | e4, e5, e6 | e7, e8, e9 | | ||
|----------------|----------------|----------------| | ||
| → (accumulate) | → (accumulate) | → (accumulate) | | ||
|----------------|----------------|----------------| | ||
| S10, S11, S12 | S13, S14, S15 | S16, S17, S18 | | ||
| e10, e11, e12 | e13, e14, e15 | e16, e17, e18 | | ||
|----------------|----------------|----------------| | ||
| → (apply) | → (apply) | → (apply) | | ||
Accumulated gradient for the weight w1 after the second iteration (considering all GPUs): | ||
Total gradient for w1 = e1 + e2 + e3 + e4 + e5 + e6 + e7 + e8 + e9 + e10 + e11 + e12 + e13 + e14 + e15 + e16 + e17 + e18 | ||
Weight update for w1: | ||
w1_new = w1_old - learning rate x (Total gradient for w1 / 18) | ||
``` | ||
|
||
**Example 2:** | ||
Micro batch size: 2 | ||
Gradient accumulation steps: 1 | ||
Number of GPUs: 3 | ||
Total batch size = 2 * 1 * 3 = 6 | ||
|
||
``` | ||
| GPU 1 | GPU 2 | GPU 3 | | ||
|-----------|-----------|-----------| | ||
| S1, S2 | S3, S4 | S5, S6 | | ||
| e1, e2 | e3, e4 | e5, e6 | | ||
|-----------|-----------|-----------| | ||
| → (apply) | → (apply) | → (apply) | | ||
Accumulated gradient for the weight w1 (considering all GPUs): | ||
Total gradient for w1 = e1 + e2 + e3 + e4 + e5 + e6 | ||
Weight update for w1: | ||
w1_new = w1_old - learning rate × (Total gradient for w1 / 6) | ||
``` |
Oops, something went wrong.