Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Adding the benchmarking with TorchBench #5788

Merged
merged 36 commits into from
Nov 13, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
36 commits
Select commit Hold shift + click to select a range
f2ed93b
Initial commit with dummy model benchmark
Liyang90 Jan 10, 2023
5984de6
add XRT support
Liyang90 Jan 10, 2023
8d0f78e
Add torchbench benchmark models
Liyang90 Jan 12, 2023
ca20c6f
add randomize_input
Liyang90 Jan 12, 2023
4ae79e2
add model set up for torchbench model
Liyang90 Jan 13, 2023
35d446f
update ExperimentLoader
Liyang90 Jan 13, 2023
4bb173d
Add saving results
Liyang90 Jan 18, 2023
980cd5b
minor args update
Liyang90 Jan 18, 2023
7cc45fb
sync with master
Liyang90 Jan 18, 2023
a410666
update style
Liyang90 Jan 18, 2023
9e1506e
add experiment name
Liyang90 Jan 19, 2023
7591cbc
add grad context for eval and train
Liyang90 Jan 19, 2023
6e9a327
minor user config update
Liyang90 Jan 20, 2023
1939fe0
fix train() return item
Liyang90 Jan 20, 2023
8a50910
minor refactor
Liyang90 Jan 21, 2023
bacda04
add dynamo options
Liyang90 Jan 23, 2023
0bc642e
add column in result for dynamo setting
Liyang90 Jan 23, 2023
3851180
using to capture output and error
Liyang90 Jan 24, 2023
83c674c
Fix some failure cases for dynamo
Liyang90 Jan 25, 2023
113830f
reduce eval result size by returning eval loss
Liyang90 Jan 26, 2023
de609b8
minor refactor
Liyang90 Jan 26, 2023
2d94836
revert eval result change
Liyang90 Jan 26, 2023
c2ad278
minor fix
Liyang90 Feb 6, 2023
88fee74
Change output format to jsonl
Liyang90 Feb 10, 2023
deb1482
Add accelerator model nname
Liyang90 Feb 10, 2023
4993534
sync with master
Liyang90 Feb 10, 2023
10c52a7
add skipping finished experiments
Liyang90 Feb 11, 2023
3b3724c
main process needs to remove PJRT_DEVICE env var that is automaticall…
Liyang90 Feb 14, 2023
668f289
Add a simple result analyzer
Liyang90 Feb 15, 2023
1e787a7
Result analyzer save to database csv with historical data
Liyang90 Feb 16, 2023
80a3fd6
Handle detectron2 models
Liyang90 Mar 15, 2023
ad12d41
Merge branch 'master' into benchmark
Liyang90 Mar 15, 2023
a756875
minor update
Liyang90 Mar 15, 2023
aec61fb
Merge branch 'master' into benchmark
Liyang90 May 1, 2023
307578c
add deny list
Liyang90 May 19, 2023
84dff00
Merge branch 'pytorch:benchmark' into benchmark
Liyang90 Nov 10, 2023
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
183 changes: 183 additions & 0 deletions benchmarks/benchmark_experiment.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,183 @@
from collections import OrderedDict
import logging
import os
import torch
import torch._dynamo as dynamo

try:
from .util import is_xla_device_available, get_accelerator_model
except ImportError:
from util import is_xla_device_available, get_accelerator_model

try:
import torch_xla.core.xla_model as xm
except ImportError:
# ignore the error if torch_xla is not installed
pass

logger = logging.getLogger(__name__)


class ExperimentLoader:

def __init__(self, args):
self._args = args
self.experiment_name = self._args.experiment_name

def expand_config_choices(self, config_choices):
configs = [{}]

for key, choices in config_choices.items():
tmp_configs = []
for config in configs:
for choice in choices:
tmp_config = config.copy()
tmp_config[key] = choice
tmp_configs.append(tmp_config)
configs = tmp_configs

return configs

def list_experiment_configs(self):
if self.experiment_name == "run_all":
config_choices = {
"accelerator": ["cpu", "gpu", "tpu"],
"xla": [None, "PJRT", "XRT"],
"dynamo": [
None, "inductor", "torchxla_trace_once", "aot_torchxla_trace_once"
],
"test": ["eval", "train"],
}

