Skip to content

Commit

Permalink
[benchmarks] Small fixes for benchmarking script. (#6632)
Browse files Browse the repository at this point in the history
  • Loading branch information
ysiraichi authored Feb 29, 2024
1 parent e385c2f commit 4cde625
Show file tree
Hide file tree
Showing 2 changed files with 11 additions and 11 deletions.
18 changes: 9 additions & 9 deletions benchmarks/torchbench_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,6 @@
import importlib
import logging
import os
from os.path import abspath, exists
import sys
import torch
import torch.amp
Expand Down Expand Up @@ -306,18 +305,19 @@ def batch_size(self):
def load_benchmark(self):
cant_change_batch_size = (
not getattr(self.benchmark_cls(), "ALLOW_CUSTOMIZE_BSIZE", True) or
model_name in config_data()["dont_change_batch_size"])
self.model_name in config_data()["dont_change_batch_size"])

if cant_change_batch_size:
self.benchmark_experiment.batch_size = None

if self.benchmark_experiment.batch_size is not None:
batch_size = self.benchmark_experiment.batch_size
elif self.is_training() and self.model_name in self.batch_size["training"]:
batch_size = self.batch_size["training"][self.model_name]
elif self.is_inference(
) and self.model_name in self.batch_size["inference"]:
batch_size = self.batch_size["inference"][self.model_name]
batch_size = self.benchmark_experiment.batch_size

if batch_size is None:
if self.is_training() and self.model_name in self.batch_size["training"]:
batch_size = self.batch_size["training"][self.model_name]
elif self.is_inference(
) and self.model_name in self.batch_size["inference"]:
batch_size = self.batch_size["inference"][self.model_name]

# workaround "RuntimeError: not allowed to set torch.backends.cudnn flags"
# torch.backends.__allow_nonbracketed_mutation_flag = True
Expand Down
4 changes: 2 additions & 2 deletions benchmarks/util.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
import logging
import numpy as np
import os
from os.path import abspath
from os.path import abspath, exists
import random
import subprocess
import torch
Expand Down Expand Up @@ -158,7 +158,7 @@ def get_torchbench_test_name(test):
return {"train": "training", "eval": "inference"}[test]


def find_near_file(self, names):
def find_near_file(names):
"""Find a file near the current directory.
Looks for `names` in the current directory, up to its two direct parents.
Expand Down

0 comments on commit 4cde625

Please sign in to comment.