Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[FA] fix an assertion failure due to refactoring in PR54 #69

Closed
wants to merge 3 commits into from

Conversation

manman-ren
Copy link
Contributor

@manman-ren manman-ren commented Nov 21, 2024

We move the static_assert to the top-level kernel. After moving, the static_assert will be caught by autotuner:
try:
return self.do_bench(kernel_call, quantiles=(0.5, 0.2, 0.8))
except (OutOfResources, CompileTimeAssertionFailure, PTXASError):
return [float("inf"), float("inf"), float("inf")]

Prior to the change, CompileTimeAssertionFailure somehow is not caught and got reported and failed the build.

Verified with: python run.py --op fp8_attention
python run.py --op flash_attention --only triton_tutorial_flash_v2 --num-inputs 1 --metrics tflops --num-inputs 1

We move the static_assert to the top-level kernel. After moving, the
static_assert will be caught by autotuner:
        try:
            return self.do_bench(kernel_call, quantiles=(0.5, 0.2, 0.8))
        except (OutOfResources, CompileTimeAssertionFailure, PTXASError):
            return [float("inf"), float("inf"), float("inf")]

Prior to the change, CompileTimeAssertionFailure somehow is not caught
and got reported and failed the build.
@xuzhao9
Copy link
Contributor

xuzhao9 commented Nov 21, 2024

Thanks! The CI signals are broken rn so we can ignore the CI signals.

Just to confirm, it only fixes the base variant, and opt/ws variants are not working, is it expected?

$ python run.py --op flash_attention --num-inputs 1 --metrics tflops --num-inputs 1
TMA benchmarks will be running with experimental grid constant TMA descriptor.
xformers import built-in _C_flashattention3
  0%|                                                                                                                                                                                                                                     | 0/1 [00:04<?, ?it/s]
Caught exception, terminating early with partial results
Traceback (most recent call last):
  File "/data/users/xzhao9/tritonbench/tritonbench/utils/triton_op.py", line 740, in run
    y_vals: Dict[str, BenchmarkOperatorMetrics] = functools.reduce(
                                                  ^^^^^^^^^^^^^^^^^
  File "/data/users/xzhao9/tritonbench/tritonbench/utils/triton_op.py", line 728, in _reduce_benchmarks
    acc[bm_name] = self._do_bench(
                   ^^^^^^^^^^^^^^^
  File "/data/users/xzhao9/tritonbench/tritonbench/utils/triton_op.py", line 957, in _do_bench
    raise e
  File "/data/users/xzhao9/tritonbench/tritonbench/utils/triton_op.py", line 948, in _do_bench
    metrics.latency = triton.testing.do_bench(
                      ^^^^^^^^^^^^^^^^^^^^^^^^
  File "/data/users/xzhao9/triton/python/triton/testing.py", line 117, in do_bench
    fn()
  File "/data/users/xzhao9/tritonbench/tritonbench/operators/flash_attention/operator.py", line 258, in <lambda>
    return lambda: triton_tutorial_FA2_opt(
                   ^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/xzhao9/.conda/envs/py312/lib/python3.12/site-packages/torch/autograd/function.py", line 575, in apply
    return super().apply(*args, **kwargs)  # type: ignore[misc]
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/data/users/xzhao9/tritonbench/tritonbench/kernels/triton_fused_attention.py", line 1783, in forward
    _attn_fwd_opt[grid_tma](
  File "/data/users/xzhao9/triton/python/triton/runtime/jit.py", line 330, in <lambda>
    return lambda *args, **kwargs: self.run(grid=grid, warmup=False, *args, **kwargs)
                                   ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/data/users/xzhao9/triton/python/triton/runtime/autotuner.py", line 206, in run
    ret = self.fn.run(
          ^^^^^^^^^^^^
  File "/data/users/xzhao9/triton/python/triton/runtime/jit.py", line 623, in run
    kernel = self.compile(
             ^^^^^^^^^^^^^
  File "/data/users/xzhao9/triton/python/triton/compiler/compiler.py", line 280, in compile
    module = src.make_ir(options, codegen_fns, module_map, context)
             ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/data/users/xzhao9/triton/python/triton/compiler/compiler.py", line 85, in make_ir
    return ast_to_ttir(self.fn, self, context=context, options=options, codegen_fns=codegen_fns,
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
triton.compiler.errors.CompileTimeAssertionFailure: at 39:4:
    Z,
    H,
    N_CTX,  #: tl.constexpr,  #
    BLOCK_M: tl.constexpr,  #
    BLOCK_N: tl.constexpr,  #
    HEAD_DIM: tl.constexpr,  #
    STAGE: tl.constexpr,  #
    ENABLE_TMA: tl.constexpr,
    LOOP_SCHEDULE: tl.constexpr,
    ENABLE_WS: tl.constexpr,
):
    tl.static_assert(BLOCK_N <= HEAD_DIM)

Summary:

Test Plan:

Reviewers:

Subscribers:

Tasks:

Tags:
Summary:

Test Plan:

Reviewers:

Subscribers:

Tasks:

Tags:
@facebook-github-bot
Copy link
Contributor

@manman-ren has imported this pull request. If you are a Meta employee, you can view this diff on Phabricator.

Copy link
Contributor

@adamomainz adamomainz left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

accepted internally after reviewing so accepting here. Please be sure to respond to Xu's question before shipping :)

Copy link
Contributor

@xuzhao9 xuzhao9 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM. Thanks Manman for fixing this!

@facebook-github-bot
Copy link
Contributor

@manman-ren merged this pull request in 8f8db26.

xuzhao9 pushed a commit that referenced this pull request Nov 22, 2024
Summary:
We move the static_assert to the top-level kernel. After moving, the static_assert will be caught by autotuner:
        try:
            return self.do_bench(kernel_call, quantiles=(0.5, 0.2, 0.8))
        except (OutOfResources, CompileTimeAssertionFailure, PTXASError):
            return [float("inf"), float("inf"), float("inf")]

Prior to the change, CompileTimeAssertionFailure somehow is not caught and got reported and failed the build.

Verified with: python run.py --op fp8_attention
python run.py --op flash_attention --only triton_tutorial_flash_v2 --num-inputs 1 --metrics tflops --num-inputs 1

Pull Request resolved: #69

Reviewed By: xuzhao9, adamomainz

Differential Revision: D66336174

Pulled By: manman-ren

fbshipit-source-id: 95d29821e6cba45af535b11020aa51424a408789
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

Successfully merging this pull request may close these issues.

4 participants