if self._args.accelerator:
config_choices["accelerator"] = list(set(self._args.accelerator))
if self._args.xla:
config_choices["xla"] = list(set(self._args.xla))
config_choices["xla"] = [
x if x != "None" else None for x in config_choices["xla"]
]
if self._args.dynamo:
config_choices["dynamo"] = list(set(self._args.dynamo))
config_choices["dynamo"] = [
x if x != "None" else None for x in config_choices["dynamo"]
]
if self._args.test:
config_choices["test"] = list(set(self._args.test))
else:
raise NotImplementedError

experiment_configs = []
for experiment_config in self.expand_config_choices(config_choices):
if not self.is_available(experiment_config):
continue

self._add_experiment_env(experiment_config)
experiment_configs.append(experiment_config)
return experiment_configs

def is_available(self, experiment_config):
if experiment_config["dynamo"] and experiment_config[
"dynamo"] not in dynamo.list_backends(exclude_tags=()):
return False
if experiment_config["dynamo"] == "inductor" and not (
experiment_config["accelerator"] == "gpu" and
not experiment_config["xla"]):
return False
if experiment_config["dynamo"] == "torchxla_trace_once" and not (
experiment_config["xla"] and experiment_config["test"] == "eval"):
return False
if experiment_config["dynamo"] == "aot_torchxla_trace_once" and not (
experiment_config["xla"] and experiment_config["test"] == "train"):
return False
if (experiment_config["xla"] and
not is_xla_device_available(experiment_config["accelerator"].upper())):
return False
if (experiment_config["accelerator"] == "tpu" and
not experiment_config["xla"]):
return False
if (experiment_config["accelerator"] == "gpu" and
not experiment_config["xla"] and not torch.cuda.is_available()):
return False
return True

def _add_experiment_env(self, experiment_config):
process_env = None
if experiment_config["xla"]:
# remove env vars that would interfere with subprocess settings
os.environ.pop("PJRT_DEVICE", None)
os.environ.pop("XRT_TPU_CONFIG", None)
if experiment_config["xla"] == "PJRT":
process_env = os.environ.copy()
process_env["PJRT_DEVICE"] = experiment_config["accelerator"].upper()
elif experiment_config["xla"] == "XRT":
process_env = os.environ.copy()
if is_xla_device_available("TPU"):
process_env["TPU_NUM_DEVICES"] = "1"
process_env["XRT_TPU_CONFIG"] = "localservice;0;localhost:51011"
elif is_xla_device_available("GPU"):
process_env["GPU_NUM_DEVICES"] = "1"
elif not experiment_config["xla"] and is_xla_device_available(
experiment_config["accelerator"].upper()):
# In non-xla CPU training experiments, an env var is still needed if an
# xla device exists, or there will be "Missing XLA configuration" error.
process_env = os.environ.copy()
process_env["PJRT_DEVICE"] = experiment_config["accelerator"].upper()

experiment_config["process_env"] = process_env

def load_experiment(self, experiment_config, dummy=False):
experiment_name = self.experiment_name
accelerator = experiment_config["accelerator"]
xla = experiment_config["xla"]
dynamo = experiment_config["dynamo"]
test = experiment_config["test"]
batch_size = experiment_config.get("batch_size", self._args.batch_size)
benchmark_experiment = BenchmarkExperiment(
experiment_name=experiment_name,
accelerator=accelerator,
xla=xla,
dynamo=dynamo,
test=test,
batch_size=batch_size)

return benchmark_experiment


class BenchmarkExperiment:

def __init__(self, experiment_name, accelerator, xla, dynamo, test,
batch_size):
self.experiment_name = experiment_name
self.accelerator = accelerator
self.xla = xla
self.dynamo = dynamo
self.test = test
self.batch_size = batch_size
self.accelerator_model = get_accelerator_model(self.accelerator)

def get_device(self):
if self.xla:
device = xm.xla_device(devkind=self.accelerator.upper())
elif self.accelerator == "cpu":
device = torch.device("cpu")
elif self.accelerator == "gpu":
device = torch.device("cuda")
else:
raise NotImplementedError

return device

@property
def filename_str(self):
return "-".join(self.to_dict().values())

def to_dict(self):
d = OrderedDict()
d["experiment_name"] = self.experiment_name
d["accelerator"] = self.accelerator
d["accelerator_model"] = self.accelerator_model
d["xla"] = self.xla
d["dynamo"] = self.dynamo
d["test"] = self.test
d["batch_size"] = self.batch_size
return d
172 changes: 172 additions & 0 deletions benchmarks/benchmark_model.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,172 @@
from collections import OrderedDict
import logging
import re
import torch
import torch.nn as nn
import torch._dynamo as dynamo
from torch._dynamo.testing import collect_results
import types

