From ada54e5c0eb633cb0d612f4709873376afc15f69 Mon Sep 17 00:00:00 2001 From: frgossen Date: Wed, 22 Nov 2023 18:12:17 -0500 Subject: [PATCH] Merge benchmark scripts into the main branch (#5913) * Adding the benchmarking with TorchBench (#5788) * Initial commit with dummy model benchmark * add XRT support * Add torchbench benchmark models * add randomize_input * add model set up for torchbench model * update ExperimentLoader * Add saving results * minor args update * update style * add experiment name * add grad context for eval and train * minor user config update * fix train() return item * minor refactor * add dynamo options * add column in result for dynamo setting * using to capture output and error * Fix some failure cases for dynamo * reduce eval result size by returning eval loss * minor refactor * revert eval result change * minor fix * Change output format to jsonl * Add accelerator model nname * add skipping finished experiments * main process needs to remove PJRT_DEVICE env var that is automatically added * Add a simple result analyzer * Result analyzer save to database csv with historical data * Handle detectron2 models * minor update * add deny list * Create run_benchmark * Rename run_benchmark to run_benchmark.sh * Fix device names and dynamo backend names in benchmark runner (#5806) * update optimizer for openxla * Add benchmark selection by tier 1-3 (#5808) * Apply Pytorch/XLA formatting style (#5816) * Add top tier benchmark runner (#5809) * Add profiling capabilities to experiment_runner.py script (#5812) * update run model config call interface, optimizer and result analyze script * update dependency errir * Add profiling capabilties --------- Co-authored-by: zpcore * benchmarks: add script to aggregate results from result_analyzer (#5829) * benchmarks: extract tiers into their own file So that they can be reused in other files. The second user is coming next. * benchmarks: add aggregate.py This script processes output CSV files from results_analyzer to generate CSV/plots. Example: $ for fmt in csv png; do \ for acc in v100 a6000; do \ for report in latest histogram speedup; do \ for test in training inference; do \ FILENAME=/tmp/png/$acc-$test-$report.$fmt; \ python3 aggregate.py \ --accelerator=$acc \ --test=$test \ -i /tmp/csv-depot \ --report=$report \ --title="All benchmarks" \ --format=$fmt > $FILENAME || break; \ chmod 644 $FILENAME; \ done; \ done; \ done; \ done This generates plots and CSV files to summarize the latest performance vs. Inductor, as well as a histogram and a geomean speedup over time for all the input CSV data in /tmp/csv-depot. Results are broken down per accelerator and either inference or training. To generate results per tier, we just have to pass --filter-by-tier to the above and update the title to --title="Tier 1". * Fix syntax in experiment_runner.py (#5827) * Add flag to forward XLA flags and allow for experiment expansion (#5828) * Add hide-errors flag to result analyzer (#5836) * Add readme and linting * Fix ClusterResolver --------- Co-authored-by: Liyang90 Co-authored-by: Manfei <41607353+ManfeiBai@users.noreply.github.com> Co-authored-by: zpcore Co-authored-by: Grzegorz Olechwierowicz Co-authored-by: Emilio Cota --- .github/workflows/lintercheck.yml | 2 +- .vscode/settings.json | 2 +- benchmarks/README.md | 55 +++ benchmarks/aggregate.py | 348 ++++++++++++++++++ benchmarks/benchmark_experiment.py | 193 ++++++++++ benchmarks/benchmark_model.py | 165 +++++++++ benchmarks/experiment_runner.py | 547 +++++++++++++++++++++++++++++ benchmarks/result_analyzer.py | 191 ++++++++++ benchmarks/run_benchmark.sh | 79 +++++ benchmarks/run_top_tier_bm.sh | 25 ++ benchmarks/tiers.py | 15 + benchmarks/torchbench_model.py | 247 +++++++++++++ benchmarks/util.py | 153 ++++++++ 13 files changed, 2020 insertions(+), 2 deletions(-) create mode 100644 benchmarks/README.md create mode 100755 benchmarks/aggregate.py create mode 100644 benchmarks/benchmark_experiment.py create mode 100644 benchmarks/benchmark_model.py create mode 100644 benchmarks/experiment_runner.py create mode 100644 benchmarks/result_analyzer.py create mode 100644 benchmarks/run_benchmark.sh create mode 100755 benchmarks/run_top_tier_bm.sh create mode 100644 benchmarks/tiers.py create mode 100644 benchmarks/torchbench_model.py create mode 100644 benchmarks/util.py diff --git a/.github/workflows/lintercheck.yml b/.github/workflows/lintercheck.yml index fd50411a8a62..aebae600b9dd 100644 --- a/.github/workflows/lintercheck.yml +++ b/.github/workflows/lintercheck.yml @@ -66,7 +66,7 @@ jobs: exit 1 fi - yapf -i -r *.py test/ scripts/ torch_xla/ + yapf -i -r *.py test/ scripts/ torch_xla/ benchmarks/ git_status=$(git status --porcelain) if [[ $git_status ]]; then git diff diff --git a/.vscode/settings.json b/.vscode/settings.json index 59b86e622e78..7e1221a18343 100644 --- a/.vscode/settings.json +++ b/.vscode/settings.json @@ -19,4 +19,4 @@ ], "python.formatting.provider": "yapf", "editor.formatOnSave": true -} +} \ No newline at end of file diff --git a/benchmarks/README.md b/benchmarks/README.md new file mode 100644 index 000000000000..51ab754400bf --- /dev/null +++ b/benchmarks/README.md @@ -0,0 +1,55 @@ +# Benchmarking + +The two main benchmarking scripts are + - `experiment_runner.py` to run benchmark experiments, and + - `result_analyzer.py` to aggregate the benchmark result in CSV form. + + +## Experiment runner + +Run the `experiment_runner.py` from the `pytorch` directory, which should be the +parent of the `xla` directory. + +The following example runs the alexnet benchmark on GPU through the +Pytorch/XLA-dynamo path and through the Inductor-dynamo with 5 repetitions each. +The results will be stored in a json file in `experiment_results`. + +``` +cd pytorch +python xla/benchmarks/experiment_runner.py \ + --dynamo=openxla_eval --dynamo=openxla --dynamo=inductor \ + --xla=PJRT --xla=None \ + --test=eval --test=train \ + --suite-name=torchbench \ + --accelerator=cuda \ + --output-dirname=experiment_results \ + --repeat=5 \ + --print-subprocess \ + --no-resume \ + --filter="^alexnet$" +``` + +You can change the flags to add the configurations you are interested in. The +`experiment_runner.py` will expand the options to all supported configurations. +For example, in the case above, it will consider all the possible combinations +among the flags `--dynamo`, `--xla`, and `--test`, 4 of which are supported: + + - `dynamo=openxla_eval`, `xla=PJRT`, `test=eval` + - `dynamo=openxla`, `xla=PJRT`, `test=train` + - `dynamo=inductor`, `xla=None`, `test=eval` + - `dynamo=inductor`, `xla=None`, `test=train` + + +## Result analyzer + +Run the `result_analyzer.py` from the `pytorch` directory, which should be the +parent of the `xla` directory. + +The following example analyzes the results generated by the above invocation of +`experiment_runner.py`. The aggregates are saved in CSV format in +`experiment_results/metric_report.csv`. + +``` +cd pytorch +python xla/benchmarks/result_analyzer.py --output-dirname=experiment_results +``` diff --git a/benchmarks/aggregate.py b/benchmarks/aggregate.py new file mode 100755 index 000000000000..e32f4ae30d91 --- /dev/null +++ b/benchmarks/aggregate.py @@ -0,0 +1,348 @@ +"""Processes .csv result files and aggregates them.""" +# TODO: support more plots: +# - Speedup of Inductor and PytorchXLA over the oldest Inductor data set. +# This will allow us to have a sense of how fast Inductor is improving +# as well as PytorchXLA. +# - Number of working Inductor and PytorchXLA workloads. + +import argparse +import csv +from datetime import date +import logging +import os +import re +import sys +import matplotlib.pyplot as plt +from typing import Any +import numpy as np +from scipy.stats.mstats import gmean + +try: + from .tiers import append_filter_by_tier +except ImportError: + from tiers import append_filter_by_tier + +logger = logging.getLogger(__name__) + + +def find_files(input_dirname: str) -> list[str]: + files = [] + for root, _, filenames in os.walk(input_dirname): + for filename in filenames: + match = re.search(r'.*\.csv$', filename) + if match: + files.append(os.path.join(root, filename)) + return files + + +def clean_up_accelerator_model(model: str) -> str: + if re.search(r'One of Tesla V100', model): + return 'v100' + if re.search(r'One of Quadro P1000, NVIDIA RTX A6000', model): + return 'a6000' + if re.search(r'NVIDIA A100-SXM4-40GB', model): + return 'a100' + sys.exit(f"fatal: cannot recognize accelerator model: '{model}'.") + + +def skip_model(args, model_name: str): + return (not re.search("|".join(args.filter), model_name, re.I) or + re.search("|".join(args.exclude), model_name, re.I)) + + +def process_file(args, results_map: dict[str, Any], filename: str): + with open(filename) as check_header_file: + try: + has_header = csv.Sniffer().has_header(check_header_file.read(1024)) + except csv.Error: + logger.error('Cannot read CSV in %s, skipping.', filename) + return + if not has_header: + logger.error('Cannot interpret %s: missing headers.', filename) + return + fields = ( + 'model_name', + 'accelerator_model', + 'dynamo', + 'test', + 'batch_size', + 'median_total_time', + ) + with open(filename) as read_file: + reader = csv.reader(read_file) + headers = next(reader) + if headers[0] != 'timestamp': + logger.error('Missing timestamp in CSV in %s, skipping.', filename) + return + field2index = {} + for i, header in enumerate(headers): + for field in fields: + if field == header: + field2index[field] = i + for row in reader: + timestamp = row[0] + model_name = row[field2index['model_name']] + if skip_model(args, model_name): + continue + accelerator_model = clean_up_accelerator_model( + row[field2index['accelerator_model']]) + dynamo = row[field2index['dynamo']] + test = row[field2index['test']] + batch_size = row[field2index['batch_size']] + median_total_time = row[field2index['median_total_time']] + if timestamp not in results_map: + results_map[timestamp] = {} + if accelerator_model not in results_map[timestamp]: + results_map[timestamp][accelerator_model] = {} + if dynamo not in results_map[timestamp][accelerator_model]: + results_map[timestamp][accelerator_model][dynamo] = {} + if test not in results_map[timestamp][accelerator_model][dynamo]: + results_map[timestamp][accelerator_model][dynamo][test] = {} + if (model_name + not in results_map[timestamp][accelerator_model][dynamo][test]): + results_map[timestamp][accelerator_model][dynamo][test][model_name] = {} + if (batch_size not in results_map[timestamp][accelerator_model][dynamo] + [test][model_name]): + results_map[timestamp][accelerator_model][dynamo][test][model_name][ + batch_size] = {} + results_map[timestamp][accelerator_model][dynamo][test][model_name][ + batch_size] = median_total_time + + +def summarize_speedups(acc_map: dict[str, Any], label: str): + if label not in acc_map: + return + acc_map[f'{label}:gmean'] = gmean(acc_map[label]) + for p in (5, 50, 95): + percentile = float(np.percentile(acc_map[label], p)) + acc_map[f'{label}:p{p}'] = percentile + + +# The speedup values are stored in acc_map[label]; the corresponding +# model names are stored in acc_map[f'{label}:model_name']. +def compute_speedups(acc_map: dict[str, Any], label: str, xla_label, + inductor_label, test_label): + model_label = f'{label}:model_name' + if xla_label not in acc_map: + return + if inductor_label not in acc_map: + return + if (test_label not in acc_map[xla_label] or + test_label not in acc_map[inductor_label]): + return + for model_name, v in acc_map[xla_label][test_label].items(): + if model_name not in acc_map[inductor_label][test_label]: + continue + speedups = [] + # If we are running several batch sizes, keep the geomean of their speedups. + for batch_size in v: + xla_time = v[batch_size] + inductor_time = acc_map[inductor_label][test_label][model_name].get( + batch_size, None) + if not xla_time or not inductor_time: + continue + speedups.append(float(inductor_time) / float(xla_time)) + if speedups: + if label not in acc_map: + acc_map[label] = [] + acc_map[label].append(gmean(speedups)) + if model_label not in acc_map: + acc_map[model_label] = [] + acc_map[model_label].append(model_name) + summarize_speedups(acc_map, label) + + +def process_results(results_map: dict[str, Any]): + for timestamp in results_map: + for accelerator in results_map[timestamp]: + acc_map = results_map[timestamp][accelerator] + + compute_speedups(acc_map, 'speedups:inference', 'openxla_eval', + 'inductor', 'eval') + compute_speedups(acc_map, 'speedups:training', 'openxla', 'inductor', + 'train') + + +def maketitle(args, title: str): + if args.title: + title += f'\n{args.title}' + return title + + +def pr_latest(results_map: dict[str, Any], args, timestamps: list[str]): + label = f'speedups:{args.test}' + model_label = f'{label}:model_name' + + for timestamp in reversed(timestamps): + if label not in results_map[timestamp][args.accelerator]: + continue + acc_map = results_map[timestamp][args.accelerator] + (speedups, + model_names) = map(list, + zip(*sorted(zip(acc_map[label], acc_map[model_label])))) + + if args.format == 'csv': + print('# WorkloadNumber,Speedup,ModelName') + for i, speedup in enumerate(speedups): + print(','.join(map(str, [i, speedup, model_names[i]]))) + else: + plt.axhline(y=1.0, color='lightgray') + plt.plot(speedups, marker='o') + plt.title( + maketitle( + args, + f'Speedup of Pytorch/XLA over Inductor\n{date.fromtimestamp(float(timestamp))}' + )) + plt.xlabel('Workload Number') + plt.ylabel(f'Speedup') + plt.savefig(sys.stdout.buffer) + return + logger.warning(f'cannot find data for accelerator {args.accelerator}') + + +def pr_histogram(results_map: dict[str, Any], args, timestamps: list[str]): + percentiles = [f'p{p}' for p in (5, 50, 95)] + labels = [f'speedups:{args.test}:{p}' for p in percentiles] + x = [] + y = [[] for i in range(len(percentiles))] + for timestamp in timestamps: + if labels[0] in results_map[timestamp][args.accelerator]: + for label in labels: + assert label in results_map[timestamp][args.accelerator] + x.append(date.fromtimestamp(float(timestamp))) + for i, label in enumerate(labels): + y[i].append(results_map[timestamp][args.accelerator][label]) + if args.format == 'csv': + titles = ['# Datetime'] + percentiles + print(','.join(titles)) + for i, datetime in enumerate(x): + print(','.join([str(datetime)] + + [str(y[j][i]) for j in range(len(percentiles))])) + else: + plt.axhline(y=1.0, color='lightgray') + for i, p in enumerate(percentiles): + plt.plot(x, y[i], label=p, marker='^') + plt.legend() + plt.xlabel("Date") + plt.ylabel("Geomean Speedup") + plt.title( + maketitle(args, f"Histogram of Pytorch/XLA's Speedup over Inductor")) + plt.savefig(sys.stdout.buffer) + + +def pr_gmean(results_map: dict[str, Any], args, timestamps: list[str]): + label = f'speedups:{args.test}:gmean' + x = [] + y = [] + for timestamp in timestamps: + if label not in results_map[timestamp][args.accelerator]: + continue + x.append(date.fromtimestamp(float(timestamp))) + gmean = results_map[timestamp][args.accelerator][label] + y.append(gmean) + if args.format == 'csv': + print('# Datetime,Speedup') + for a, b in zip(x, y): + print(','.join(map(str, [a, b]))) + else: + plt.axhline(y=1.0, color='lightgray') + plt.plot(x, y, marker='^') + plt.xlabel("Date") + plt.ylabel("Geomean Speedup") + plt.title(maketitle(args, f"Pytorch/XLA's Speedup over Inductor")) + plt.savefig(sys.stdout.buffer) + + +def pr_results(results_map: dict[str, Any], args): + timestamp_list = list(results_map.keys()) + timestamps = [ + ts for ts in timestamp_list if args.accelerator in results_map[ts] + ] + timestamps.sort() + + if args.report == 'latest': + return pr_latest(results_map, args, timestamps) + elif args.report == 'histogram': + return pr_histogram(results_map, args, timestamps) + elif args.report == 'speedup': + return pr_gmean(results_map, args, timestamps) + else: + sys.exit('unreachable') + + +def parse_args(args=None): + parser = argparse.ArgumentParser() + parser.add_argument( + '--accelerator', + default='v100', + choices=['a100', 'v100', 'a6000'], + help='Accelerator.') + parser.add_argument( + "--exclude", + "-x", + action="append", + default=[], + help="filter out benchmarks with regexp") + parser.add_argument( + "--exclude-by-tier", + type=int, + action="append", + default=[], + help="filter out benchmarks by predefined tier 1-3", + ) + parser.add_argument( + "--filter", + "-k", + action="append", + default=[], + help="filter benchmarks with regexp") + parser.add_argument( + "--filter-by-tier", + type=int, + action="append", + default=[], + help="filter benchmarks by predefined tier 1-3", + ) + parser.add_argument( + "--format", default='csv', choices=['csv', 'png'], help='Output format') + parser.add_argument( + '--input-dirname', '-i', required=True, type=str, help='Input directory.') + parser.add_argument( + '--report', + default='speedup', + choices=['latest', 'histogram', 'speedup'], + help='What report to generate.') + parser.add_argument( + '--test', + default='inference', + choices=['inference', 'training'], + help='Test mode.') + parser.add_argument('--title', type=str, help="Plot title.") + args = parser.parse_args(args) + + append_filter_by_tier(args.filter, args.filter_by_tier) + append_filter_by_tier(args.exclude, args.exclude_by_tier) + args.filter = args.filter or [r"."] + args.exclude = args.exclude or [r"^$"] + + return args + + +def main(): + args = parse_args() + filenames = find_files(args.input_dirname) + results_map = {} + + # Some CSV files have lots of errors from execution; expand CSV's size limit. + csv.field_size_limit(1024 * 1024) + + for filename in filenames: + process_file(args, results_map, filename) + process_results(results_map) + if not results_map: + sys.exit('no results found') + pr_results(results_map, args) + + +if __name__ == '__main__': + main() diff --git a/benchmarks/benchmark_experiment.py b/benchmarks/benchmark_experiment.py new file mode 100644 index 000000000000..21679594d2fa --- /dev/null +++ b/benchmarks/benchmark_experiment.py @@ -0,0 +1,193 @@ +from collections import OrderedDict +import logging +import os +import torch +import torch._dynamo as dynamo + +try: + from .util import is_xla_device_available, get_accelerator_model +except ImportError: + from util import is_xla_device_available, get_accelerator_model + +try: + import torch_xla.core.xla_model as xm +except ImportError: + # ignore the error if torch_xla is not installed + pass + +logger = logging.getLogger(__name__) + + +class ExperimentLoader: + + def __init__(self, args): + self._args = args + self.experiment_name = self._args.experiment_name + + def expand_config_choices(self, config_choices): + configs = [{}] + + for key, choices in config_choices.items(): + tmp_configs = [] + for config in configs: + for choice in choices: + tmp_config = config.copy() + tmp_config[key] = choice + tmp_configs.append(tmp_config) + configs = tmp_configs + + return configs + + def list_experiment_configs(self): + if self.experiment_name == "run_all": + config_choices = { + "accelerator": ["cpu", "cuda", "tpu"], + "xla": [None, "PJRT", "XRT"], + "xla_flags": [None], + "dynamo": [None, "inductor", "openxla_eval", "openxla"], + "test": ["eval", "train"], + } + + if self._args.accelerator: + config_choices["accelerator"] = list(set(self._args.accelerator)) + if self._args.xla: + config_choices["xla"] = [ + x if x != "None" else None for x in list(set(self._args.xla)) + ] + if self._args.dynamo: + config_choices["dynamo"] = [ + x if x != "None" else None for x in list(set(self._args.dynamo)) + ] + if self._args.test: + config_choices["test"] = list(set(self._args.test)) + if self._args.xla_flags: + config_choices["xla_flags"] = [ + x if x != "None" else None for x in list(set(self._args.xla_flags)) + ] + + else: + raise NotImplementedError + + experiment_configs = [] + for experiment_config in self.expand_config_choices(config_choices): + if not self.is_available(experiment_config): + continue + + self._add_experiment_env(experiment_config) + experiment_configs.append(experiment_config) + return experiment_configs + + def is_available(self, experiment_config): + if experiment_config["dynamo"] and experiment_config[ + "dynamo"] not in dynamo.list_backends(exclude_tags=()): + return False + if experiment_config["dynamo"] == "inductor" and not ( + experiment_config["accelerator"] == "cuda" and + not experiment_config["xla"]): + return False + if experiment_config["dynamo"] == "openxla_eval" and not ( + experiment_config["xla"] and experiment_config["test"] == "eval"): + return False + if experiment_config["dynamo"] == "openxla" and not ( + experiment_config["xla"] and experiment_config["test"] == "train"): + return False + if (experiment_config["xla"] and + not is_xla_device_available(experiment_config["accelerator"].upper())): + return False + if (experiment_config["accelerator"] == "tpu" and + not experiment_config["xla"]): + return False + if (experiment_config["accelerator"] == "cuda" and + not experiment_config["xla"] and not torch.cuda.is_available()): + return False + return True + + def _add_experiment_env(self, experiment_config): + process_env = None + + if experiment_config["xla"]: + # remove env vars that would interfere with subprocess settings + os.environ.pop("PJRT_DEVICE", None) + os.environ.pop("XRT_TPU_CONFIG", None) + os.environ.pop("XLA_FLAGS", None) + + process_env = os.environ.copy() + if experiment_config["xla"] == "PJRT": + process_env["PJRT_DEVICE"] = experiment_config["accelerator"].upper() + elif experiment_config["xla"] == "XRT": + if is_xla_device_available("TPU"): + process_env["TPU_NUM_DEVICES"] = "1" + process_env["XRT_TPU_CONFIG"] = "localservice;0;localhost:51011" + elif is_xla_device_available("CUDA"): + process_env["GPU_NUM_DEVICES"] = "1" + elif not experiment_config["xla"] and is_xla_device_available( + experiment_config["accelerator"].upper()): + # In non-xla CPU training experiments, an env var is still needed if an + # xla device exists, or there will be "Missing XLA configuration" error. + process_env["PJRT_DEVICE"] = experiment_config["accelerator"].upper() + + if experiment_config["xla_flags"]: + process_env["XLA_FLAGS"] = experiment_config["xla_flags"] + + experiment_config["process_env"] = process_env + + def load_experiment(self, experiment_config, dummy=False): + experiment_name = self.experiment_name + accelerator = experiment_config["accelerator"] + xla = experiment_config["xla"] + xla_flags = experiment_config["xla_flags"] + dynamo = experiment_config["dynamo"] + test = experiment_config["test"] + batch_size = experiment_config.get("batch_size", self._args.batch_size) + benchmark_experiment = BenchmarkExperiment( + experiment_name=experiment_name, + accelerator=accelerator, + xla=xla, + xla_flags=xla_flags, + dynamo=dynamo, + test=test, + batch_size=batch_size) + + return benchmark_experiment + + +class BenchmarkExperiment: + + def __init__(self, experiment_name, accelerator, xla, xla_flags, dynamo, test, + batch_size): + self.experiment_name = experiment_name + self.accelerator = accelerator + self.xla = xla + self.xla_flags = xla_flags + self.dynamo = dynamo + self.test = test + self.batch_size = batch_size + self.accelerator_model = get_accelerator_model(self.accelerator) + + def get_device(self): + if self.xla: + device = xm.xla_device(devkind=self.accelerator.upper()) + elif self.accelerator == "cpu": + device = torch.device("cpu") + elif self.accelerator == "cuda": + device = torch.device("cuda") + else: + raise NotImplementedError + + return device + + @property + def filename_str(self): + return "-".join(self.to_dict().values()) + + def to_dict(self): + d = OrderedDict() + d["experiment_name"] = self.experiment_name + d["accelerator"] = self.accelerator + d["accelerator_model"] = self.accelerator_model + d["xla"] = self.xla + d["xla_flags"] = self.xla_flags + d["dynamo"] = self.dynamo + d["test"] = self.test + d["batch_size"] = self.batch_size + return d diff --git a/benchmarks/benchmark_model.py b/benchmarks/benchmark_model.py new file mode 100644 index 000000000000..5f9492413ea8 --- /dev/null +++ b/benchmarks/benchmark_model.py @@ -0,0 +1,165 @@ +from collections import OrderedDict +import logging +import re +import torch +import torch.nn as nn +import torch._dynamo as dynamo +from torch._dynamo.testing import collect_results +import types + +try: + from .util import move_to_device +except ImportError: + from util import move_to_device + +logger = logging.getLogger(__name__) + + +class ModelLoader: + + def __init__(self, args): + self._args = args + self.suite_name = self._args.suite_name + self.benchmark_model_class = BenchmarkModel + + def list_model_configs(self): + model_configs = [ + { + "model_name": "dummy" + }, + ] + + return model_configs + + def is_compatible(self, dummy_benchmark_model, benchmark_experiment): + return True + + def get_benchmark_indices(self, length): + start = self._args.partition_id * (length // self._args.total_partitions) + end = ((self._args.partition_id + 1) * + (length // self._args.total_partitions) + if self._args.partition_id < self._args.total_partitions - 1 else + length) + return start, end + + def skip_model(self, model_name): + return (not re.search("|".join(self._args.filter), model_name, re.I) or + re.search("|".join(self._args.exclude), model_name, re.I)) + + def load_model(self, model_config, benchmark_experiment, dummy=False): + suite_name = self.suite_name + model_name = model_config["model_name"] + benchmark_model = self.benchmark_model_class( + suite_name=suite_name, + model_name=model_name, + benchmark_experiment=benchmark_experiment, + ) + + if not dummy: + benchmark_model.set_up() + benchmark_model.prepare_for_experiment() + + return benchmark_model + + +class BenchmarkModel: + + def __init__(self, suite_name, model_name, benchmark_experiment): + self.suite_name = suite_name + self.model_name = model_name + self.benchmark_experiment = benchmark_experiment + + def set_up(self): + """Set up module, actual batch_size, example_inputs, and optimizer_class + + This is model suite specific. + """ + if self.model_name != "dummy": + raise NotImplementedError + + self.module = nn.Sequential( + nn.Linear(32, 512), + nn.ReLU(), + nn.Linear(512, 512), + nn.ReLU(), + nn.Linear(512, 32), + nn.Softmax(dim=1), + ) + + self.benchmark_experiment.batch_size = 16 + self.example_inputs = (torch.rand(self.benchmark_experiment.batch_size, + 32),) + self.optimizer_class = torch.optim.Adam + + def _prepare_for_eval(self): + self.module.eval() + self.model_iter_fn = self.eval + + def _prepare_for_train(self): + self.module.train() + self.model_iter_fn = self.train + if not hasattr(self, "optimizer"): + # For some special models, self.set_up() may have initialized an + # optimizer to use. So only initialize it when there is none existing. + self.optimizer = self.optimizer_class(self.module.parameters(), lr=0.01) + + def prepare_for_experiment(self): + self.device = self.benchmark_experiment.get_device() + self.module = self.module.to(self.device) + self.example_inputs = move_to_device(self.example_inputs, self.device) + + if self.benchmark_experiment.test == "eval": + self._prepare_for_eval() + elif self.benchmark_experiment.test == "train": + self._prepare_for_train() + else: + raise NotImplementedError + + if self.benchmark_experiment.dynamo: + self.model_iter_fn = torch.compile( + self.model_iter_fn, backend=self.benchmark_experiment.dynamo) + + def pick_grad(self): + if self.benchmark_experiment.test == "eval": + return torch.no_grad() + elif self.benchmark_experiment.test == "train": + return torch.enable_grad() + + def _optimizer_zero_grad(self): + if self.optimizer is not None: + self.optimizer.zero_grad(True) + else: + self.module.zero_grad(True) + + def _optimizer_step(self): + if self.optimizer is not None: + self.optimizer.step() + + def compute_loss(self, pred): + raise NotImplementedError + + def train(self, inputs, collect_full_output=False): + self._optimizer_zero_grad() + pred = self.module(*inputs) + loss = self.compute_loss(pred) + loss.backward() + self._optimizer_step() + if collect_full_output: + return collect_results(self.module, pred, loss, inputs) + # return loss.detach() + # TODO: dynamo inductor would fail if .detach() is used + return None + + def eval(self, inputs, collect_full_output=False): + pred = self.module(*inputs) + return pred + + @property + def filename_str(self): + return "-".join(self.to_dict().values()) + + def to_dict(self): + d = OrderedDict() + d["suite_name"] = self.suite_name + d["model_name"] = self.model_name + return d diff --git a/benchmarks/experiment_runner.py b/benchmarks/experiment_runner.py new file mode 100644 index 000000000000..23e128f11319 --- /dev/null +++ b/benchmarks/experiment_runner.py @@ -0,0 +1,547 @@ +import argparse +from collections import OrderedDict +import copy +import csv +import io +import json +import logging +import numpy as np +import os +import subprocess +import sys +import time +import torch +from tqdm import tqdm +from torch.profiler import profile, record_function, ProfilerActivity + +try: + from .benchmark_model import ModelLoader + from .torchbench_model import TorchBenchModelLoader + from .benchmark_experiment import ExperimentLoader + from .util import patch_torch_manual_seed, reset_rng_state, move_to_device, randomize_input + from .tiers import append_filter_by_tier +except ImportError: + from benchmark_model import ModelLoader + from torchbench_model import TorchBenchModelLoader + from benchmark_experiment import ExperimentLoader + from util import patch_torch_manual_seed, reset_rng_state, move_to_device, randomize_input + from tiers import append_filter_by_tier + +try: + import torch_xla.core.xla_model as xm +except ImportError: + # ignore the error if torch_xla is not installed + pass + +logger = logging.getLogger(__name__) + + +class ExperimentRunner: + + def __init__(self, args): + self._args = args + self.suite_name = self._args.suite_name + + self.experiment_loader = ExperimentLoader(self._args) + + if self.suite_name == "torchbench": + self.model_loader = TorchBenchModelLoader(self._args) + elif self.suite_name == "dummy": + self.model_loader = ModelLoader(self._args) + else: + raise NotImplementedError + + self.output_dir = os.path.abspath(self._args.output_dirname) + os.makedirs(self.output_dir, exist_ok=True) + self.output_file = os.path.join(self.output_dir, self._args.output_basename) + + def run(self): + if self._args.experiment_config and self._args.model_config: + if self._args.dry_run: + logger.warning(f"Dry run with {[sys.executable] + sys.argv}") + return + experiment_config = json.loads(self._args.experiment_config) + model_config = json.loads(self._args.model_config) + self.run_single_experiment(experiment_config, model_config) + else: + assert not self._args.experiment_config and not self._args.model_config + finished_experiments = set() + if os.path.exists(self.output_file): + if self._args.no_resume: + os.unlink(self.output_file) + else: + with open(self.output_file, mode="r", encoding="utf-8") as f: + jsonlines = f.read().splitlines() + for jsonline in jsonlines: + tmp = json.loads(jsonline) + if self._args.experiment_name == "run_all": + # the finished experiment batch_size may be altered by model set_up(), + # so the dummy experiment will not match it + tmp["experiment"]["batch_size"] = self._args.batch_size + finished_experiments.add("-".join( + str(item) for item in (list(tmp["model"].values()) + + list(tmp["experiment"].values())))) + + experiment_configs = self.experiment_loader.list_experiment_configs() + model_configs = self.model_loader.list_model_configs() + logger.warning( + f"Number of selected experiment configs: {len(experiment_configs)}") + logger.warning(f"Number of selected model configs: {len(model_configs)}") + for model_config in tqdm( + model_configs, + desc="model configs", + disable=not self._args.progress_bar): + for experiment_config in experiment_configs: + process_env = experiment_config.pop("process_env") + experiment_config_str = json.dumps(experiment_config) + model_config_str = json.dumps(model_config) + dummy_benchmark_experiment = self.experiment_loader.load_experiment( + experiment_config, dummy=True) + dummy_benchmark_model = self.model_loader.load_model( + model_config, dummy_benchmark_experiment, dummy=True) + experiment_config["process_env"] = process_env + command = ([sys.executable] + sys.argv + + [f"--experiment-config={experiment_config_str}"] + + [f"--model-config={model_config_str}"]) + if self._args.dry_run: + logger.warning(f"Dry run with {command}") + continue + if "-".join( + str(item) + for item in (list(dummy_benchmark_model.to_dict().values()) + + list(dummy_benchmark_experiment.to_dict().values()) + )) in finished_experiments: + continue + if self.model_loader.is_compatible(dummy_benchmark_model, + dummy_benchmark_experiment): + try: + completed_process = subprocess.run( + command, + timeout=60 * 30, + env=process_env, + check=True, + capture_output=True, + encoding="utf-8", + ) + except subprocess.TimeoutExpired as e: + logger.error("TIMEOUT") + self.save_results(dummy_benchmark_experiment, + dummy_benchmark_model, {"error": str(e)}, None) + except subprocess.CalledProcessError as e: + logger.error("ERROR") + self.save_results(dummy_benchmark_experiment, + dummy_benchmark_model, {"error": e.stderr}, + None) + except subprocess.SubprocessError as e: + logger.error("ERROR") + self.save_results(dummy_benchmark_experiment, + dummy_benchmark_model, {"error": str(e)}, None) + else: + if self._args.print_subprocess: + logger.info(completed_process.stdout) + logger.warning(completed_process.stderr) + + else: + e = "SKIP because of incompatible model and experiment configs." + logger.warning(e) + self.save_results(dummy_benchmark_experiment, dummy_benchmark_model, + {"error": str(e)}, None) + + def run_single_experiment(self, experiment_config, model_config): + benchmark_experiment = self.experiment_loader.load_experiment( + experiment_config) + reset_rng_state(benchmark_experiment) + benchmark_model = self.model_loader.load_model(model_config, + benchmark_experiment) + + with benchmark_model.pick_grad(): + metrics = OrderedDict() + outputs = [] + for i in range(self._args.repeat): + run_metrics, output = self.timed_run(benchmark_experiment, + benchmark_model) + output = move_to_device(output, 'cpu') + outputs.append(output) + for key, val in run_metrics.items(): + # metrics from repeated runs are formed into lists in the metrics dict + if i == 0: + metrics[key] = [] + metrics[key].append(val) + + # additional experiment metrics can be added here + + self.save_results(benchmark_experiment, benchmark_model, metrics, outputs) + + def save_results(self, benchmark_experiment, benchmark_model, metrics, + outputs): + if self._args.save_output and outputs is not None: + outputs_file_name = f"{benchmark_model.filename_str}-{benchmark_experiment.filename_str}.pt" + torch.save(outputs, os.path.join(self.output_dir, outputs_file_name)) + else: + outputs_file_name = None + + results = OrderedDict() + results["model"] = benchmark_model.to_dict() + results["experiment"] = benchmark_experiment.to_dict() + results["repeat"] = self._args.repeat + results["iterations_per_run"] = self._args.iterations_per_run + + results["metrics"] = metrics + results["outputs_file"] = outputs_file_name + + self.output_jsonl(results) + + def output_jsonl(self, obj, file_path=None): + if not file_path: + file_path = self.output_file + json_str = json.dumps(obj, ensure_ascii=False) + with open(file_path, mode="a", encoding="utf-8") as f: + f.write(f"{json_str}\n") + + def output_csv(self, headers, row, file_path=None): + if not file_path: + file_path = self.output_file + existed = os.path.exists(file_path) + output = csv.writer( + io.TextIOWrapper( + open(file_path, "ab", buffering=0), + "utf-8", + write_through=True, + ), + lineterminator="\n", + ) + if not existed: + output.writerow(headers) + output.writerow([(f"{x:.8e}" if isinstance(x, float) else x) for x in row]) + + def _mark_step(self, benchmark_experiment): + if benchmark_experiment.xla: + xm.mark_step() + + def _synchronize(self, benchmark_experiment): + if benchmark_experiment.xla: + xm.wait_device_ops() + elif benchmark_experiment.accelerator == "cuda": + torch.cuda.synchronize() + else: + pass + + def prepare_inputs(self, example_inputs, should_randomize_input): + inputs_list = [] + for i in range(self._args.iterations_per_run): + inputs = copy.deepcopy(example_inputs) + if should_randomize_input: + inputs = randomize_input(inputs) + inputs_list.append(inputs) + return inputs_list + + def dump_profile_info(self, prof, model_name): + assert prof is not None, 'Expecting profiler to be defined!' + if not self._args.profile_cuda_dump: + logger.warning( + 'Profiling enabled, but dumping tracing/kernel summary disabled.') + return + + file_path = f"/tmp/{model_name}-profile" + os.makedirs(file_path, exist_ok=True) + prof.export_chrome_trace(os.path.join(file_path, "trace.json")) + + kernel_dump = prof.key_averages().table( + sort_by="cuda_time_total", row_limit=500) + with open(os.path.join(file_path, "kernel_dump.txt"), "a") as f: + f.write(kernel_dump) + + def timed_run(self, benchmark_experiment, benchmark_model): + + reset_rng_state(benchmark_experiment) + + inputs_list = self.prepare_inputs(benchmark_model.example_inputs, + self._args.randomize_input) + + reset_rng_state(benchmark_experiment) + self._mark_step(benchmark_experiment) + self._synchronize(benchmark_experiment) + + enable_prof = self._args.profile_cuda + metrics = OrderedDict() + t_start = time.perf_counter() + if benchmark_experiment.xla: + t_trace = 0 + + def loop(prof=None): + nonlocal t_trace + for i in range(self._args.iterations_per_run): + if benchmark_experiment.xla: + t_trace_start = time.perf_counter() + + output = benchmark_model.model_iter_fn( + inputs_list[i], collect_full_output=self._args.collect_full_output) + + if benchmark_experiment.xla: + t_trace += time.perf_counter() - t_trace_start + + self._mark_step(benchmark_experiment) + + if prof: + prof.step() + return output + + if enable_prof: + with profile( + activities=[ProfilerActivity.CUDA, ProfilerActivity.CPU]) as prof: + output = loop(prof) + else: + output = loop() + + self._synchronize(benchmark_experiment) + + t_end = time.perf_counter() + if enable_prof: + self.dump_profile_info(prof, benchmark_model.model_name) + + metrics["total_time"] = t_end - t_start + metrics[ + "per_iter_time"] = metrics["total_time"] / self._args.iterations_per_run + if benchmark_experiment.xla: + metrics["trace_per_iter_time"] = t_trace / self._args.iterations_per_run + + return metrics, output + + +def parse_args(args=None): + parser = argparse.ArgumentParser() + + parser.add_argument( + "--suite-name", + required=True, + choices=["dummy", "torchbench"], + help="Suite name for the model garden.", + ) + + parser.add_argument( + "--filter", + "-k", + action="append", + default=[], + help="filter benchmarks with regexp") + + parser.add_argument( + "--exclude", + "-x", + action="append", + default=[], + help="filter out benchmarks with regexp") + + parser.add_argument( + "--filter-by-tier", + type=int, + action="append", + default=[], + help="filter benchmarks by predefined tier 1-3", + ) + + parser.add_argument( + "--exclude-by-tier", + type=int, + action="append", + default=[], + help="filter out benchmarks by predefined tier 1-3", + ) + + parser.add_argument( + "--log-level", + default="warning", + choices=["info", "warning"], + help="Specify the logging level.", + ) + + parser.add_argument( + "--experiment-name", + default="run_all", + choices=["run_all"], + help="Experiment name to run.", + ) + + parser.add_argument( + "--accelerator", + choices=["cpu", "cuda", "tpu"], + action="append", + help="Specify an accelerator to use.", + ) + + parser.add_argument( + "--xla", + choices=["None", "PJRT", "XRT"], + action="append", + help="Specify an xla option to use.", + ) + + parser.add_argument( + "--dynamo", + choices=["None", "inductor", "openxla_eval", "openxla"], + action="append", + help="Specify an xla option to use.", + ) + + parser.add_argument( + "--test", + choices=["eval", "train"], + action="append", + help="Specify a test to run.", + ) + + parser.add_argument( + "--repeat", + type=int, + default=10, + help="Number of times to repeat the timed run in a single experiment.", + ) + + parser.add_argument( + "--iterations-per-run", + type=int, + default=1, + help="Number of times to repeat the model iteration inside a timed run.", + ) + + parser.add_argument( + "--batch-size", + type=int, + help="Batch size to be used. If not provided, it depends on the model suites to determine it.", + ) + + parser.add_argument( + "--total-partitions", + type=int, + default=1, + choices=range(1, 10), + help="Total number of partitions we want to divide the benchmark suite into", + ) + + parser.add_argument( + "--partition-id", + type=int, + default=0, + help="ID of the benchmark suite partition to be run. Used to divide CI tasks", + ) + + parser.add_argument( + "--dry-run", + action="store_true", + help="Do a dry run to only print the benchmark commands.", + ) + + parser.add_argument( + "--print-subprocess", + action="store_true", + help="Print subprocess stdout.", + ) + + parser.add_argument( + "--progress-bar", + action="store_true", + help="Display progress bar.", + ) + + parser.add_argument( + "--randomize-input", + action="store_true", + help="Whether to randomize the input values. Dimensions will be kept the same.", + ) + + parser.add_argument( + "--collect-full-output", + action="store_true", + help="""Whether to collect full output for training. Set this to true if we + want to verify the numerical correctness of graidents. But that may + cause time measurement not accurate""", + ) + + parser.add_argument( + "--save-output", + action="store_true", + help="Whether to save the model output to disk", + ) + + parser.add_argument( + "--output-dirname", + type=str, + default="./output/", + help="Overrides the directory to place output files.", + ) + + parser.add_argument( + "--output-basename", + type=str, + default="results.jsonl", + help="Overrides the basename of output files.", + ) + + parser.add_argument( + "--no-resume", + action="store_true", + help="""By default, the runner would skip the finished experiments that + exist in the output-basename file. If --no-resume is set, the previous + output-basename file will be deleted and all experiment will run""", + ) + + parser.add_argument( + "--experiment-config", + type=str, + help="JSON string of the experiment config dict.", + ) + + parser.add_argument( + "--model-config", + type=str, + help="JSON string of the model config dict.", + ) + + parser.add_argument( + "--profile-cuda", + action="store_true", + help="""Whether to profile CUDA or not. Note this does not do much except for + triggering a profiler. To get the profiling data use additionally --profile-cuda-dump""", + ) + + parser.add_argument( + "--profile-cuda-dump", + type=str, + default="./output/", + help="Directory specifying where to dump profiling information (summary, and trace)" + ), + + parser.add_argument( + "--xla-flags", + type=str, + action="append", + help="Flags to forward to XLA via `XLA_FLAGS` env var.", + ) + + return parser.parse_args(args) + + +def main(): + args = parse_args() + + # Expand filter/exclude by tier. + append_filter_by_tier(args.filter, args.filter_by_tier) + append_filter_by_tier(args.exclude, args.exclude_by_tier) + args.filter = args.filter or [r"."] + args.exclude = args.exclude or [r"^$"] + + if args.log_level == "info": + log_level = logging.INFO + elif args.log_level == "warning": + log_level = logging.WARNING + else: + log_level = None + logging.basicConfig(level=log_level, force=True) + + logger.info(args) + runner = ExperimentRunner(args) + runner.run() + + +if __name__ == "__main__": + main() diff --git a/benchmarks/result_analyzer.py b/benchmarks/result_analyzer.py new file mode 100644 index 000000000000..7ad22d6692ea --- /dev/null +++ b/benchmarks/result_analyzer.py @@ -0,0 +1,191 @@ +import argparse +from collections import OrderedDict +import copy +import csv +import io +import json +import logging +import numpy as np +import os +import pandas as pd +import subprocess +import sys +import time +import torch +from tqdm import tqdm + +logger = logging.getLogger(__name__) + + +class ResultAnalyzer: + + def __init__(self, args): + self._args = args + self.timestamp = self._args.timestamp or time.time() + self.output_dir = os.path.abspath(self._args.output_dirname) + if not os.path.exists(self.output_dir): + raise ValueError("The output directory does not exist.") + self.output_file = os.path.join(self.output_dir, "metric_report.csv") + + self.database = os.path.abspath(self._args.database) + + def run(self): + jsonl_files = [] + for file in os.listdir(self.output_dir): + if file.endswith(".jsonl"): + jsonl_files.append(os.path.join(self.output_dir, file)) + + metric_df = pd.DataFrame({ + "timestamp": pd.Series(dtype="int"), + "suite_name": pd.Series(dtype="str"), + "model_name": pd.Series(dtype="str"), + "experiment_name": pd.Series(dtype="str"), + "accelerator": pd.Series(dtype="str"), + "accelerator_model": pd.Series(dtype="str"), + "xla": pd.Series(dtype="str"), + "xla_flags": pd.Series(dtype="str"), + "dynamo": pd.Series(dtype="str"), + "test": pd.Series(dtype="str"), + "batch_size": pd.Series(dtype="int"), + "repeat": pd.Series(dtype="int"), + "iterations_per_run": pd.Series(dtype="int"), + "error_message": pd.Series(dtype="str"), + "median_total_time": pd.Series(dtype="float"), + "median_per_iter_time": pd.Series(dtype="float"), + "xla_median_trace_per_iter_time": pd.Series(dtype="float"), + "xla_compile_time": pd.Series(dtype="float"), + "dynamo_compile_time": pd.Series(dtype="float"), + "outputs_file": pd.Series(dtype="str"), + }) + for file in jsonl_files: + metric_df = self.extract_metrics(file, metric_df) + + # additional processing of the metric_df can be done here + + self.export_metric_report(metric_df) + + def extract_metrics(self, file, metric_df): + with open(file, mode="r", encoding="utf-8") as f: + jsonlines = f.read().splitlines() + + for jsonline in jsonlines: + tmp = json.loads(jsonline) + d = { + "timestamp": self.timestamp, + "suite_name": tmp["model"]["suite_name"], + "model_name": tmp["model"]["model_name"], + "experiment_name": tmp["experiment"]["experiment_name"], + "accelerator": tmp["experiment"]["accelerator"], + "accelerator_model": tmp["experiment"]["accelerator_model"], + "xla": tmp["experiment"]["xla"], + "xla_flags": tmp["experiment"]["xla_flags"], + "dynamo": tmp["experiment"]["dynamo"], + "test": tmp["experiment"]["test"], + "batch_size": tmp["experiment"]["batch_size"], + "repeat": tmp["repeat"], + "iterations_per_run": tmp["iterations_per_run"], + "error_message": None, + "outputs_file": tmp["outputs_file"], + } + + if "error" in tmp["metrics"] and not self._args.hide_errors: + d["error_message"] = tmp["metrics"]["error"] + + if "error" not in tmp["metrics"]: + total_time = np.asarray(tmp["metrics"]["total_time"], dtype="float") + d["median_total_time"] = np.median(total_time) + per_iter_time = np.asarray( + tmp["metrics"]["per_iter_time"], dtype="float") + d["median_per_iter_time"] = np.median(per_iter_time) + if tmp["experiment"]["xla"]: + trace_per_iter_time = np.asarray( + tmp["metrics"]["trace_per_iter_time"], dtype="float") + d["xla_median_trace_per_iter_time"] = np.median(trace_per_iter_time) + d["xla_compile_time"] = np.max(total_time) - np.median(total_time) + if tmp["experiment"]["dynamo"]: + d["dynamo_compile_time"] = np.max(total_time) - np.median(total_time) + + new_row = pd.Series(d) + new_row.fillna(value=np.nan, inplace=True) + metric_df = pd.concat([metric_df, new_row.to_frame().T], + ignore_index=True) + + return metric_df + + def export_metric_report(self, metric_df): + metric_df.to_csv( + self.output_file, mode="w", encoding="utf-8", header=True, index=False) + + if not os.path.exists(self.database): + metric_df.to_csv( + self.database, mode="w", encoding="utf-8", header=True, index=False) + else: + metric_df.to_csv( + self.database, mode="a", encoding="utf-8", header=False, index=False) + + +def parse_args(args=None): + parser = argparse.ArgumentParser() + + parser.add_argument( + "--log-level", + default="warning", + choices=["info", "warning"], + help="Specify the logging level.", + ) + + parser.add_argument( + "--experiment-name", + default="run_all", + choices=["run_all"], + help="Experiment name to run.", + ) + + parser.add_argument( + "--output-dirname", + type=str, + default="./output/", + help="Overrides the directory to place output files.", + ) + + parser.add_argument( + "--database", + type=str, + default="./output/database.csv", + help="Path to the database.", # for POC, database is a path to a csv file. + ) + + parser.add_argument( + "--timestamp", + type=int, + help="User provided timestamp. If not provided, get the timestamp in analyzer", + ) + + parser.add_argument( + "--hide-errors", + default=False, + action="store_true", + help="Hide errors to make the CSV more readable", + ) + + return parser.parse_args(args) + + +def main(): + args = parse_args() + + if args.log_level == "info": + log_level = logging.INFO + elif args.log_level == "warning": + log_level = logging.WARNING + else: + log_level = None + logging.basicConfig(level=log_level, force=True) + + logger.info(args) + analyzer = ResultAnalyzer(args) + analyzer.run() + + +if __name__ == "__main__": + main() diff --git a/benchmarks/run_benchmark.sh b/benchmarks/run_benchmark.sh new file mode 100644 index 000000000000..fd8a055bccc5 --- /dev/null +++ b/benchmarks/run_benchmark.sh @@ -0,0 +1,79 @@ +#!/bin/bash +set -exo pipefail +CDIR="$(cd "$(dirname "$0")" ; pwd -P)" +LOGFILE=/tmp/benchmark_test.log + +# Note [Keep Going] +# +# Set the `CONTINUE_ON_ERROR` flag to `true` to make the CircleCI tests continue on error. +# This will allow you to see all the failures on your PR, not stopping with the first +# test failure like the default behavior. +CONTINUE_ON_ERROR="${CONTINUE_ON_ERROR:-0}" +if [[ "$CONTINUE_ON_ERROR" == "1" ]]; then + set +e +fi + +TESTGPUVM=None +TESTTPUVM=None +# NUMBER=0 + +while getopts 'G:T:' OPTION # N: +do + case $OPTION in + G) + TESTGPUVM=$OPTARG + ;; + T) + TESTTPUVM=$OPTARG + ;; + # N) + # NUMBER=$OPTARG + # ;; + esac +done +shift $(($OPTIND - 1)) + +# func for test after ssh to VM, create container and execute in container +function benchmarking_in_container { + sudo docker pull gcr.io/tpu-pytorch/xla:nightly_3.8_cuda_11.8 + sudo apt-get install -y apt-transport-https ca-certificates curl gnupg-agent software-properties-common + nvidia-smi + distribution=$(. /etc/os-release;echo $ID$VERSION_ID) + curl -s -L https://nvidia.github.io/nvidia-docker/gpgkey | sudo apt-key add - + curl -s -L https://nvidia.github.io/nvidia-docker/$distribution/nvidia-docker.list | sudo tee /etc/apt/sources.list.d/nvidia-docker.list + sudo apt-get update && sudo apt-get install -y nvidia-container-toolkit + sudo systemctl restart docker + sudo docker run --gpus all -it -d gcr.io/tpu-pytorch/xla:nightly_3.8_cuda_11.8 bin/bash + sudo docker exec -it $(sudo docker ps | awk 'NR==2 { print $1 }') /bin/bash + # install torchbench + cd ~ + git clone -b xla_benchmark https://github.com/pytorch/benchmark.git + cd benchmark + # install deps + pip install --pre torchvision torchaudio -i https://download.pytorch.org/whl/nightly/cu118 + # git clone xla + cd ~ + git clone -b benchmark https://github.com/pytorch/xla.git xla + cd ~/xla/benchmarks + # dry run + python3 experiment_runner.py --suite-name=torchbench --accelerator=gpu --progress-bar --dry-run + # run bechmark + python3 experiment_runner.py --suite-name=torchbench --accelerator=gpu --progress-bar + # analyze result to csv + python3 result_analyzer.py +} + + + +if TESTGPUVM='1A100': + # ssh to 1-A100 GPUVM and test in container + gcloud compute ssh a100-manfei-1 --zone us-central1-c --project tpu-prod-env-one-vm -- -o ProxyCommand='corp-ssh-helper %h %p' --command=benchmarking_in_container +elif TESTGPUVM='8A100': + # SSH TO 8-A100 GPUVM and test in container + gcloud compute ssh manfei-a100-8-new --zone us-central1-c --project tpu-prod-env-one-vm -- -o ProxyCommand='corp-ssh-helper %h %p' --command=benchmarking_in_container +elif TESTGPUVM='4H100': + # ssh to 4-H100 GPUVM and test in container +elif TESTTPUVM='v5e8': + # ssh to v5e-8 TPUVM and test in container +elif TESTTPUVM='v5p8': + # ssh to v5p-8 TPUVM and test in container diff --git a/benchmarks/run_top_tier_bm.sh b/benchmarks/run_top_tier_bm.sh new file mode 100755 index 000000000000..ca67e361a6f2 --- /dev/null +++ b/benchmarks/run_top_tier_bm.sh @@ -0,0 +1,25 @@ +#!/usr/bin/env bash + +set -ex + +DATE=$(date +"%Y_%m_%d_%H_%M") + +OUT_PATH=xla/benchmarks/bm_results/$DATE +mkdir -p $OUT_PATH + +python xla/benchmarks/experiment_runner.py \ + --dynamo=inductor --dynamo=openxla_eval --dynamo=openxla \ + --xla=None --xla=PJRT \ + --test=eval --test=train \ + --filter-by-tier=1 --filter-by-tier=2 --filter-by-tier=3 \ + --suite-name=torchbench \ + --accelerator=cuda \ + --output-dirname=$OUT_PATH \ + --repeat=5 \ + --print-subprocess \ + --no-resume \ + > $OUT_PATH/stdout.txt 2> $OUT_PATH/stderr.txt + +python3 xla/benchmarks/result_analyzer.py \ + --output-dirname=$OUT_PATH \ + --database=$OUT_PATH/$DATE.csv diff --git a/benchmarks/tiers.py b/benchmarks/tiers.py new file mode 100644 index 000000000000..7339cc746780 --- /dev/null +++ b/benchmarks/tiers.py @@ -0,0 +1,15 @@ +_FILTER_BY_TIER = { + 1: + r"^(BERT_pytorch|cm3leon_generate|DALLE2_pytorch|dlrm|hf_GPT2|hf_GPT2_large|GPT_3|hf_T5|hf_T5_base|hf_T5_generate|hf_T5_large|llama_v2_7b_16h|stable_diffusion_xl)$", + 2: + r"^(alexnet|attention_is_all_you_need_pytorch|Background_Matting|basic_gnn_gcn|basic_gnn_gin|basic_gnn_sage|dcgan|densenet121|detectron2_fasterrcnn_r_101_c4|detectron2_fasterrcnn_r_101_dc5|detectron2_fasterrcnn_r_101_fpn|detectron2_fasterrcnn_r_50_c4|detectron2_fasterrcnn_r_50_dc5|detectron2_fasterrcnn_r_50_fpn|detectron2_fcos_r_50_fpn|detectron2_maskrcnn|detectron2_maskrcnn_r_101_c4|detectron2_maskrcnn_r_101_fpn|detectron2_maskrcnn_r_50_c4|detectron2_maskrcnn_r_50_fpn|fastNLP_Bert|functorch_dp_cifar10|hf_Albert|hf_Bart|hf_Bert|hf_Bert_large|llama)$", + 3: + r"^(doctr_det_predictor|doctr_reco_predictor|drq|functorch_maml_omniglot)$", +} + + +def append_filter_by_tier(filter_list: list[str], filter_by_tier: list[int]): + for tier in filter_by_tier: + if tier not in _FILTER_BY_TIER: + continue + filter_list.append(_FILTER_BY_TIER[tier]) diff --git a/benchmarks/torchbench_model.py b/benchmarks/torchbench_model.py new file mode 100644 index 000000000000..60f4fc05aa8f --- /dev/null +++ b/benchmarks/torchbench_model.py @@ -0,0 +1,247 @@ +import gc +import importlib +import logging +import os +from os.path import abspath, exists +import sys +import torch +import torch.nn as nn +from torch._dynamo.testing import collect_results, reduce_to_scalar_loss +from torch._dynamo.utils import clone_inputs +import types + +try: + from .util import move_to_device, set_cwd + from .benchmark_model import ModelLoader, BenchmarkModel +except ImportError: + from util import move_to_device, set_cwd + from benchmark_model import ModelLoader, BenchmarkModel + +logger = logging.getLogger(__name__) + +DETECTRON2_MODELS = { + "detectron2_fasterrcnn_r_101_c4", + "detectron2_fasterrcnn_r_101_dc5", + "detectron2_fasterrcnn_r_101_fpn", + "detectron2_fasterrcnn_r_50_c4", + "detectron2_fasterrcnn_r_50_dc5", + "detectron2_fasterrcnn_r_50_fpn", + "detectron2_maskrcnn_r_101_c4", + "detectron2_maskrcnn_r_101_fpn", + "detectron2_maskrcnn_r_50_c4", + "detectron2_maskrcnn_r_50_fpn", + "detectron2_maskrcnn", + "detectron2_fcos_r_50_fpn", +} + +# Skip the experiment of a model if any of the experiment configs in the list is fully matched +DENY_LIST = { + "doctr_det_predictor": [{ + "test": "train" + },], # not implemented + "doctr_reco_predictor": [{ + "test": "train" + },], # not implemented + "detectron2_fcos_r_50_fpn": [{ + "test": "train" + },], # not implemented + # https://github.com/pytorch/torchdynamo/issues/145 + "fambench_xlmr": [{}], + "llama": [{ + "test": "train" + },], # not implemented + "mobilenet_v2_quantized_qat": [ + { + "test": "eval", + "accelerator": "cuda" + }, # not implemented + { + "test": "eval", + "accelerator": "tpu" + }, + ], # not implemented + "pyhpc_equation_of_state": [{ + "test": "train" + },], # not implemented + "pyhpc_isoneutral_mixing": [{ + "test": "train" + },], # not implemented + "pyhpc_turbulent_kinetic_energy": [{ + "test": "train" + },], # not implemented + "pytorch_struct": [{ + "test": "eval" + },], # not implemented + "resnet50_quantized_qat": [ + { + "test": "eval", + "accelerator": "cuda" + }, # not implemented + { + "test": "eval", + "accelerator": "tpu" + }, + ], # not implemented + # https://github.com/pytorch/pytorch/issues/99438 + "vision_maskrcnn": [{}], +} + + +class TorchBenchModelLoader(ModelLoader): + + def __init__(self, args): + super().__init__(args) + self.benchmark_model_class = TorchBenchModel + self.torchbench_dir = self.add_torchbench_dir() + + def add_torchbench_dir(self): + os.environ["KALDI_ROOT"] = "/tmp" # avoids some spam + for torchbench_dir in ( + "./torchbenchmark", + "./torchbench", + "./benchmark", + "../torchbenchmark", + "../torchbench", + "../benchmark", + "../../torchbenchmark", + "../../torchbench", + "../../benchmark", + ): + if exists(torchbench_dir): + break + + if exists(torchbench_dir): + torchbench_dir = abspath(torchbench_dir) + if torchbench_dir not in sys.path: + sys.path.append(torchbench_dir) + else: + raise Exception("Torch Benchmark folder not found.") + + return torchbench_dir + + def list_model_configs(self): + model_configs = [] + + from torchbenchmark import _list_model_paths + + models = _list_model_paths() + + start, end = self.get_benchmark_indices(len(models)) + models = models[start:end] + for model_path in models: + model_name = os.path.basename(model_path) + + if self.skip_model(model_name): + continue + + model_configs.append({"model_name": model_name}) + + return model_configs + + def is_compatible(self, dummy_benchmark_model, benchmark_experiment): + if dummy_benchmark_model.model_name in DENY_LIST: + for deny_experiment_config in DENY_LIST[dummy_benchmark_model.model_name]: + matched = True + for k, v in deny_experiment_config.items(): + if getattr(benchmark_experiment, k) != v: + matched = False + break + if matched: + return False + + return True + + +class TorchBenchModel(BenchmarkModel): + + def __init__(self, suite_name, model_name, benchmark_experiment): + super().__init__(suite_name, model_name, benchmark_experiment) + + def set_up(self): + """Set up module, actual batch_size, example_inputs, and optimizer_class + + This is model suite specific. + """ + self.optimizer_class = torch.optim.Adam + + try: + module = importlib.import_module( + f"torchbenchmark.models.{self.model_name}") + except ModuleNotFoundError: + module = importlib.import_module( + f"torchbenchmark.models.fb.{self.model_name}") + benchmark_cls = getattr(module, "Model", None) + + cant_change_batch_size = (not getattr(benchmark_cls, + "ALLOW_CUSTOMIZE_BSIZE", True)) + if cant_change_batch_size: + self.benchmark_experiment.batch_size = None + + # workaround "RuntimeError: not allowed to set torch.backends.cudnn flags" + # torch.backends.__allow_nonbracketed_mutation_flag = True + + if self.benchmark_experiment.accelerator == "cpu": + device = "cpu" + elif self.benchmark_experiment.accelerator == "cuda" and not self.benchmark_experiment.xla: + device = "cuda" + else: + device = str(self.benchmark_experiment.get_device()) + + benchmark = benchmark_cls( + test=self.benchmark_experiment.test, + device=device, + batch_size=self.benchmark_experiment.batch_size, + ) + + self.module, self.example_inputs = benchmark.get_module() + + self.benchmark_experiment.batch_size = benchmark.batch_size + + # Torchbench has quite different setup for yolov3, so directly passing + # the right example_inputs + if self.model_name == "yolov3": + self.example_inputs = (torch.rand(self.benchmark_experiment.batch_size, 3, + 384, 512),) + if self.benchmark_experiment.test == "train" and self.model_name in DETECTRON2_MODELS: + self.optimizer = benchmark.optimizer + + del benchmark + gc.collect() + + def pick_grad(self): + # special case + if self.model_name in ("maml",): + return torch.enable_grad() + + if self.benchmark_experiment.test == "eval": + return torch.no_grad() + elif self.benchmark_experiment.test == "train": + return torch.enable_grad() + + def compute_loss(self, pred): + """Reduce the output of a model to get scalar loss""" + if isinstance(pred, torch.Tensor): + # Mean does not work on integer tensors + return pred.sum() / pred.numel() + elif isinstance(pred, (list, tuple)): + return sum([reduce_to_scalar_loss(x) for x in pred]) / len(pred) + elif type(pred).__name__ in ( + "MaskedLMOutput", + "Seq2SeqLMOutput", + "CausalLMOutputWithCrossAttentions", + ): + return reduce_to_scalar_loss(pred.logits) + elif type(pred).__name__ == "SquashedNormal": + return pred.mean.sum() + elif isinstance(pred, dict): + return sum([reduce_to_scalar_loss(value) for value in pred.values() + ]) / len(pred.keys()) + raise NotImplementedError("Don't know how to reduce", type(pred)) + + def train(self, inputs, collect_full_output=False): + if self.model_name in DETECTRON2_MODELS: + from detectron2.utils.events import EventStorage + with EventStorage(): + super().train(inputs, collect_full_output=collect_full_output) + else: + super().train(inputs, collect_full_output=collect_full_output) \ No newline at end of file diff --git a/benchmarks/util.py b/benchmarks/util.py new file mode 100644 index 000000000000..0e62511d5a39 --- /dev/null +++ b/benchmarks/util.py @@ -0,0 +1,153 @@ +from contextlib import contextmanager +import functools +import logging +from multiprocessing import Process, Queue +import numpy as np +import os +from os.path import abspath +import queue +import random +import subprocess +import torch +import traceback + +try: + import torch_xla.core.xla_model as xm + from torch_xla._internal import tpu +except ImportError: + # ignore the error if torch_xla is not installed + pass + +logger = logging.getLogger(__name__) + + +@functools.lru_cache(None) +def patch_torch_manual_seed(): + """Make torch manual seed deterministic. Helps with accuracy testing.""" + + def deterministic_torch_manual_seed(*args, **kwargs): + from torch._C import default_generator + + seed = 1337 + import torch.cuda + + if not torch.cuda._is_in_bad_fork(): + torch.cuda.manual_seed_all(seed) + return default_generator.manual_seed(seed) + + torch.manual_seed = deterministic_torch_manual_seed + + +def reset_rng_state(benchmark_experiment=None): + torch.manual_seed(1337) + random.seed(1337) + np.random.seed(1337) + if benchmark_experiment and benchmark_experiment.xla: + device = benchmark_experiment.get_device() + xm.set_rng_state(1337, str(device)) + + +@functools.lru_cache(maxsize=3) +def is_xla_device_available(devkind): + if devkind not in ["CPU", "CUDA", "TPU"]: + raise ValueError(devkind) + + def _check_xla_device(q, devkind): + try: + import os + os.environ["PJRT_DEVICE"] = devkind + + import torch_xla.core.xla_model as xm + + q.put(bool(xm.get_xla_supported_devices(devkind=devkind))) + except Exception: + traceback.print_exc() + q.put(False) + + q = Queue() + process = Process(target=_check_xla_device, args=(q, devkind)) + process.start() + process.join(60) + try: + return q.get_nowait() + except queue.Empty: + traceback.print_exc() + return False + + +def move_to_device(item, device): + if isinstance(item, torch.Tensor): + return item.to(device=device) + elif isinstance(item, list): + return [move_to_device(t, device) for t in item] + elif isinstance(item, tuple): + return tuple(move_to_device(t, device) for t in item) + elif isinstance(item, dict): + return dict((k, move_to_device(t, device)) for k, t in item.items()) + else: + return item + + +def randomize_input(inputs): + if isinstance(inputs, torch.Tensor): + if inputs.dtype in (torch.float32, torch.float64): + torch._dynamo.utils.counters["randomize_input"]["times"] += 1 + return torch.randn_like(inputs) + elif inputs.dtype == torch.int64: + # Note: we can not simply tune integer tensors as follows + # `return torch.randint_like(inputs, high=inputs.max().item())` + # This may break some invariants between tensors. + # E.g. in embedding lookup case, one tensor is the length + # and another is an indices tensor. + return inputs + else: + raise RuntimeError( + f"randomize_input need support tensor of type {inputs.dtype}") + elif isinstance(inputs, (list, tuple)): + return type(inputs)([randomize_input(x) for x in inputs]) + elif isinstance(inputs, dict): + return dict((k, randomize_input(x)) for k, x in inputs.items()) + else: + logger.warning( + f"randomize_input can not handle input of type {type(inputs)}") + return inputs + + +@contextmanager +def set_cwd(path): + original_dir = abspath(os.getcwd()) + os.chdir(path) + try: + yield + finally: + os.chdir(original_dir) + + +def get_accelerator_model(accelerator): + if accelerator == "cpu": + return get_cpu_name() + elif accelerator == "cuda": + return get_gpu_name() + elif accelerator == "tpu": + return get_tpu_name() + else: + raise NotImplementedError + + +def get_cpu_name(): + return subprocess.check_output( + ["lscpu"], + encoding='utf-8').split("Model name:")[1].split("\n")[0].strip() + + +def get_gpu_name(): + gpu_names = subprocess.check_output( + ["nvidia-smi", "--query-gpu=gpu_name", "--format=csv"], + encoding='utf-8').split("\n")[1:] + if len(gpu_names) == 1: + return gpu_names[0] + return "One of " + ", ".join(gpu_names) + + +def get_tpu_name(): + return tpu._get_metadata("accelerator-type")