Skip to content

Commit

Permalink
[example] update resnet example result
Browse files Browse the repository at this point in the history
  • Loading branch information
ver217 committed Aug 16, 2023
1 parent 27dc751 commit 2020a9a
Show file tree
Hide file tree
Showing 2 changed files with 5 additions and 5 deletions.
6 changes: 3 additions & 3 deletions examples/images/resnet/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -49,8 +49,8 @@ python eval.py -c ./ckpt-low_level_zero -e 80

Expected accuracy performance will be:

| Model | Single-GPU Baseline FP32 | Booster DDP with FP32 | Booster DDP with FP16 | Booster Low Level Zero |
| --------- | ------------------------ | --------------------- | --------------------- | ---------------------- |
| ResNet-18 | 85.85% | 84.91% | 85.46% | 84.50% |
| Model | Single-GPU Baseline FP32 | Booster DDP with FP32 | Booster DDP with FP16 | Booster Low Level Zero | Booster Gemini |
| --------- | ------------------------ | --------------------- | --------------------- | ---------------------- | -------------- |
| ResNet-18 | 85.85% | 84.91% | 85.46% | 84.50% | 84.60% |

**Note: the baseline is adapted from the [script](https://pytorch-tutorial.readthedocs.io/en/latest/tutorial/chapter03_intermediate/3_2_2_cnn_resnet_cifar10/) to use `torchvision.models.resnet18`**
4 changes: 2 additions & 2 deletions examples/images/resnet/train.py
Original file line number Diff line number Diff line change
Expand Up @@ -104,7 +104,7 @@ def main():
'--plugin',
type=str,
default='torch_ddp',
choices=['torch_ddp', 'torch_ddp_fp16', 'low_level_zero'],
choices=['torch_ddp', 'torch_ddp_fp16', 'low_level_zero', 'gemini'],
help="plugin to use")
parser.add_argument('-r', '--resume', type=int, default=-1, help="resume from the epoch's checkpoint")
parser.add_argument('-c', '--checkpoint', type=str, default='./checkpoint', help="checkpoint directory")
Expand Down Expand Up @@ -141,7 +141,7 @@ def main():
if args.plugin.startswith('torch_ddp'):
plugin = TorchDDPPlugin()
elif args.plugin == 'gemini':
plugin = GeminiPlugin(placement_policy='cuda', strict_ddp_mode=True, initial_scale=2**5)
plugin = GeminiPlugin(initial_scale=2**5)
elif args.plugin == 'low_level_zero':
plugin = LowLevelZeroPlugin(initial_scale=2**5)

Expand Down

0 comments on commit 2020a9a

Please sign in to comment.