Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Forward subprocess stdout and stderr in all cases #6154

Merged
merged 1 commit into from
Dec 15, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions benchmarks/benchmark_experiment.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,12 +40,12 @@ def list_experiment_configs(self):
map(parse_none_str, set(self._args.xla_flags)))

# Expand experiment configs and add env vars.
logger.info(f"Expand experiment configs:")
logger.debug(f"Expand experiment configs")
experiment_configs = []
for cfg in self._expand_config_choices(config_choices):
if not self._is_available(cfg):
continue
logger.info(f"Experiment config (w/o env vars): {cfg}")
logger.debug(f"Experiment config (w/o env vars): {cfg}")
self._add_experiment_env(cfg)
experiment_configs.append(cfg)
return experiment_configs
Expand Down
181 changes: 99 additions & 82 deletions benchmarks/experiment_runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,15 +44,18 @@ def __init__(self, args):
self.output_file = os.path.join(self.output_dir, self._args.output_basename)

def run(self):
is_main_process = self._args.experiment_config is None and self._args.model_config is None
is_main_process = self._args.experiment_config is None and \
self._args.model_config is None
if is_main_process:
self.generate_and_run_all_configs()
else:
assert self._args.experiment_config is not None and self._args.model_config is not None
assert self._args.experiment_config is not None and \
self._args.model_config is not None
self.run_single_config()

def generate_and_run_all_configs(self):
assert self._args.experiment_config is None and self._args.model_config is None
assert self._args.experiment_config is None and \
self._args.model_config is None

# Collect fingerprints for configs to skip. These are configs for which we
# already have results. The derived fingerprints uniquely identify the
Expand All @@ -70,89 +73,100 @@ def generate_and_run_all_configs(self):
self._get_config_fingerprint(ln_dict["experiment"],
ln_dict["model"]))

