From 97fc8ea5f4cb9970c1b8a7cae3fe85c06fe261e3 Mon Sep 17 00:00:00 2001 From: "Edward Z. Yang" Date: Thu, 30 Mar 2023 10:16:07 -0400 Subject: [PATCH] Run the benchmark suite with dynamic batch only (#97912) Symbolic shapes compile time on full CI with inductor is horribly long (even though our aot_eager local runs seemed to suggest that the added latency was only 10s per model.) To patch over the problem for now, run the benchmark suite with dynamic batch only. This should absolve a lot of sins. Signed-off-by: Edward Z. Yang Pull Request resolved: https://github.com/pytorch/pytorch/pull/97912 Approved by: https://github.com/janeyx99, https://github.com/desertfire --- .ci/pytorch/test.sh | 4 ++-- benchmarks/dynamo/common.py | 24 ++++++++++++++++++++++++ torch/_dynamo/config.py | 7 +++++++ torch/_dynamo/guards.py | 7 ++++--- torch/_dynamo/variables/builder.py | 5 ++++- torch/fx/experimental/symbolic_shapes.py | 1 + 6 files changed, 42 insertions(+), 6 deletions(-) diff --git a/.ci/pytorch/test.sh b/.ci/pytorch/test.sh index 5fe1f916dc3598..45a8bfb511ad89 100755 --- a/.ci/pytorch/test.sh +++ b/.ci/pytorch/test.sh @@ -302,7 +302,7 @@ test_perf_for_dashboard() { --accuracy --"$dtype" --backend "$backend" "$@" \ --output "$TEST_REPORTS_DIR/${backend}_with_cudagraphs_${suite}_${dtype}_training_cuda_accuracy.csv" python "benchmarks/dynamo/$suite.py" \ - --accuracy --"$dtype" --backend "$backend" --dynamic-shapes --disable-cudagraphs "$@" \ + --accuracy --"$dtype" --backend "$backend" --dynamic-shapes --dynamic-batch-only --disable-cudagraphs "$@" \ --output "$TEST_REPORTS_DIR/${backend}_dynamic_${suite}_${dtype}_training_cuda_accuracy.csv" # Run performance test @@ -316,7 +316,7 @@ test_perf_for_dashboard() { --performance --cold-start-latency --"$dtype" --backend "$backend" "$@" \ --output "$TEST_REPORTS_DIR/${backend}_with_cudagraphs_${suite}_${dtype}_training_cuda_performance.csv" python "benchmarks/dynamo/$suite.py" \ - --performance --cold-start-latency --"$dtype" --backend "$backend" --dynamic-shapes --disable-cudagraphs "$@" \ + --performance --cold-start-latency --"$dtype" --backend "$backend" --dynamic-shapes --dynamic-batch-only --disable-cudagraphs "$@" \ --output "$TEST_REPORTS_DIR/${backend}_dynamic_${suite}_${dtype}_training_cuda_performance.csv" done } diff --git a/benchmarks/dynamo/common.py b/benchmarks/dynamo/common.py index 12365561ceda54..19b8c0d5db0dc6 100644 --- a/benchmarks/dynamo/common.py +++ b/benchmarks/dynamo/common.py @@ -1693,6 +1693,11 @@ def get_example_inputs(self): action="store_true", help="Runs a dynamic shapes version of the benchmark, if available.", ) + parser.add_argument( + "--dynamic-batch-only", + action="store_true", + help="Only assume batch dimension is dynamic. Implies --dynamic-shapes", + ) parser.add_argument( "--specialize-int", action="store_true", help="Run with specialize_int=True." ) @@ -1956,6 +1961,10 @@ def run(runner, args, original_dir=None): if args.dynamic_ci_skips_only: args.dynamic_shapes = True args.ci = True + if args.dynamic_batch_only: + args.dynamic_shapes = True + torch._dynamo.config.assume_static_by_default = True + torch._dynamo.config.allow_ignore_mark_dynamic = True if args.dynamic_shapes: torch._dynamo.config.dynamic_shapes = True if args.specialize_int: @@ -2329,6 +2338,21 @@ def run(runner, args, original_dir=None): elif args.bfloat16: model, example_inputs = cast_to_bf16(model, example_inputs) + # Look for stuff that looks like batch size, and mark it dynamic. + # Better integration would integrate directly with benchmark suite + # but cannot conveniently do this + # NB: This must be done late enough so that we don't do more + # conversions on the inputs + # NB: Assumes only the first batch-y like dimension is the batch + def detect_and_mark_batch(t): + for i, s in enumerate(t.size()): + if s == batch_size: + torch._dynamo.mark_dynamic(t, i) + break + + if args.dynamic_batch_only: + tree_map(detect_and_mark_batch, example_inputs) + if args.log_operator_inputs: log_operator_inputs( model, example_inputs, runner.model_iter_fn, name, args diff --git a/torch/_dynamo/config.py b/torch/_dynamo/config.py index 4a962920ff0dd5..c8433901301350 100644 --- a/torch/_dynamo/config.py +++ b/torch/_dynamo/config.py @@ -68,6 +68,13 @@ # see [Note - on the state of mark_dynamic] assume_static_by_default = False +# Typically, if you mark_dynamic a dimension, we will error if the dimension +# actually ended up getting specialized. This knob changes the behavior so +# that we don't error at all. This is helpful for our CI where I'm using a +# heuristic to mark batch dimensions as dynamic and the heuristic may get it +# wrong. +allow_ignore_mark_dynamic = False + # Set this to False to assume nn.Modules() contents are immutable (similar assumption as freezing) guard_nn_modules = False diff --git a/torch/_dynamo/guards.py b/torch/_dynamo/guards.py index 42b653e88cc10e..4a94cc6fa11fe8 100644 --- a/torch/_dynamo/guards.py +++ b/torch/_dynamo/guards.py @@ -536,9 +536,10 @@ def TENSOR_MATCH(self, guard: Guard): f"hasattr({tensor_name}, '_dynamo_dynamic_indices') == False" ) else: - assert not hasattr( - value, "_dynamo_dynamic_indices" - ), f"Illegal Unreachable state, guard accumulation for dynamic tensor that should have been static. Initial static message: {tensor_static_reason_to_message(reason)}" # noqa: B950 + if not config.allow_ignore_mark_dynamic: + assert not hasattr( + value, "_dynamo_dynamic_indices" + ), f"Illegal Unreachable state, guard accumulation for dynamic tensor that should have been static. Initial static message: {tensor_static_reason_to_message(reason)}" # noqa: B950 if len(code) > 0: self._produce_guard_code(guard, code) diff --git a/torch/_dynamo/variables/builder.py b/torch/_dynamo/variables/builder.py index 001eec9eddf8c5..897c597d5ffcbd 100644 --- a/torch/_dynamo/variables/builder.py +++ b/torch/_dynamo/variables/builder.py @@ -1165,7 +1165,10 @@ def wrap_to_fake_tensor_and_record( # Precedence: export constraints > eager constraints constraint = dim2constraint.get(i) if constraint is None: - if i in getattr(e, "_dynamo_dynamic_indices", set()): + if ( + i in getattr(e, "_dynamo_dynamic_indices", set()) + and not config.allow_ignore_mark_dynamic + ): constraint = RelaxedUnspecConstraint() constraint_dims.append(constraint) diff --git a/torch/fx/experimental/symbolic_shapes.py b/torch/fx/experimental/symbolic_shapes.py index 0c0b9e6a4db604..960ec2c92aea66 100644 --- a/torch/fx/experimental/symbolic_shapes.py +++ b/torch/fx/experimental/symbolic_shapes.py @@ -1579,6 +1579,7 @@ def create_symbol( # Even if we're duck shaping, if we haven't seen this particular # value before, we also create a new symbol sympy_expr = sympy.Symbol(f"s{len(self.var_to_val)}", positive=True, integer=True) + log.info("create_symbol %s = %s", sympy_expr, val) # We always associate vars to vals self.var_to_val[sympy_expr] = sympy.Integer(val) # Do the appending later, because we always want to populate this