diff --git a/benchmarks/benchmark_experiment.py b/benchmarks/benchmark_experiment.py new file mode 100644 index 00000000000..bd523acfb9e --- /dev/null +++ b/benchmarks/benchmark_experiment.py @@ -0,0 +1,183 @@ +from collections import OrderedDict +import logging +import os +import torch +import torch._dynamo as dynamo + +try: + from .util import is_xla_device_available, get_accelerator_model +except ImportError: + from util import is_xla_device_available, get_accelerator_model + +try: + import torch_xla.core.xla_model as xm +except ImportError: + # ignore the error if torch_xla is not installed + pass + +logger = logging.getLogger(__name__) + + +class ExperimentLoader: + + def __init__(self, args): + self._args = args + self.experiment_name = self._args.experiment_name + + def expand_config_choices(self, config_choices): + configs = [{}] + + for key, choices in config_choices.items(): + tmp_configs = [] + for config in configs: + for choice in choices: + tmp_config = config.copy() + tmp_config[key] = choice + tmp_configs.append(tmp_config) + configs = tmp_configs + + return configs + + def list_experiment_configs(self): + if self.experiment_name == "run_all": + config_choices = { + "accelerator": ["cpu", "gpu", "tpu"], + "xla": [None, "PJRT", "XRT"], + "dynamo": [ + None, "inductor", "torchxla_trace_once", "aot_torchxla_trace_once" + ], + "test": ["eval", "train"], + } + + if self._args.accelerator: + config_choices["accelerator"] = list(set(self._args.accelerator)) + if self._args.xla: + config_choices["xla"] = list(set(self._args.xla)) + config_choices["xla"] = [ + x if x != "None" else None for x in config_choices["xla"] + ] + if self._args.dynamo: + config_choices["dynamo"] = list(set(self._args.dynamo)) + config_choices["dynamo"] = [ + x if x != "None" else None for x in config_choices["dynamo"] + ] + if self._args.test: + config_choices["test"] = list(set(self._args.test)) + else: + raise NotImplementedError + + experiment_configs = [] + for experiment_config in self.expand_config_choices(config_choices): + if not self.is_available(experiment_config): + continue + + self._add_experiment_env(experiment_config) + experiment_configs.append(experiment_config) + return experiment_configs + + def is_available(self, experiment_config): + if experiment_config["dynamo"] and experiment_config[ + "dynamo"] not in dynamo.list_backends(exclude_tags=()): + return False + if experiment_config["dynamo"] == "inductor" and not ( + experiment_config["accelerator"] == "gpu" and + not experiment_config["xla"]): + return False + if experiment_config["dynamo"] == "torchxla_trace_once" and not ( + experiment_config["xla"] and experiment_config["test"] == "eval"): + return False + if experiment_config["dynamo"] == "aot_torchxla_trace_once" and not ( + experiment_config["xla"] and experiment_config["test"] == "train"): + return False + if (experiment_config["xla"] and + not is_xla_device_available(experiment_config["accelerator"].upper())): + return False + if (experiment_config["accelerator"] == "tpu" and + not experiment_config["xla"]): + return False + if (experiment_config["accelerator"] == "gpu" and + not experiment_config["xla"] and not torch.cuda.is_available()): + return False + return True + + def _add_experiment_env(self, experiment_config): + process_env = None + if experiment_config["xla"]: + # remove env vars that would interfere with subprocess settings + os.environ.pop("PJRT_DEVICE", None) + os.environ.pop("XRT_TPU_CONFIG", None) + if experiment_config["xla"] == "PJRT": + process_env = os.environ.copy() + process_env["PJRT_DEVICE"] = experiment_config["accelerator"].upper() + elif experiment_config["xla"] == "XRT": + process_env = os.environ.copy() + if is_xla_device_available("TPU"): + process_env["TPU_NUM_DEVICES"] = "1" + process_env["XRT_TPU_CONFIG"] = "localservice;0;localhost:51011" + elif is_xla_device_available("GPU"): + process_env["GPU_NUM_DEVICES"] = "1" + elif not experiment_config["xla"] and is_xla_device_available( + experiment_config["accelerator"].upper()): + # In non-xla CPU training experiments, an env var is still needed if an + # xla device exists, or there will be "Missing XLA configuration" error. + process_env = os.environ.copy() + process_env["PJRT_DEVICE"] = experiment_config["accelerator"].upper() + + experiment_config["process_env"] = process_env + + def load_experiment(self, experiment_config, dummy=False): + experiment_name = self.experiment_name + accelerator = experiment_config["accelerator"] + xla = experiment_config["xla"] + dynamo = experiment_config["dynamo"] + test = experiment_config["test"] + batch_size = experiment_config.get("batch_size", self._args.batch_size) + benchmark_experiment = BenchmarkExperiment( + experiment_name=experiment_name, + accelerator=accelerator, + xla=xla, + dynamo=dynamo, + test=test, + batch_size=batch_size) + + return benchmark_experiment + + +class BenchmarkExperiment: + + def __init__(self, experiment_name, accelerator, xla, dynamo, test, + batch_size): + self.experiment_name = experiment_name + self.accelerator = accelerator + self.xla = xla + self.dynamo = dynamo + self.test = test + self.batch_size = batch_size + self.accelerator_model = get_accelerator_model(self.accelerator) + + def get_device(self): + if self.xla: + device = xm.xla_device(devkind=self.accelerator.upper()) + elif self.accelerator == "cpu": + device = torch.device("cpu") + elif self.accelerator == "gpu": + device = torch.device("cuda") + else: + raise NotImplementedError + + return device + + @property + def filename_str(self): + return "-".join(self.to_dict().values()) + + def to_dict(self): + d = OrderedDict() + d["experiment_name"] = self.experiment_name + d["accelerator"] = self.accelerator + d["accelerator_model"] = self.accelerator_model + d["xla"] = self.xla + d["dynamo"] = self.dynamo + d["test"] = self.test + d["batch_size"] = self.batch_size + return d diff --git a/benchmarks/benchmark_model.py b/benchmarks/benchmark_model.py new file mode 100644 index 00000000000..dec1dde017a --- /dev/null +++ b/benchmarks/benchmark_model.py @@ -0,0 +1,172 @@ +from collections import OrderedDict +import logging +import re +import torch +import torch.nn as nn +import torch._dynamo as dynamo +from torch._dynamo.testing import collect_results +import types + +try: + from .util import move_to_device +except ImportError: + from util import move_to_device + +logger = logging.getLogger(__name__) + + +class ModelLoader: + + def __init__(self, args): + self._args = args + self.suite_name = self._args.suite_name + self.benchmark_model_class = BenchmarkModel + + def list_model_configs(self): + model_configs = [ + { + "model_name": "dummy" + }, + ] + + return model_configs + + def is_compatible(self, dummy_benchmark_model, benchmark_experiment): + return True + + def get_benchmark_indices(self, length): + start = self._args.partition_id * (length // self._args.total_partitions) + end = ((self._args.partition_id + 1) * + (length // self._args.total_partitions) + if self._args.partition_id < self._args.total_partitions - 1 else + length) + return start, end + + def skip_model(self, model_name): + return (not re.search("|".join(self._args.filter), model_name, re.I) or + re.search("|".join(self._args.exclude), model_name, re.I)) + + def load_model(self, model_config, benchmark_experiment, dummy=False): + suite_name = self.suite_name + model_name = model_config["model_name"] + benchmark_model = self.benchmark_model_class( + suite_name=suite_name, + model_name=model_name, + benchmark_experiment=benchmark_experiment, + ) + + if not dummy: + benchmark_model.set_up() + benchmark_model.prepare_for_experiment() + + return benchmark_model + + +class BenchmarkModel: + + def __init__(self, suite_name, model_name, benchmark_experiment): + self.suite_name = suite_name + self.model_name = model_name + self.benchmark_experiment = benchmark_experiment + + def set_up(self): + """Set up module, actual batch_size, example_inputs, and optimizer_class + + This is model suite specific. + """ + if self.model_name != "dummy": + raise NotImplementedError + + self.module = nn.Sequential( + nn.Linear(32, 512), + nn.ReLU(), + nn.Linear(512, 512), + nn.ReLU(), + nn.Linear(512, 32), + nn.Softmax(dim=1), + ) + + self.benchmark_experiment.batch_size = 16 + self.example_inputs = (torch.rand(self.benchmark_experiment.batch_size, + 32),) + self.optimizer_class = torch.optim.Adam + + def _prepare_for_eval(self): + self.module.eval() + self.model_iter_fn = self.eval + + def _prepare_for_train(self): + self.module.train() + self.model_iter_fn = self.train + if self.benchmark_experiment.dynamo == "aot_torchxla_trace_once": + # TODO: dynamo aot_torchxla_trace_once would fail if there is an + # optimizer. + # This makes the aot_torchxla_trace_once results not comparable + # with other training results + self.optimizer = None + else: + if not hasattr(self, "optimizer"): + # For some special models, self.set_up() may have initialized an + # optimizer to use. So only initialize it when there is none existing. + self.optimizer = self.optimizer_class(self.module.parameters(), lr=0.01) + + def prepare_for_experiment(self): + self.device = self.benchmark_experiment.get_device() + self.module = self.module.to(self.device) + self.example_inputs = move_to_device(self.example_inputs, self.device) + + if self.benchmark_experiment.test == "eval": + self._prepare_for_eval() + elif self.benchmark_experiment.test == "train": + self._prepare_for_train() + else: + raise NotImplementedError + + if self.benchmark_experiment.dynamo: + self.model_iter_fn = torch.compile(self.model_iter_fn, + backend=self.benchmark_experiment.dynamo) + + def pick_grad(self): + if self.benchmark_experiment.test == "eval": + return torch.no_grad() + elif self.benchmark_experiment.test == "train": + return torch.enable_grad() + + def _optimizer_zero_grad(self): + if self.optimizer is not None: + self.optimizer.zero_grad(True) + else: + self.module.zero_grad(True) + + def _optimizer_step(self): + if self.optimizer is not None: + self.optimizer.step() + + def compute_loss(self, pred): + raise NotImplementedError + + def train(self, inputs, collect_full_output=False): + self._optimizer_zero_grad() + pred = self.module(*inputs) + loss = self.compute_loss(pred) + loss.backward() + self._optimizer_step() + if collect_full_output: + return collect_results(self.module, pred, loss, inputs) + # return loss.detach() + # TODO: dynamo inductor would fail if .detach() is used + return None + + def eval(self, inputs, collect_full_output=False): + pred = self.module(*inputs) + return pred + + @property + def filename_str(self): + return "-".join(self.to_dict().values()) + + def to_dict(self): + d = OrderedDict() + d["suite_name"] = self.suite_name + d["model_name"] = self.model_name + return d diff --git a/benchmarks/experiment_runner.py b/benchmarks/experiment_runner.py new file mode 100644 index 00000000000..953092d89c3 --- /dev/null +++ b/benchmarks/experiment_runner.py @@ -0,0 +1,463 @@ +import argparse +from collections import OrderedDict +import copy +import csv +import io +import json +import logging +import numpy as np +import os +import subprocess +import sys +import time +import torch +from tqdm import tqdm + +try: + from .benchmark_model import ModelLoader + from .torchbench_model import TorchBenchModelLoader + from .benchmark_experiment import ExperimentLoader + from .util import patch_torch_manual_seed, reset_rng_state, move_to_device, randomize_input +except ImportError: + from benchmark_model import ModelLoader + from torchbench_model import TorchBenchModelLoader + from benchmark_experiment import ExperimentLoader + from util import patch_torch_manual_seed, reset_rng_state, move_to_device, randomize_input + +try: + import torch_xla.core.xla_model as xm +except ImportError: + # ignore the error if torch_xla is not installed + pass + +logger = logging.getLogger(__name__) + + +class ExperimentRunner: + + def __init__(self, args): + self._args = args + self.suite_name = self._args.suite_name + + self.experiment_loader = ExperimentLoader(self._args) + + if self.suite_name == "torchbench": + self.model_loader = TorchBenchModelLoader(self._args) + elif self.suite_name == "dummy": + self.model_loader = ModelLoader(self._args) + else: + raise NotImplementedError + + self.output_dir = os.path.abspath(self._args.output_dirname) + os.makedirs(self.output_dir, exist_ok=True) + self.output_file = os.path.join(self.output_dir, self._args.output_basename) + + def run(self): + if self._args.experiment_config and self._args.model_config: + if self._args.dry_run: + logger.warning(f"Dry run with {[sys.executable] + sys.argv}") + return + experiment_config = json.loads(self._args.experiment_config) + model_config = json.loads(self._args.model_config) + self.run_single_experiment(experiment_config, model_config) + else: + assert not self._args.experiment_config and not self._args.model_config + finished_experiments = set() + if os.path.exists(self.output_file): + if self._args.no_resume: + os.unlink(self.output_file) + else: + with open(self.output_file, mode="r", encoding="utf-8") as f: + jsonlines = f.read().splitlines() + for jsonline in jsonlines: + tmp = json.loads(jsonline) + if self._args.experiment_name == "run_all": + # the finished experiment batch_size may be altered by model set_up(), + # so the dummy experiment will not match it + tmp["experiment"]["batch_size"] = self._args.batch_size + finished_experiments.add("-".join( + str(item) for item in (list(tmp["model"].values()) + + list(tmp["experiment"].values())))) + + experiment_configs = self.experiment_loader.list_experiment_configs() + model_configs = self.model_loader.list_model_configs() + logger.warning( + f"Number of selected experiment configs: {len(experiment_configs)}") + logger.warning(f"Number of selected model configs: {len(model_configs)}") + for model_config in tqdm( + model_configs, + desc="model configs", + disable=not self._args.progress_bar): + for experiment_config in experiment_configs: + process_env = experiment_config.pop("process_env") + experiment_config_str = json.dumps(experiment_config) + model_config_str = json.dumps(model_config) + dummy_benchmark_experiment = self.experiment_loader.load_experiment( + experiment_config, dummy=True) + dummy_benchmark_model = self.model_loader.load_model( + model_config, dummy_benchmark_experiment, dummy=True) + experiment_config["process_env"] = process_env + command = ([sys.executable] + sys.argv + + [f"--experiment-config={experiment_config_str}"] + + [f"--model-config={model_config_str}"]) + if self._args.dry_run: + logger.warning(f"Dry run with {command}") + continue + if "-".join( + str(item) + for item in (list(dummy_benchmark_model.to_dict().values()) + + list(dummy_benchmark_experiment.to_dict().values()) + )) in finished_experiments: + continue + if self.model_loader.is_compatible(dummy_benchmark_model, + dummy_benchmark_experiment): + try: + completed_process = subprocess.run( + command, + timeout=60 * 30, + env=process_env, + check=True, + capture_output=True, + encoding="utf-8", + ) + except subprocess.TimeoutExpired as e: + logger.error("TIMEOUT") + self.save_results(dummy_benchmark_experiment, + dummy_benchmark_model, {"error": str(e)}, None) + except subprocess.CalledProcessError as e: + logger.error("ERROR") + self.save_results(dummy_benchmark_experiment, + dummy_benchmark_model, {"error": e.stderr}, + None) + except subprocess.SubprocessError as e: + logger.error("ERROR") + self.save_results(dummy_benchmark_experiment, + dummy_benchmark_model, {"error": str(e)}, None) + else: + if self._args.print_subprocess: + logger.info(completed_process.stdout) + logger.warning(completed_process.stderr) + + else: + e = "SKIP because of incompatible model and experiment configs." + logger.warning(e) + self.save_results(dummy_benchmark_experiment, dummy_benchmark_model, + {"error": str(e)}, None) + + def run_single_experiment(self, experiment_config, model_config): + benchmark_experiment = self.experiment_loader.load_experiment( + experiment_config) + reset_rng_state(benchmark_experiment) + benchmark_model = self.model_loader.load_model(model_config, + benchmark_experiment) + + with benchmark_model.pick_grad(): + metrics = OrderedDict() + outputs = [] + for i in range(self._args.repeat): + run_metrics, output = self.timed_run(benchmark_experiment, + benchmark_model) + output = move_to_device(output, 'cpu') + outputs.append(output) + for key, val in run_metrics.items(): + # metrics from repeated runs are formed into lists in the metrics dict + if i == 0: + metrics[key] = [] + metrics[key].append(val) + + # additional experiment metrics can be added here + + self.save_results(benchmark_experiment, benchmark_model, metrics, outputs) + + def save_results(self, benchmark_experiment, benchmark_model, metrics, + outputs): + if self._args.save_output and outputs is not None: + outputs_file_name = f"{benchmark_model.filename_str}-{benchmark_experiment.filename_str}.pt" + torch.save(outputs, os.path.join(self.output_dir, outputs_file_name)) + else: + outputs_file_name = None + + results = OrderedDict() + results["model"] = benchmark_model.to_dict() + results["experiment"] = benchmark_experiment.to_dict() + results["repeat"] = self._args.repeat + results["iterations_per_run"] = self._args.iterations_per_run + + results["metrics"] = metrics + results["outputs_file"] = outputs_file_name + + self.output_jsonl(results) + + def output_jsonl(self, obj, file_path=None): + if not file_path: + file_path = self.output_file + json_str = json.dumps(obj, ensure_ascii=False) + with open(file_path, mode="a", encoding="utf-8") as f: + f.write(f"{json_str}\n") + + def output_csv(self, headers, row, file_path=None): + if not file_path: + file_path = self.output_file + existed = os.path.exists(file_path) + output = csv.writer( + io.TextIOWrapper( + open(file_path, "ab", buffering=0), + "utf-8", + write_through=True, + ), + lineterminator="\n", + ) + if not existed: + output.writerow(headers) + output.writerow([(f"{x:.8e}" if isinstance(x, float) else x) for x in row]) + + def _mark_step(self, benchmark_experiment): + if benchmark_experiment.xla: + xm.mark_step() + + def _synchronize(self, benchmark_experiment): + if benchmark_experiment.xla: + xm.wait_device_ops() + elif benchmark_experiment.accelerator == "gpu": + torch.cuda.synchronize() + else: + pass + + def prepare_inputs(self, example_inputs, should_randomize_input): + inputs_list = [] + for i in range(self._args.iterations_per_run): + inputs = copy.deepcopy(example_inputs) + if should_randomize_input: + inputs = randomize_input(inputs) + inputs_list.append(inputs) + return inputs_list + + def timed_run(self, benchmark_experiment, benchmark_model): + reset_rng_state(benchmark_experiment) + + inputs_list = self.prepare_inputs(benchmark_model.example_inputs, + self._args.randomize_input) + + reset_rng_state(benchmark_experiment) + self._mark_step(benchmark_experiment) + self._synchronize(benchmark_experiment) + + metrics = OrderedDict() + t_start = time.perf_counter() + if benchmark_experiment.xla: + t_trace = 0 + + for i in range(self._args.iterations_per_run): + if benchmark_experiment.xla: + t_trace_start = time.perf_counter() + + output = benchmark_model.model_iter_fn( + inputs_list[i], collect_full_output=self._args.collect_full_output) + + if benchmark_experiment.xla: + t_trace += time.perf_counter() - t_trace_start + + self._mark_step(benchmark_experiment) + + self._synchronize(benchmark_experiment) + + t_end = time.perf_counter() + + metrics["total_time"] = t_end - t_start + metrics[ + "per_iter_time"] = metrics["total_time"] / self._args.iterations_per_run + if benchmark_experiment.xla: + metrics["trace_per_iter_time"] = t_trace / self._args.iterations_per_run + + return metrics, output + + +def parse_args(args=None): + parser = argparse.ArgumentParser() + + parser.add_argument( + "--suite-name", + required=True, + choices=["dummy", "torchbench"], + help="Suite name for the model garden.", + ) + + parser.add_argument( + "--filter", "-k", action="append", help="filter benchmarks with regexp") + parser.add_argument( + "--exclude", "-x", action="append", help="filter benchmarks with regexp") + + parser.add_argument( + "--log-level", + default="warning", + choices=["info", "warning"], + help="Specify the logging level.", + ) + + parser.add_argument( + "--experiment-name", + default="run_all", + choices=["run_all"], + help="Experiment name to run.", + ) + + parser.add_argument( + "--accelerator", + choices=["cpu", "gpu", "tpu"], + action="append", + help="Specify an accelerator to use.", + ) + + parser.add_argument( + "--xla", + choices=["None", "PJRT", "XRT"], + action="append", + help="Specify an xla option to use.", + ) + + parser.add_argument( + "--dynamo", + choices=[ + "None", "inductor", "torchxla_trace_once", "aot_torchxla_trace_once" + ], + action="append", + help="Specify an xla option to use.", + ) + + parser.add_argument( + "--test", + choices=["eval", "train"], + action="append", + help="Specify a test to run.", + ) + + parser.add_argument( + "--repeat", + type=int, + default=10, + help="Number of times to repeat the timed run in a single experiment.", + ) + + parser.add_argument( + "--iterations-per-run", + type=int, + default=1, + help="Number of times to repeat the model iteration inside a timed run.", + ) + + parser.add_argument( + "--batch-size", + type=int, + help="Batch size to be used. If not provided, it depends on the model suites to determine it.", + ) + + parser.add_argument( + "--total-partitions", + type=int, + default=1, + choices=range(1, 10), + help="Total number of partitions we want to divide the benchmark suite into", + ) + parser.add_argument( + "--partition-id", + type=int, + default=0, + help="ID of the benchmark suite partition to be run. Used to divide CI tasks", + ) + + parser.add_argument( + "--dry-run", + action="store_true", + help="Do a dry run to only print the benchmark commands.", + ) + + parser.add_argument( + "--print-subprocess", + action="store_true", + help="Print subprocess stdout.", + ) + + parser.add_argument( + "--progress-bar", + action="store_true", + help="Display progress bar.", + ) + + parser.add_argument( + "--randomize-input", + action="store_true", + help="Whether to randomize the input values. Dimensions will be kept the same.", + ) + + parser.add_argument( + "--collect-full-output", + action="store_true", + help="""Whether to collect full output for training. Set this to true if we + want to verify the numerical correctness of graidents. But that may + cause time measurement not accurate""", + ) + + parser.add_argument( + "--save-output", + action="store_true", + help="Whether to save the model output to disk", + ) + + parser.add_argument( + "--output-dirname", + type=str, + default="./output/", + help="Overrides the directory to place output files.", + ) + + parser.add_argument( + "--output-basename", + type=str, + default="results.jsonl", + help="Overrides the basename of output files.", + ) + + parser.add_argument( + "--no-resume", + action="store_true", + help="""By default, the runner would skip the finished experiments that + exist in the output-basename file. If --no-resume is set, the previous + output-basename file will be deleted and all experiment will run""", + ) + + parser.add_argument( + "--experiment-config", + type=str, + help="JSON string of the experiment config dict.", + ) + + parser.add_argument( + "--model-config", + type=str, + help="JSON string of the model config dict.", + ) + + return parser.parse_args(args) + + +def main(): + args = parse_args() + + args.filter = args.filter or [r"."] + args.exclude = args.exclude or [r"^$"] + + if args.log_level == "info": + log_level = logging.INFO + elif args.log_level == "warning": + log_level = logging.WARNING + else: + log_level = None + logging.basicConfig(level=log_level, force=True) + + logger.info(args) + runner = ExperimentRunner(args) + runner.run() + + +if __name__ == "__main__": + main() diff --git a/benchmarks/result_analyzer.py b/benchmarks/result_analyzer.py new file mode 100644 index 00000000000..e1849886b6a --- /dev/null +++ b/benchmarks/result_analyzer.py @@ -0,0 +1,169 @@ +import argparse +from collections import OrderedDict +import copy +import csv +import io +import json +import logging +import numpy as np +import os +import pandas as pd +import subprocess +import sys +import time +import torch +from tqdm import tqdm + +logger = logging.getLogger(__name__) + + +class ResultAnalyzer: + + def __init__(self, args): + self._args = args + self.timestamp = self._args.timestamp or time.time() + self.output_dir = os.path.abspath(self._args.output_dirname) + if not os.path.exists(self.output_dir): + raise ValueError("The output directory does not exist.") + self.output_file = os.path.join(self.output_dir, "metric_report.csv") + + self.database = os.path.abspath(self._args.database) + + def run(self): + jsonl_files = [] + for file in os.listdir(self.output_dir): + if file.endswith(".jsonl"): + jsonl_files.append(os.path.join(self.output_dir, file)) + + metric_df = pd.DataFrame({"timestamp": pd.Series(dtype="int"), + "suite_name": pd.Series(dtype="str"), + "model_name": pd.Series(dtype="str"), + "experiment_name": pd.Series(dtype="str"), + "accelerator": pd.Series(dtype="str"), + "accelerator_model": pd.Series(dtype="str"), + "xla": pd.Series(dtype="str"), + "dynamo": pd.Series(dtype="str"), + "test": pd.Series(dtype="str"), + "batch_size": pd.Series(dtype="int"), + "repeat": pd.Series(dtype="int"), + "iterations_per_run": pd.Series(dtype="int"), + "error_message": pd.Series(dtype="str"), + "median_total_time": pd.Series(dtype="float"), + "median_per_iter_time": pd.Series(dtype="float"), + "xla_median_trace_per_iter_time": pd.Series(dtype="float"), + "xla_compile_time": pd.Series(dtype="float"), + "dynamo_compile_time": pd.Series(dtype="float"), + "outputs_file": pd.Series(dtype="str"), + }) + for file in jsonl_files: + metric_df = self.extract_metrics(file, metric_df) + + # additional processing of the metric_df can be done here + + self.export_metric_report(metric_df) + + def extract_metrics(self, file, metric_df): + with open(file, mode="r", encoding="utf-8") as f: + jsonlines = f.read().splitlines() + + for jsonline in jsonlines: + tmp = json.loads(jsonline) + d = {"timestamp": self.timestamp, + "suite_name": tmp["model"]["suite_name"], + "model_name": tmp["model"]["model_name"], + "experiment_name": tmp["experiment"]["experiment_name"], + "accelerator": tmp["experiment"]["accelerator"], + "accelerator_model": tmp["experiment"]["accelerator_model"], + "xla": tmp["experiment"]["xla"], + "dynamo": tmp["experiment"]["dynamo"], + "test": tmp["experiment"]["test"], + "batch_size": tmp["experiment"]["batch_size"], + "repeat": tmp["repeat"], + "iterations_per_run": tmp["iterations_per_run"], + "error_message": tmp["metrics"].get("error", None), + "outputs_file": tmp["outputs_file"], + } + if "error" not in tmp["metrics"]: + total_time = np.asarray(tmp["metrics"]["total_time"], dtype="float") + d["median_total_time"] = np.median(total_time) + per_iter_time = np.asarray(tmp["metrics"]["per_iter_time"], dtype="float") + d["median_per_iter_time"] = np.median(per_iter_time) + if tmp["experiment"]["xla"]: + trace_per_iter_time = np.asarray(tmp["metrics"]["trace_per_iter_time"], dtype="float") + d["xla_median_trace_per_iter_time"] = np.median(trace_per_iter_time) + d["xla_compile_time"] = np.max(total_time) - np.median(total_time) + if tmp["experiment"]["dynamo"]: + d["dynamo_compile_time"] = np.max(total_time) - np.median(total_time) + + new_row = pd.Series(d) + new_row.fillna(value=np.nan, inplace=True) + metric_df = pd.concat([metric_df, new_row.to_frame().T], ignore_index=True) + + return metric_df + + def export_metric_report(self, metric_df): + metric_df.to_csv(self.output_file, mode="w", encoding="utf-8", header=True, index=False) + + if not os.path.exists(self.database): + metric_df.to_csv(self.database, mode="w", encoding="utf-8", header=True, index=False) + else: + metric_df.to_csv(self.database, mode="a", encoding="utf-8", header=False, index=False) + +def parse_args(args=None): + parser = argparse.ArgumentParser() + + parser.add_argument( + "--log-level", + default="warning", + choices=["info", "warning"], + help="Specify the logging level.", + ) + + parser.add_argument( + "--experiment-name", + default="run_all", + choices=["run_all"], + help="Experiment name to run.", + ) + + parser.add_argument( + "--output-dirname", + type=str, + default="./output/", + help="Overrides the directory to place output files.", + ) + + parser.add_argument( + "--database", + type=str, + default="./output/database.csv", + help="Path to the database.", # for POC, database is a path to a csv file. + ) + + parser.add_argument( + "--timestamp", + type=int, + help="User provided timestamp. If not provided, get the timestamp in analyzer", + ) + + return parser.parse_args(args) + + +def main(): + args = parse_args() + + if args.log_level == "info": + log_level = logging.INFO + elif args.log_level == "warning": + log_level = logging.WARNING + else: + log_level = None + logging.basicConfig(level=log_level, force=True) + + logger.info(args) + analyzer = ResultAnalyzer(args) + analyzer.run() + + +if __name__ == "__main__": + main() \ No newline at end of file diff --git a/benchmarks/torchbench_model.py b/benchmarks/torchbench_model.py new file mode 100644 index 00000000000..99f1ceab62b --- /dev/null +++ b/benchmarks/torchbench_model.py @@ -0,0 +1,216 @@ +import gc +import importlib +import logging +import os +from os.path import abspath, exists +import sys +import torch +import torch.nn as nn +from torch._dynamo.testing import collect_results, reduce_to_scalar_loss +from torch._dynamo.utils import clone_inputs +import types + +try: + from .util import move_to_device, set_cwd + from .benchmark_model import ModelLoader, BenchmarkModel +except ImportError: + from util import move_to_device, set_cwd + from benchmark_model import ModelLoader, BenchmarkModel + +logger = logging.getLogger(__name__) + + +DETECTRON2_MODELS = { + "detectron2_fasterrcnn_r_101_c4", + "detectron2_fasterrcnn_r_101_dc5", + "detectron2_fasterrcnn_r_101_fpn", + "detectron2_fasterrcnn_r_50_c4", + "detectron2_fasterrcnn_r_50_dc5", + "detectron2_fasterrcnn_r_50_fpn", + "detectron2_maskrcnn_r_101_c4", + "detectron2_maskrcnn_r_101_fpn", + "detectron2_maskrcnn_r_50_c4", + "detectron2_maskrcnn_r_50_fpn", + "detectron2_maskrcnn", + "detectron2_fcos_r_50_fpn", +} + +# Skip the experiment of a model if any of the experiment configs in the list is fully matched +DENY_LIST = { + "doctr_det_predictor": [{"test": "train"},], # not implemented + "doctr_reco_predictor": [{"test": "train"},], # not implemented + "detectron2_fcos_r_50_fpn": [{"test": "train"},], # not implemented + # https://github.com/pytorch/torchdynamo/issues/145 + "fambench_xlmr": [{}], + "llama": [{"test": "train"},], # not implemented + "mobilenet_v2_quantized_qat": [{"test": "eval", "accelerator": "gpu"}, # not implemented + {"test": "eval", "accelerator": "tpu"},], # not implemented + "pyhpc_equation_of_state": [{"test": "train"},], # not implemented + "pyhpc_isoneutral_mixing": [{"test": "train"},], # not implemented + "pyhpc_turbulent_kinetic_energy": [{"test": "train"},], # not implemented + "pytorch_struct": [{"test": "eval"},], # not implemented + "resnet50_quantized_qat": [{"test": "eval", "accelerator": "gpu"}, # not implemented + {"test": "eval", "accelerator": "tpu"},], # not implemented + # https://github.com/pytorch/pytorch/issues/99438 + "vision_maskrcnn": [{}], +} + +class TorchBenchModelLoader(ModelLoader): + + def __init__(self, args): + super().__init__(args) + self.benchmark_model_class = TorchBenchModel + self.torchbench_dir = self.add_torchbench_dir() + + def add_torchbench_dir(self): + os.environ["KALDI_ROOT"] = "/tmp" # avoids some spam + for torchbench_dir in ( + "./torchbenchmark", + "./torchbench", + "./benchmark", + "../torchbenchmark", + "../torchbench", + "../benchmark", + "../../torchbenchmark", + "../../torchbench", + "../../benchmark", + ): + if exists(torchbench_dir): + break + + if exists(torchbench_dir): + torchbench_dir = abspath(torchbench_dir) + if torchbench_dir not in sys.path: + sys.path.append(torchbench_dir) + else: + raise Exception("Torch Benchmark folder not found.") + + return torchbench_dir + + def list_model_configs(self): + model_configs = [] + + from torchbenchmark import _list_model_paths + + models = _list_model_paths() + + start, end = self.get_benchmark_indices(len(models)) + models = models[start:end] + for model_path in models: + model_name = os.path.basename(model_path) + + if self.skip_model(model_name): + continue + + model_configs.append({"model_name": model_name}) + + return model_configs + + def is_compatible(self, dummy_benchmark_model, benchmark_experiment): + if dummy_benchmark_model.model_name in DENY_LIST: + for deny_experiment_config in DENY_LIST[dummy_benchmark_model.model_name]: + matched = True + for k, v in deny_experiment_config.items(): + if getattr(benchmark_experiment, k) != v: + matched = False + break + if matched: + return False + + return True + + +class TorchBenchModel(BenchmarkModel): + + def __init__(self, suite_name, model_name, benchmark_experiment): + super().__init__(suite_name, model_name, benchmark_experiment) + + def set_up(self): + """Set up module, actual batch_size, example_inputs, and optimizer_class + + This is model suite specific. + """ + self.optimizer_class = torch.optim.Adam + + try: + module = importlib.import_module( + f"torchbenchmark.models.{self.model_name}") + except ModuleNotFoundError: + module = importlib.import_module( + f"torchbenchmark.models.fb.{self.model_name}") + benchmark_cls = getattr(module, "Model", None) + + cant_change_batch_size = (not getattr(benchmark_cls, + "ALLOW_CUSTOMIZE_BSIZE", True)) + if cant_change_batch_size: + self.benchmark_experiment.batch_size = None + + # workaround "RuntimeError: not allowed to set torch.backends.cudnn flags" + # torch.backends.__allow_nonbracketed_mutation_flag = True + + if self.benchmark_experiment.accelerator == "cpu": + device = "cpu" + elif self.benchmark_experiment.accelerator == "gpu" and not self.benchmark_experiment.xla: + device = "cuda" + else: + device = str(self.benchmark_experiment.get_device()) + + benchmark = benchmark_cls( + test=self.benchmark_experiment.test, + device=device, + jit=False, + batch_size=self.benchmark_experiment.batch_size, + ) + + self.module, self.example_inputs = benchmark.get_module() + + self.benchmark_experiment.batch_size = benchmark.batch_size + + # Torchbench has quite different setup for yolov3, so directly passing + # the right example_inputs + if self.model_name == "yolov3": + self.example_inputs = (torch.rand(self.benchmark_experiment.batch_size, 3, + 384, 512),) + if self.benchmark_experiment.test == "train" and self.model_name in DETECTRON2_MODELS: + self.optimizer = benchmark.optimizer + + del benchmark + gc.collect() + + def pick_grad(self): + # special case + if self.model_name in ("maml",): + return torch.enable_grad() + + if self.benchmark_experiment.test == "eval": + return torch.no_grad() + elif self.benchmark_experiment.test == "train": + return torch.enable_grad() + + def compute_loss(self, pred): + """Reduce the output of a model to get scalar loss""" + if isinstance(pred, torch.Tensor): + # Mean does not work on integer tensors + return pred.sum() / pred.numel() + elif isinstance(pred, (list, tuple)): + return sum([reduce_to_scalar_loss(x) for x in pred]) / len(pred) + elif type(pred).__name__ in ( + "MaskedLMOutput", + "Seq2SeqLMOutput", + "CausalLMOutputWithCrossAttentions", + ): + return reduce_to_scalar_loss(pred.logits) + elif type(pred).__name__ == "SquashedNormal": + return pred.mean.sum() + elif isinstance(pred, dict): + return sum([reduce_to_scalar_loss(value) for value in pred.values() + ]) / len(pred.keys()) + raise NotImplementedError("Don't know how to reduce", type(pred)) + + def train(self, inputs, collect_full_output=False): + if self.model_name in DETECTRON2_MODELS: + from detectron2.utils.events import EventStorage + with EventStorage(): + super().train(inputs, collect_full_output=collect_full_output) + else: + super().train(inputs, collect_full_output=collect_full_output) diff --git a/benchmarks/util.py b/benchmarks/util.py new file mode 100644 index 00000000000..59cdd96af45 --- /dev/null +++ b/benchmarks/util.py @@ -0,0 +1,151 @@ +from contextlib import contextmanager +import functools +import logging +from multiprocessing import Process, Queue +import numpy as np +import os +from os.path import abspath +import queue +import random +import subprocess +import torch +import traceback + +try: + import torch_xla.core.xla_model as xm + from torch_xla.distributed.cluster import ClusterResolver +except ImportError: + # ignore the error if torch_xla is not installed + pass + +logger = logging.getLogger(__name__) + + +@functools.lru_cache(None) +def patch_torch_manual_seed(): + """Make torch manual seed deterministic. Helps with accuracy testing.""" + + def deterministic_torch_manual_seed(*args, **kwargs): + from torch._C import default_generator + + seed = 1337 + import torch.cuda + + if not torch.cuda._is_in_bad_fork(): + torch.cuda.manual_seed_all(seed) + return default_generator.manual_seed(seed) + + torch.manual_seed = deterministic_torch_manual_seed + + +def reset_rng_state(benchmark_experiment=None): + torch.manual_seed(1337) + random.seed(1337) + np.random.seed(1337) + if benchmark_experiment and benchmark_experiment.xla: + device = benchmark_experiment.get_device() + xm.set_rng_state(1337, str(device)) + + +@functools.lru_cache(maxsize=3) +def is_xla_device_available(devkind): + if devkind not in ["CPU", "GPU", "TPU"]: + raise ValueError(devkind) + + def _check_xla_device(q, devkind): + try: + import os + os.environ["PJRT_DEVICE"] = devkind + + import torch_xla.core.xla_model as xm + + q.put(bool(xm.get_xla_supported_devices(devkind=devkind))) + except Exception: + traceback.print_exc() + q.put(False) + + q = Queue() + process = Process(target=_check_xla_device, args=(q, devkind)) + process.start() + process.join(60) + try: + return q.get_nowait() + except queue.Empty: + traceback.print_exc() + return False + + +def move_to_device(item, device): + if isinstance(item, torch.Tensor): + return item.to(device=device) + elif isinstance(item, list): + return [move_to_device(t, device) for t in item] + elif isinstance(item, tuple): + return tuple(move_to_device(t, device) for t in item) + elif isinstance(item, dict): + return dict((k, move_to_device(t, device)) for k, t in item.items()) + else: + return item + + +def randomize_input(inputs): + if isinstance(inputs, torch.Tensor): + if inputs.dtype in (torch.float32, torch.float64): + torch._dynamo.utils.counters["randomize_input"]["times"] += 1 + return torch.randn_like(inputs) + elif inputs.dtype == torch.int64: + # Note: we can not simply tune integer tensors as follows + # `return torch.randint_like(inputs, high=inputs.max().item())` + # This may break some invariants between tensors. + # E.g. in embedding lookup case, one tensor is the length + # and another is an indices tensor. + return inputs + else: + raise RuntimeError( + f"randomize_input need support tensor of type {inputs.dtype}") + elif isinstance(inputs, (list, tuple)): + return type(inputs)([randomize_input(x) for x in inputs]) + elif isinstance(inputs, dict): + return dict((k, randomize_input(x)) for k, x in inputs.items()) + else: + logger.warning( + f"randomize_input can not handle input of type {type(inputs)}") + return inputs + + +@contextmanager +def set_cwd(path): + original_dir = abspath(os.getcwd()) + os.chdir(path) + try: + yield + finally: + os.chdir(original_dir) + + +def get_accelerator_model(accelerator): + if accelerator == "cpu": + return get_cpu_name() + elif accelerator == "gpu": + return get_gpu_name() + elif accelerator == "tpu": + return get_tpu_name() + else: + raise NotImplementedError + + +def get_cpu_name(): + return subprocess.check_output( + ["lscpu"], + encoding='utf-8').split("Model name:")[1].split("\n")[0].strip() + + +def get_gpu_name(): + return subprocess.check_output( + ["nvidia-smi", "--query-gpu=gpu_name", "--format=csv"], + encoding='utf-8').split("\n")[1] + + +def get_tpu_name(): + return ClusterResolver.get_instance_metadata( + 'instance/attributes/accelerator-type')