-
Notifications
You must be signed in to change notification settings - Fork 487
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Adding the benchmarking with TorchBench (#5788)
* Initial commit with dummy model benchmark * add XRT support * Add torchbench benchmark models * add randomize_input * add model set up for torchbench model * update ExperimentLoader * Add saving results * minor args update * update style * add experiment name * add grad context for eval and train * minor user config update * fix train() return item * minor refactor * add dynamo options * add column in result for dynamo setting * using to capture output and error * Fix some failure cases for dynamo * reduce eval result size by returning eval loss * minor refactor * revert eval result change * minor fix * Change output format to jsonl * Add accelerator model nname * add skipping finished experiments * main process needs to remove PJRT_DEVICE env var that is automatically added * Add a simple result analyzer * Result analyzer save to database csv with historical data * Handle detectron2 models * minor update * add deny list
- Loading branch information
Showing
6 changed files
with
1,354 additions
and
0 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,183 @@ | ||
from collections import OrderedDict | ||
import logging | ||
import os | ||
import torch | ||
import torch._dynamo as dynamo | ||
|
||
try: | ||
from .util import is_xla_device_available, get_accelerator_model | ||
except ImportError: | ||
from util import is_xla_device_available, get_accelerator_model | ||
|
||
try: | ||
import torch_xla.core.xla_model as xm | ||
except ImportError: | ||
# ignore the error if torch_xla is not installed | ||
pass | ||
|
||
logger = logging.getLogger(__name__) | ||
|
||
|
||
class ExperimentLoader: | ||
|
||
def __init__(self, args): | ||
self._args = args | ||
self.experiment_name = self._args.experiment_name | ||
|
||
def expand_config_choices(self, config_choices): | ||
configs = [{}] | ||
|
||
for key, choices in config_choices.items(): | ||
tmp_configs = [] | ||
for config in configs: | ||
for choice in choices: | ||
tmp_config = config.copy() | ||
tmp_config[key] = choice | ||
tmp_configs.append(tmp_config) | ||
configs = tmp_configs | ||
|
||
return configs | ||
|
||
def list_experiment_configs(self): | ||
if self.experiment_name == "run_all": | ||
config_choices = { | ||
"accelerator": ["cpu", "gpu", "tpu"], | ||
"xla": [None, "PJRT", "XRT"], | ||
"dynamo": [ | ||
None, "inductor", "torchxla_trace_once", "aot_torchxla_trace_once" | ||
], | ||
"test": ["eval", "train"], | ||
} | ||
|
||
if self._args.accelerator: | ||
config_choices["accelerator"] = list(set(self._args.accelerator)) | ||
if self._args.xla: | ||
config_choices["xla"] = list(set(self._args.xla)) | ||
config_choices["xla"] = [ | ||
x if x != "None" else None for x in config_choices["xla"] | ||
] | ||
if self._args.dynamo: | ||
config_choices["dynamo"] = list(set(self._args.dynamo)) | ||
config_choices["dynamo"] = [ | ||
x if x != "None" else None for x in config_choices["dynamo"] | ||
] | ||
if self._args.test: | ||
config_choices["test"] = list(set(self._args.test)) | ||
else: | ||
raise NotImplementedError | ||
|
||
experiment_configs = [] | ||
for experiment_config in self.expand_config_choices(config_choices): | ||
if not self.is_available(experiment_config): | ||
continue | ||
|
||
self._add_experiment_env(experiment_config) | ||
experiment_configs.append(experiment_config) | ||
return experiment_configs | ||
|
||
def is_available(self, experiment_config): | ||
if experiment_config["dynamo"] and experiment_config[ | ||
"dynamo"] not in dynamo.list_backends(exclude_tags=()): | ||
return False | ||
if experiment_config["dynamo"] == "inductor" and not ( | ||
experiment_config["accelerator"] == "gpu" and | ||
not experiment_config["xla"]): | ||
return False | ||
if experiment_config["dynamo"] == "torchxla_trace_once" and not ( | ||
experiment_config["xla"] and experiment_config["test"] == "eval"): | ||
return False | ||
if experiment_config["dynamo"] == "aot_torchxla_trace_once" and not ( | ||
experiment_config["xla"] and experiment_config["test"] == "train"): | ||
return False | ||
if (experiment_config["xla"] and | ||
not is_xla_device_available(experiment_config["accelerator"].upper())): | ||
return False | ||
if (experiment_config["accelerator"] == "tpu" and | ||
not experiment_config["xla"]): | ||
return False | ||
if (experiment_config["accelerator"] == "gpu" and | ||
not experiment_config["xla"] and not torch.cuda.is_available()): | ||
return False | ||
return True | ||
|
||
def _add_experiment_env(self, experiment_config): | ||
process_env = None | ||
if experiment_config["xla"]: | ||
# remove env vars that would interfere with subprocess settings | ||
os.environ.pop("PJRT_DEVICE", None) | ||
os.environ.pop("XRT_TPU_CONFIG", None) | ||
if experiment_config["xla"] == "PJRT": | ||
process_env = os.environ.copy() | ||
process_env["PJRT_DEVICE"] = experiment_config["accelerator"].upper() | ||
elif experiment_config["xla"] == "XRT": | ||
process_env = os.environ.copy() | ||
if is_xla_device_available("TPU"): | ||
process_env["TPU_NUM_DEVICES"] = "1" | ||
process_env["XRT_TPU_CONFIG"] = "localservice;0;localhost:51011" | ||
elif is_xla_device_available("GPU"): | ||
process_env["GPU_NUM_DEVICES"] = "1" | ||
elif not experiment_config["xla"] and is_xla_device_available( | ||
experiment_config["accelerator"].upper()): | ||
# In non-xla CPU training experiments, an env var is still needed if an | ||
# xla device exists, or there will be "Missing XLA configuration" error. | ||
process_env = os.environ.copy() | ||
process_env["PJRT_DEVICE"] = experiment_config["accelerator"].upper() | ||
|
||
experiment_config["process_env"] = process_env | ||
|
||
def load_experiment(self, experiment_config, dummy=False): | ||
experiment_name = self.experiment_name | ||
accelerator = experiment_config["accelerator"] | ||
xla = experiment_config["xla"] | ||
dynamo = experiment_config["dynamo"] | ||
test = experiment_config["test"] | ||
batch_size = experiment_config.get("batch_size", self._args.batch_size) | ||
benchmark_experiment = BenchmarkExperiment( | ||
experiment_name=experiment_name, | ||
accelerator=accelerator, | ||
xla=xla, | ||
dynamo=dynamo, | ||
test=test, | ||
batch_size=batch_size) | ||
|
||
return benchmark_experiment | ||
|
||
|
||
class BenchmarkExperiment: | ||
|
||
def __init__(self, experiment_name, accelerator, xla, dynamo, test, | ||
batch_size): | ||
self.experiment_name = experiment_name | ||
self.accelerator = accelerator | ||
self.xla = xla | ||
self.dynamo = dynamo | ||
self.test = test | ||
self.batch_size = batch_size | ||
self.accelerator_model = get_accelerator_model(self.accelerator) | ||
|
||
def get_device(self): | ||
if self.xla: | ||
device = xm.xla_device(devkind=self.accelerator.upper()) | ||
elif self.accelerator == "cpu": | ||
device = torch.device("cpu") | ||
elif self.accelerator == "gpu": | ||
device = torch.device("cuda") | ||
else: | ||
raise NotImplementedError | ||
|
||
return device | ||
|
||
@property | ||
def filename_str(self): | ||
return "-".join(self.to_dict().values()) | ||
|
||
def to_dict(self): | ||
d = OrderedDict() | ||
d["experiment_name"] = self.experiment_name | ||
d["accelerator"] = self.accelerator | ||
d["accelerator_model"] = self.accelerator_model | ||
d["xla"] = self.xla | ||
d["dynamo"] = self.dynamo | ||
d["test"] = self.test | ||
d["batch_size"] = self.batch_size | ||
return d |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,172 @@ | ||
from collections import OrderedDict | ||
import logging | ||
import re | ||
import torch | ||
import torch.nn as nn | ||
import torch._dynamo as dynamo | ||
from torch._dynamo.testing import collect_results | ||
import types | ||
|
||
try: | ||
from .util import move_to_device | ||
except ImportError: | ||
from util import move_to_device | ||
|
||
logger = logging.getLogger(__name__) | ||
|
||
|
||
class ModelLoader: | ||
|
||
def __init__(self, args): | ||
self._args = args | ||
self.suite_name = self._args.suite_name | ||
self.benchmark_model_class = BenchmarkModel | ||
|
||
def list_model_configs(self): | ||
model_configs = [ | ||
{ | ||
"model_name": "dummy" | ||
}, | ||
] | ||
|
||
return model_configs | ||
|
||
def is_compatible(self, dummy_benchmark_model, benchmark_experiment): | ||
return True | ||
|
||
def get_benchmark_indices(self, length): | ||
start = self._args.partition_id * (length // self._args.total_partitions) | ||
end = ((self._args.partition_id + 1) * | ||
(length // self._args.total_partitions) | ||
if self._args.partition_id < self._args.total_partitions - 1 else | ||
length) | ||
return start, end | ||
|
||
def skip_model(self, model_name): | ||
return (not re.search("|".join(self._args.filter), model_name, re.I) or | ||
re.search("|".join(self._args.exclude), model_name, re.I)) | ||
|
||
def load_model(self, model_config, benchmark_experiment, dummy=False): | ||
suite_name = self.suite_name | ||
model_name = model_config["model_name"] | ||
benchmark_model = self.benchmark_model_class( | ||
suite_name=suite_name, | ||
model_name=model_name, | ||
benchmark_experiment=benchmark_experiment, | ||
) | ||
|
||
if not dummy: | ||
benchmark_model.set_up() | ||
benchmark_model.prepare_for_experiment() | ||
|
||
return benchmark_model | ||
|
||
|
||
class BenchmarkModel: | ||
|
||
def __init__(self, suite_name, model_name, benchmark_experiment): | ||
self.suite_name = suite_name | ||
self.model_name = model_name | ||
self.benchmark_experiment = benchmark_experiment | ||
|
||
def set_up(self): | ||
"""Set up module, actual batch_size, example_inputs, and optimizer_class | ||
This is model suite specific. | ||
""" | ||
if self.model_name != "dummy": | ||
raise NotImplementedError | ||
|
||
self.module = nn.Sequential( | ||
nn.Linear(32, 512), | ||
nn.ReLU(), | ||
nn.Linear(512, 512), | ||
nn.ReLU(), | ||
nn.Linear(512, 32), | ||
nn.Softmax(dim=1), | ||
) | ||
|
||
self.benchmark_experiment.batch_size = 16 | ||
self.example_inputs = (torch.rand(self.benchmark_experiment.batch_size, | ||
32),) | ||
self.optimizer_class = torch.optim.Adam | ||
|
||
def _prepare_for_eval(self): | ||
self.module.eval() | ||
self.model_iter_fn = self.eval | ||
|
||
def _prepare_for_train(self): | ||
self.module.train() | ||
self.model_iter_fn = self.train | ||
if self.benchmark_experiment.dynamo == "aot_torchxla_trace_once": | ||
# TODO: dynamo aot_torchxla_trace_once would fail if there is an | ||
# optimizer. | ||
# This makes the aot_torchxla_trace_once results not comparable | ||
# with other training results | ||
self.optimizer = None | ||
else: | ||
if not hasattr(self, "optimizer"): | ||
# For some special models, self.set_up() may have initialized an | ||
# optimizer to use. So only initialize it when there is none existing. | ||
self.optimizer = self.optimizer_class(self.module.parameters(), lr=0.01) | ||
|
||
def prepare_for_experiment(self): | ||
self.device = self.benchmark_experiment.get_device() | ||
self.module = self.module.to(self.device) | ||
self.example_inputs = move_to_device(self.example_inputs, self.device) | ||
|
||
if self.benchmark_experiment.test == "eval": | ||
self._prepare_for_eval() | ||
elif self.benchmark_experiment.test == "train": | ||
self._prepare_for_train() | ||
else: | ||
raise NotImplementedError | ||
|
||
if self.benchmark_experiment.dynamo: | ||
self.model_iter_fn = torch.compile(self.model_iter_fn, | ||
backend=self.benchmark_experiment.dynamo) | ||
|
||
def pick_grad(self): | ||
if self.benchmark_experiment.test == "eval": | ||
return torch.no_grad() | ||
elif self.benchmark_experiment.test == "train": | ||
return torch.enable_grad() | ||
|
||
def _optimizer_zero_grad(self): | ||
if self.optimizer is not None: | ||
self.optimizer.zero_grad(True) | ||
else: | ||
self.module.zero_grad(True) | ||
|
||
def _optimizer_step(self): | ||
if self.optimizer is not None: | ||
self.optimizer.step() | ||
|
||
def compute_loss(self, pred): | ||
raise NotImplementedError | ||
|
||
def train(self, inputs, collect_full_output=False): | ||
self._optimizer_zero_grad() | ||
pred = self.module(*inputs) | ||
loss = self.compute_loss(pred) | ||
loss.backward() | ||
self._optimizer_step() | ||
if collect_full_output: | ||
return collect_results(self.module, pred, loss, inputs) | ||
# return loss.detach() | ||
# TODO: dynamo inductor would fail if .detach() is used | ||
return None | ||
|
||
def eval(self, inputs, collect_full_output=False): | ||
pred = self.module(*inputs) | ||
return pred | ||
|
||
@property | ||
def filename_str(self): | ||
return "-".join(self.to_dict().values()) | ||
|
||
def to_dict(self): | ||
d = OrderedDict() | ||
d["suite_name"] = self.suite_name | ||
d["model_name"] = self.model_name | ||
return d |
Oops, something went wrong.