diff --git a/examples/images/resnet/README.md b/examples/images/resnet/README.md index c69828637269..9a7493ea31a6 100644 --- a/examples/images/resnet/README.md +++ b/examples/images/resnet/README.md @@ -49,8 +49,8 @@ python eval.py -c ./ckpt-low_level_zero -e 80 Expected accuracy performance will be: -| Model | Single-GPU Baseline FP32 | Booster DDP with FP32 | Booster DDP with FP16 | Booster Low Level Zero | -| --------- | ------------------------ | --------------------- | --------------------- | ---------------------- | -| ResNet-18 | 85.85% | 84.91% | 85.46% | 84.50% | +| Model | Single-GPU Baseline FP32 | Booster DDP with FP32 | Booster DDP with FP16 | Booster Low Level Zero | Booster Gemini | +| --------- | ------------------------ | --------------------- | --------------------- | ---------------------- | -------------- | +| ResNet-18 | 85.85% | 84.91% | 85.46% | 84.50% | 84.60% | **Note: the baseline is adapted from the [script](https://pytorch-tutorial.readthedocs.io/en/latest/tutorial/chapter03_intermediate/3_2_2_cnn_resnet_cifar10/) to use `torchvision.models.resnet18`** diff --git a/examples/images/resnet/train.py b/examples/images/resnet/train.py index fe0dabf08377..fa300395c9f3 100644 --- a/examples/images/resnet/train.py +++ b/examples/images/resnet/train.py @@ -104,7 +104,7 @@ def main(): '--plugin', type=str, default='torch_ddp', - choices=['torch_ddp', 'torch_ddp_fp16', 'low_level_zero'], + choices=['torch_ddp', 'torch_ddp_fp16', 'low_level_zero', 'gemini'], help="plugin to use") parser.add_argument('-r', '--resume', type=int, default=-1, help="resume from the epoch's checkpoint") parser.add_argument('-c', '--checkpoint', type=str, default='./checkpoint', help="checkpoint directory") @@ -141,7 +141,7 @@ def main(): if args.plugin.startswith('torch_ddp'): plugin = TorchDDPPlugin() elif args.plugin == 'gemini': - plugin = GeminiPlugin(placement_policy='cuda', strict_ddp_mode=True, initial_scale=2**5) + plugin = GeminiPlugin(initial_scale=2**5) elif args.plugin == 'low_level_zero': plugin = LowLevelZeroPlugin(initial_scale=2**5)