# Enumerate experiment and model configs and launch subprocesses.
experiment_configs = self.experiment_loader.list_experiment_configs()
model_configs = self.model_loader.list_model_configs()
logger.warning(
logger.info(
f"Number of selected experiment configs: {len(experiment_configs)}")
logger.warning(f"Number of selected model configs: {len(model_configs)}")
for model_config in tqdm(
logger.info(f"Number of selected model configs: {len(model_configs)}")
for model_cfg in tqdm(
model_configs,
desc="model configs",
desc="Running benchmark configs by model",
disable=not self._args.progress_bar):
for experiment_config in experiment_configs:
process_env = experiment_config.pop("process_env")
experiment_config_str = json.dumps(experiment_config)
model_config_str = json.dumps(model_config)
dummy_benchmark_experiment = self.experiment_loader.load_experiment(
experiment_config)
dummy_benchmark_model = self.model_loader.load_model(
model_config, dummy_benchmark_experiment, dummy=True)
process_env = dummy_benchmark_model.extend_process_env(process_env)
experiment_config["process_env"] = process_env
command = ([sys.executable] + sys.argv +
[f"--experiment-config={experiment_config_str}"] +
[f"--model-config={model_config_str}"])
# TODO: Actually run this and rely on dry running in subprocess.
for experiment_cfg in experiment_configs:

# Log run and configs.
experiment_cfg_wo_env = experiment_cfg.copy()
process_env = experiment_cfg_wo_env.pop("process_env")
logger.info(f"Run with --model-config={json.dumps(model_cfg)} "
f"--experiment-config={json.dumps(experiment_cfg_wo_env)}")

# Move on if dry running.
if self._args.dry_run:
logger.warning(f"Dry run with {command}")
continue

# TODO: See if we can pass experiment_cfg to `load_experiment`.
benchmark_experiment = self.experiment_loader.load_experiment(
experiment_cfg_wo_env)
benchmark_model = self.model_loader.load_model(
model_cfg, benchmark_experiment, dummy=True)

# Skip already completed benchmark.
fingerprint = self._get_config_fingerprint(
dummy_benchmark_experiment.to_dict(),
dummy_benchmark_model.to_dict())
benchmark_experiment.to_dict(), benchmark_model.to_dict())
if fingerprint in skip_fingerprints:
logger.info(f"Skipping {fingerprint}")
logger.info(f"SKIP already completed benchmark")
continue
if self.model_loader.is_compatible(dummy_benchmark_model,
dummy_benchmark_experiment):
try:
completed_process = subprocess.run(
command,
timeout=60 * 30,
env=process_env,
check=True,
capture_output=True,
encoding="utf-8",
)
except subprocess.TimeoutExpired as e:
logger.error("TIMEOUT")
self.save_results(dummy_benchmark_experiment, dummy_benchmark_model,
{"error": str(e)}, None)
except subprocess.CalledProcessError as e:
logger.error("ERROR")
self.save_results(dummy_benchmark_experiment, dummy_benchmark_model,
{"error": e.stderr}, None)
except subprocess.SubprocessError as e:
logger.error("ERROR")
self.save_results(dummy_benchmark_experiment, dummy_benchmark_model,
{"error": str(e)}, None)
else:
if self._args.print_subprocess:
logger.info(completed_process.stdout)
logger.warning(completed_process.stderr)

else:
e = "SKIP because of incompatible model and experiment configs."
logger.warning(e)
self.save_results(dummy_benchmark_experiment, dummy_benchmark_model,

# Skip unsupported config.
if not self.model_loader.is_compatible(benchmark_model,
benchmark_experiment):
logger.warning("SKIP incompatible model and experiment configs.")
self.save_results(benchmark_experiment, benchmark_model,
{"error": "SKIP"}, None)
continue

# Launch subprocess.
try:
process_env = benchmark_model.extend_process_env(process_env)
command = [sys.executable] + sys.argv + [
f"--experiment-config={json.dumps(experiment_cfg)}"
] + [f"--model-config={json.dumps(model_cfg)}"]
command_str = " ".join(command)
logger.debug(f"Run `{command_str}`")
child_process = subprocess.run(
command,
timeout=self._args.subprocess_timeout,
env=process_env,
check=True,
capture_output=True,
text=True,
)
self._fwd_captured_stdout_stderr(child_process.stdout,
child_process.stderr)
except subprocess.TimeoutExpired as e:
self._fwd_captured_stdout_stderr(e.stdout, e.stderr)
logger.error("TIMEOUT")
self.save_results(benchmark_experiment, benchmark_model,
{"error": str(e)}, None)
except subprocess.CalledProcessError as e:
self._fwd_captured_stdout_stderr(e.stdout, e.stderr)
logger.error("ERROR in subprocess")
self.save_results(benchmark_experiment, benchmark_model,
{"error": e.stderr}, None)
except subprocess.SubprocessError as e:
logger.error("ERROR when launching child process")
self.save_results(benchmark_experiment, benchmark_model,
{"error": str(e)}, None)

def _get_config_fingerprint(self, experiment_config: OrderedDict,
model_config: OrderedDict) -> str:
# Experiment `batch_size` may be altered by model in `set_up`, so we will ignore that.
# Experiment `batch_size` may be altered by model in `set_up`, so we will
# ignore that.
return "-".join(
list(map(str, model_config.values())) +
[str(v) for k, v in experiment_config.items() if k != "batch_size"] +
[str(self._args.batch_size)])

def _fwd_captured_stdout_stderr(self, stdout_text: str, stderr_text: str):
if not self._args.print_subprocess:
return
print(stdout_text, file=sys.stdout, end='', flush=True)
print(stderr_text, file=sys.stderr, end='', flush=True)

def run_single_config(self):
experiment_config = json.loads(self._args.experiment_config)
model_config = json.loads(self._args.model_config)

# Log and return if dry run.
if self._args.dry_run:
logger.info(f"Dry run with {[sys.executable] + sys.argv}")
return

benchmark_experiment = self.experiment_loader.load_experiment(
experiment_config)
reset_rng_state(benchmark_experiment)
Expand Down Expand Up @@ -193,13 +207,8 @@ def save_results(self, benchmark_experiment, benchmark_model, metrics,
results["metrics"] = metrics
results["outputs_file"] = outputs_file_name

self.output_jsonl(results)

def output_jsonl(self, obj, file_path=None):
if not file_path:
file_path = self.output_file
json_str = json.dumps(obj, ensure_ascii=False)
with open(file_path, mode="a", encoding="utf-8") as f:
json_str = json.dumps(results, ensure_ascii=False)
with open(self.output_file, mode="a", encoding="utf-8") as f:
f.write(f"{json_str}\n")

def _mark_step(self, benchmark_experiment):
Expand Down Expand Up @@ -439,7 +448,7 @@ def parse_log_level(level: str):

parser.add_argument(
"--log-level",
default="info",
default=logging.INFO,
choices=[
logging.CRITICAL,
logging.ERROR,
Expand Down Expand Up @@ -489,7 +498,8 @@ def parse_log_level(level: str):
parser.add_argument(
"--batch-size",
type=int,
help="Batch size to be used. If not provided, it depends on the model suites to determine it.",
help="""Batch size to be used. If not provided, it depends on the model
suites to determine it.""",
)
parser.add_argument(
"--total-partitions",
Expand All @@ -512,7 +522,13 @@ def parse_log_level(level: str):
parser.add_argument(
"--print-subprocess",
action="store_true",
help="Print subprocess stdout.",
help="Forward subprocess stdout and stderr.",
)
parser.add_argument(
"--subprocess-timeout",
type=int,
default=60 * 30,
help="Timeout per launched config subprocess.",
)
parser.add_argument(
"--progress-bar",
Expand All @@ -522,13 +538,14 @@ def parse_log_level(level: str):
parser.add_argument(
"--randomize-input",
action="store_true",
help="Whether to randomize the input values. Dimensions will be kept the same.",
help="""Whether to randomize the input values. Dimensions will be kept
the same.""",
)
parser.add_argument(
"--collect-full-output",
action="store_true",
help="""Whether to collect full output for training. Set this to true if we
want to verify the numerical correctness of gradients. But that may
help="""Whether to collect full output for training. Set this to true if
we want to verify the numerical correctness of gradients. But that may
cause time measurement not accurate""",
)
parser.add_argument(
Expand All @@ -553,13 +570,14 @@ def parse_log_level(level: str):
action="store_true",
help="""By default, the runner would skip the finished experiments that
exist in the output-basename file. If --no-resume is set, the previous
output-basename file will be deleted and all experiment will run""",
output-basename file will be deleted and all experiment will run.""",
)
parser.add_argument(
"--profile-cuda",
action="store_true",
help="""Whether to profile CUDA or not. Note this does not do much except for
triggering a profiler. To get the profiling data use additionally --profile-cuda-dump""",
help="""Whether to profile CUDA or not. Note this does not do much except
for triggering a profiler. To get the profiling data use additionally
--profile-cuda-dump""",
)
parser.add_argument(
"--profile-cuda-dump",
Expand Down Expand Up @@ -607,8 +625,7 @@ def main():
args.exclude = args.exclude or [r"^$"]

logging.basicConfig(level=args.log_level, force=True)

logger.info(args)
logger.debug(f"Parsed args: {args}")

if not args.disable_tf32:
logger.warning('Enabling fast F32 multiplication for PyTorch')
Expand Down
26 changes: 13 additions & 13 deletions test/benchmarks/test_experiment_runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,8 +22,8 @@ def test_dummy_dry_run(self):
expected_in_stderr = [
"Number of selected experiment configs: 2",
"Number of selected model configs: 1",
"'--experiment-config={\"accelerator\": \"cpu\", \"xla\": \"PJRT\", \"xla_flags\": null, \"dynamo\": \"openxla\", \"test\": \"eval\"}', '--model-config={\"model_name\": \"dummy\"}'",
"'--experiment-config={\"accelerator\": \"cpu\", \"xla\": \"PJRT\", \"xla_flags\": null, \"dynamo\": \"openxla\", \"test\": \"train\"}', '--model-config={\"model_name\": \"dummy\"}'",
"--model-config={\"model_name\": \"dummy\"} --experiment-config={\"accelerator\": \"cpu\", \"xla\": \"PJRT\", \"xla_flags\": null, \"dynamo\": \"openxla\", \"test\": \"eval\"}",
"--model-config={\"model_name\": \"dummy\"} --experiment-config={\"accelerator\": \"cpu\", \"xla\": \"PJRT\", \"xla_flags\": null, \"dynamo\": \"openxla\", \"test\": \"train\"}",
]
for expected in expected_in_stderr:
self.assertIn(expected, child.stderr)
Expand All @@ -40,10 +40,10 @@ def test_dummy_dry_run_cuda(self):
expected_in_stderr = [
"Number of selected experiment configs: 4",
"Number of selected model configs: 1",
"'--experiment-config={\"accelerator\": \"cuda\", \"xla\": \"PJRT\", \"xla_flags\": null, \"dynamo\": \"openxla\", \"test\": \"eval\"}', '--model-config={\"model_name\": \"dummy\"}'",
"'--experiment-config={\"accelerator\": \"cuda\", \"xla\": \"PJRT\", \"xla_flags\": null, \"dynamo\": \"openxla\", \"test\": \"train\"}', '--model-config={\"model_name\": \"dummy\"}'",
"'--experiment-config={\"accelerator\": \"cuda\", \"xla\": null, \"xla_flags\": null, \"dynamo\": \"inductor\", \"test\": \"eval\"}', '--model-config={\"model_name\": \"dummy\"}'",
"'--experiment-config={\"accelerator\": \"cuda\", \"xla\": null, \"xla_flags\": null, \"dynamo\": \"inductor\", \"test\": \"train\"}', '--model-config={\"model_name\": \"dummy\"}'",
"--model-config={\"model_name\": \"dummy\"} --experiment-config={\"accelerator\": \"cuda\", \"xla\": \"PJRT\", \"xla_flags\": null, \"dynamo\": \"openxla\", \"test\": \"eval\"}",
"--model-config={\"model_name\": \"dummy\"} --experiment-config={\"accelerator\": \"cuda\", \"xla\": \"PJRT\", \"xla_flags\": null, \"dynamo\": \"openxla\", \"test\": \"train\"}",
"--model-config={\"model_name\": \"dummy\"} --experiment-config={\"accelerator\": \"cuda\", \"xla\": null, \"xla_flags\": null, \"dynamo\": \"inductor\", \"test\": \"eval\"}",
"--model-config={\"model_name\": \"dummy\"} --experiment-config={\"accelerator\": \"cuda\", \"xla\": null, \"xla_flags\": null, \"dynamo\": \"inductor\", \"test\": \"train\"}",
]
for expected in expected_in_stderr:
self.assertIn(expected, child.stderr)
Expand All @@ -60,8 +60,8 @@ def test_dummy_dry_run_inductor_cuda(self):
expected_in_stderr = [
"Number of selected experiment configs: 2",
"Number of selected model configs: 1",
"'--experiment-config={\"accelerator\": \"cuda\", \"xla\": null, \"xla_flags\": null, \"dynamo\": \"inductor\", \"test\": \"eval\"}', '--model-config={\"model_name\": \"dummy\"}'",
"'--experiment-config={\"accelerator\": \"cuda\", \"xla\": null, \"xla_flags\": null, \"dynamo\": \"inductor\", \"test\": \"train\"}', '--model-config={\"model_name\": \"dummy\"}'",
"--model-config={\"model_name\": \"dummy\"} --experiment-config={\"accelerator\": \"cuda\", \"xla\": null, \"xla_flags\": null, \"dynamo\": \"inductor\", \"test\": \"eval\"}",
"--model-config={\"model_name\": \"dummy\"} --experiment-config={\"accelerator\": \"cuda\", \"xla\": null, \"xla_flags\": null, \"dynamo\": \"inductor\", \"test\": \"train\"}",
]
for expected in expected_in_stderr:
self.assertIn(expected, child.stderr)
Expand All @@ -79,11 +79,11 @@ def test_dummy_openxla_eval_train_cuda(self):
expected_in_stderr = [
"Number of selected experiment configs: 5",
"Number of selected model configs: 1",
"'--experiment-config={\"accelerator\": \"cuda\", \"xla\": \"PJRT\", \"xla_flags\": null, \"dynamo\": \"openxla_eval\", \"test\": \"eval\"}', '--model-config={\"model_name\": \"dummy\"}'",
"'--experiment-config={\"accelerator\": \"cuda\", \"xla\": \"PJRT\", \"xla_flags\": null, \"dynamo\": \"openxla\", \"test\": \"train\"}', '--model-config={\"model_name\": \"dummy\"}'",
"'--experiment-config={\"accelerator\": \"cuda\", \"xla\": \"PJRT\", \"xla_flags\": null, \"dynamo\": \"openxla\", \"test\": \"eval\"}', '--model-config={\"model_name\": \"dummy\"}'",
"'--experiment-config={\"accelerator\": \"cuda\", \"xla\": null, \"xla_flags\": null, \"dynamo\": \"inductor\", \"test\": \"eval\"}', '--model-config={\"model_name\": \"dummy\"}'",
"'--experiment-config={\"accelerator\": \"cuda\", \"xla\": null, \"xla_flags\": null, \"dynamo\": \"inductor\", \"test\": \"train\"}', '--model-config={\"model_name\": \"dummy\"}'",
"--model-config={\"model_name\": \"dummy\"} --experiment-config={\"accelerator\": \"cuda\", \"xla\": \"PJRT\", \"xla_flags\": null, \"dynamo\": \"openxla_eval\", \"test\": \"eval\"}",
"--model-config={\"model_name\": \"dummy\"} --experiment-config={\"accelerator\": \"cuda\", \"xla\": \"PJRT\", \"xla_flags\": null, \"dynamo\": \"openxla\", \"test\": \"train\"}",
"--model-config={\"model_name\": \"dummy\"} --experiment-config={\"accelerator\": \"cuda\", \"xla\": \"PJRT\", \"xla_flags\": null, \"dynamo\": \"openxla\", \"test\": \"eval\"}",
"--model-config={\"model_name\": \"dummy\"} --experiment-config={\"accelerator\": \"cuda\", \"xla\": null, \"xla_flags\": null, \"dynamo\": \"inductor\", \"test\": \"eval\"}",
"--model-config={\"model_name\": \"dummy\"} --experiment-config={\"accelerator\": \"cuda\", \"xla\": null, \"xla_flags\": null, \"dynamo\": \"inductor\", \"test\": \"train\"}",
]
for expected in expected_in_stderr:
self.assertIn(expected, child.stderr)
Expand Down