Skip to content

Commit

Permalink
Use SGD for list of models from PyTorch. (#6324)
Browse files Browse the repository at this point in the history
  • Loading branch information
ysiraichi authored and zpcore committed Jan 22, 2024
1 parent 047ea00 commit 514b190
Showing 1 changed file with 33 additions and 3 deletions.
36 changes: 33 additions & 3 deletions benchmarks/torchbench_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,35 @@
"detectron2_fcos_r_50_fpn",
}

# torchbench models that might OOM using Adam.
# This list was extracted from PyTorch's repository: benchmarks/dynamo/common.py
TRAIN_WITH_SGD = {
"BERT_pytorch",
"LearningToPaint",
"alexnet",
"dcgan",
"demucs",
"densenet121",
"dlrm",
"fastNLP_Bert",
"mobilenet_v2",
"phlippe_densenet",
"phlippe_resnet",
"pytorch_stargan",
"resnet18",
"shufflenet_v2_x1_0",
"speech_transformer",
"squeezenet1_1",
"stable_diffusion_text_encoder",
"timm_efficientdet",
"timm_nfnet",
"timm_regnet",
"timm_vision_transformer",
"timm_vovnet",
"vgg16",
"hf_T5",
}

# Skip the experiment of a model if any of the experiment configs in the list is fully matched
DENY_LIST = {
"doctr_det_predictor": [{
Expand Down Expand Up @@ -179,7 +208,10 @@ def set_up(self):
This is model suite specific.
"""
self.optimizer_class = torch.optim.Adam
if self.benchmark_experiment.test == "train" and self.model_name in TRAIN_WITH_SGD:
self.optimizer_class = torch.optim.SGD
else:
self.optimizer_class = torch.optim.Adam

benchmark = self.load_benchmark()

Expand All @@ -205,8 +237,6 @@ def set_up(self):
if self.model_name == "yolov3":
self.example_inputs = (torch.rand(self.benchmark_experiment.batch_size, 3,
384, 512),)
if self.benchmark_experiment.test == "train" and self.model_name in DETECTRON2_MODELS:
self.optimizer = benchmark.optimizer

del benchmark
self._cleanup()
Expand Down

0 comments on commit 514b190

Please sign in to comment.