try:
from .util import move_to_device
except ImportError:
from util import move_to_device

logger = logging.getLogger(__name__)


class ModelLoader:

def __init__(self, args):
self._args = args
self.suite_name = self._args.suite_name
self.benchmark_model_class = BenchmarkModel

def list_model_configs(self):
model_configs = [
{
"model_name": "dummy"
},
]

return model_configs

def is_compatible(self, dummy_benchmark_model, benchmark_experiment):
return True

def get_benchmark_indices(self, length):
start = self._args.partition_id * (length // self._args.total_partitions)
end = ((self._args.partition_id + 1) *
(length // self._args.total_partitions)
if self._args.partition_id < self._args.total_partitions - 1 else
length)
return start, end

def skip_model(self, model_name):
return (not re.search("|".join(self._args.filter), model_name, re.I) or
re.search("|".join(self._args.exclude), model_name, re.I))

def load_model(self, model_config, benchmark_experiment, dummy=False):
suite_name = self.suite_name
model_name = model_config["model_name"]
benchmark_model = self.benchmark_model_class(
suite_name=suite_name,
model_name=model_name,
benchmark_experiment=benchmark_experiment,
)

if not dummy:
benchmark_model.set_up()
benchmark_model.prepare_for_experiment()

return benchmark_model


class BenchmarkModel:

def __init__(self, suite_name, model_name, benchmark_experiment):
self.suite_name = suite_name
self.model_name = model_name
self.benchmark_experiment = benchmark_experiment

def set_up(self):
"""Set up module, actual batch_size, example_inputs, and optimizer_class

This is model suite specific.
"""
if self.model_name != "dummy":
raise NotImplementedError

self.module = nn.Sequential(
nn.Linear(32, 512),
nn.ReLU(),
nn.Linear(512, 512),
nn.ReLU(),
nn.Linear(512, 32),
nn.Softmax(dim=1),
)

self.benchmark_experiment.batch_size = 16
self.example_inputs = (torch.rand(self.benchmark_experiment.batch_size,
32),)
self.optimizer_class = torch.optim.Adam

def _prepare_for_eval(self):
self.module.eval()
self.model_iter_fn = self.eval

def _prepare_for_train(self):
self.module.train()
self.model_iter_fn = self.train
if self.benchmark_experiment.dynamo == "aot_torchxla_trace_once":
# TODO: dynamo aot_torchxla_trace_once would fail if there is an
# optimizer.
# This makes the aot_torchxla_trace_once results not comparable
# with other training results
self.optimizer = None
else:
if not hasattr(self, "optimizer"):
# For some special models, self.set_up() may have initialized an
# optimizer to use. So only initialize it when there is none existing.
self.optimizer = self.optimizer_class(self.module.parameters(), lr=0.01)

def prepare_for_experiment(self):
self.device = self.benchmark_experiment.get_device()
self.module = self.module.to(self.device)
self.example_inputs = move_to_device(self.example_inputs, self.device)

if self.benchmark_experiment.test == "eval":
self._prepare_for_eval()
elif self.benchmark_experiment.test == "train":
self._prepare_for_train()
else:
raise NotImplementedError

if self.benchmark_experiment.dynamo:
self.model_iter_fn = torch.compile(self.model_iter_fn,
backend=self.benchmark_experiment.dynamo)

def pick_grad(self):
if self.benchmark_experiment.test == "eval":
return torch.no_grad()
elif self.benchmark_experiment.test == "train":
return torch.enable_grad()

def _optimizer_zero_grad(self):
if self.optimizer is not None:
self.optimizer.zero_grad(True)
else:
self.module.zero_grad(True)

def _optimizer_step(self):
if self.optimizer is not None:
self.optimizer.step()

def compute_loss(self, pred):
raise NotImplementedError

def train(self, inputs, collect_full_output=False):
self._optimizer_zero_grad()
pred = self.module(*inputs)
loss = self.compute_loss(pred)
loss.backward()
self._optimizer_step()
if collect_full_output:
return collect_results(self.module, pred, loss, inputs)
# return loss.detach()
# TODO: dynamo inductor would fail if .detach() is used
return None

def eval(self, inputs, collect_full_output=False):
pred = self.module(*inputs)
return pred

@property
def filename_str(self):
return "-".join(self.to_dict().values())

def to_dict(self):
d = OrderedDict()
d["suite_name"] = self.suite_name
d["model_name"] = self.model_name
return d
Loading