Skip to content

Commit

Permalink
update doc
Browse files Browse the repository at this point in the history
  • Loading branch information
Fridge003 committed Oct 13, 2023
1 parent bd56281 commit 722f43d
Show file tree
Hide file tree
Showing 3 changed files with 26 additions and 1 deletion.
13 changes: 13 additions & 0 deletions docs/source/en/features/gradient_accumulation_with_booster.md
Original file line number Diff line number Diff line change
Expand Up @@ -126,6 +126,19 @@ for idx, (img, label) in enumerate(train_dataloader):

```

Currently the plugins supporting `no_sync()` method include `TorchDDPPlugin` and `LowLevelZeroPlugin` set to stage 1. `GeminiPlugin` doesn't support `no_sync()` method, but it can also enable synchronized gradient accumulation in a torch-like way. Following is the code snippet of enabling gradient accumulation for `GeminiPlugin`:
<!--- doc-test-ignore-start -->
```python
output = gemini_model(input)
train_loss = criterion(output, label)
train_loss = train_loss / GRADIENT_ACCUMULATION
booster.backward(train_loss, gemini_optimizer)

if idx % (GRADIENT_ACCUMULATION - 1) == 0:
gemini_optimizer.step() # zero_grad is automatically done
```
<!--- doc-test-ignore-end -->

### Step 6. Invoke Training Scripts
To verify gradient accumulation, we can just check the change of parameter values. When gradient accumulation is set, parameters are only updated in the last step. You can run the script using this command:
```shell
Expand Down
13 changes: 13 additions & 0 deletions docs/source/zh-Hans/features/gradient_accumulation_with_booster.md
Original file line number Diff line number Diff line change
Expand Up @@ -93,6 +93,19 @@ model, optimizer, criterion, train_dataloader, _ = booster.boost(model=model,
dataloader=train_dataloader)
```

目前支持`no_sync()`方法的插件包括 `TorchDDPPlugin``LowLevelZeroPlugin`(需要设置参数`stage`为1). `GeminiPlugin` 不支持 `no_sync()` 方法, 但是它可以通过和`pytorch`类似的方式来使用同步的梯度累积。以下是 `GeminiPlugin` 进行梯度累积的代码片段:
<!--- doc-test-ignore-start -->
```python
output = gemini_model(input)
train_loss = criterion(output, label)
train_loss = train_loss / GRADIENT_ACCUMULATION
booster.backward(train_loss, gemini_optimizer)

if idx % (GRADIENT_ACCUMULATION - 1) == 0:
gemini_optimizer.step() # zero_grad is automatically done
```
<!--- doc-test-ignore-end -->

### 步骤 5. 使用booster训练
使用booster构建一个普通的训练循环,验证梯度累积。 `param_by_iter` 记录分布训练的信息。
```python
Expand Down
1 change: 0 additions & 1 deletion tests/test_zero/test_gemini/test_grad_accum.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,7 +38,6 @@ def check_grad(model: GeminiDDP, torch_model: torch.nn.Module):

# Compare gradients.
for p0, p1 in zip(model.parameters(), torch_model.parameters()):
print(p0, p1.grad)
assert_close(p0, p1.grad, rtol=1e-3, atol=5e-5)

# Release gradient chunks and move them to gradient device.
Expand Down

0 comments on commit 722f43d

Please sign in to comment.