Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

(torch-to-onnx) FLUX.1 - bf16 onnx.LayerNormalization failing to legalize #888

Open
monorimet opened this issue Nov 21, 2024 · 3 comments · May be fixed by llvm/torch-mlir#3888
Open

(torch-to-onnx) FLUX.1 - bf16 onnx.LayerNormalization failing to legalize #888

monorimet opened this issue Nov 21, 2024 · 3 comments · May be fixed by llvm/torch-mlir#3888
Assignees

Comments

@monorimet
Copy link
Contributor

Hi all,
I'm trying to compile bf16 flux mmdit from onnx export.

Running into the following torch-to-onnx legalization error:

iree-compile --iree-hal-target-device=amdgpu --iree-hip-target=gfx942 --iree-hal-target-backends=rocm --iree-hip-target=gfx942 --iree-execution-model=async-external flux_1_dev_static_bf16.mlir -o flux-dev_sampler_bs1_512_1024x1024_bf16_amdgpu-gfx942.vmfb


flux_1_dev_static_bf16.mlir:3873:13: error: failed to legalize operation 'torch.operator' that was explicitly marked illegal
    %1315 = torch.operator "onnx.LayerNormalization"(%1161, %1313, %1314) {torch.onnx.axis = -1 : si64, torch.onnx.epsilon = 9.99999997E-7 : f32} : (!torch.vtensor<[1,4096,3072],bf16>, !torch.vtensor<[3072],bf16>, !torch.vtensor<[3072],bf16>) -> !torch.vtensor<[1,4096,3072],bf16> 
            ^
flux_1_dev_static_bf16.mlir:3873:13: note: see current operation: %3010 = "torch.operator"(%2325, %3007, %3009) <{name = "onnx.LayerNormalization"}> {torch.onnx.axis = -1 : si64, torch.onnx.epsilon = 9.99999997E-7 : f32} : (!torch.vtensor<[1,4096,3072],bf16>, !torch.vtensor<[3072],bf16>, !torch.vtensor<[3072],bf16>) -> !torch.vtensor<[1,4096,3072],bf16>

reproducible with the following MLIR and compile command:

onnxln_test.mlir

func.func @test_layer_norm_single_result(%arg0: !torch.vtensor<[1,4,768],bf16>, %arg1: !torch.vtensor<[768],bf16>, %arg2: !torch.vtensor<[768],bf16>) -> (!torch.vtensor<[1,4,768], bf16>) 
                           attributes {torch.onnx_meta.ir_version = 6 : si64, torch.onnx_meta.opset_version = 17 : si64, torch.onnx_meta.producer_name = "backend-test", torch.onnx_meta.producer_version = ""} { 
  %0 = torch.operator "onnx.LayerNormalization"(%arg0, %arg1, %arg2) {torch.onnx.axis = -1 : si64, torch.onnx.epsilon = 9.99999974E-6 : bf16} : (!torch.vtensor<[1,4,768],bf16>, !torch.vtensor<[768],bf16>, !torch.vtensor<[768],bf16>) -> !torch.vtensor<[1,4,768],bf16>
  return %0 : !torch.vtensor<[1,4,768],bf16>
}

compile command:

iree-compile onnxln_test.mlir -o ln.vmfb --iree-hal-target-device=hip --iree-hip-target=gfx942

The minimized reproducer may be taking some liberties as to a "correct" usage of bf16 layernorm -- I took our fp32 test in torch-mlir and find+replaced "fp32" with "bf16", which I'm not confident in, but it does reproduce the same error.

@jinchen62
Copy link
Contributor

jinchen62 commented Nov 22, 2024

There are two issues.

  1. We are missing TorchToLinalg lowering support for LayerNorm op. I will work on it.
  2. To make the bf16 working like the f32 case to lower to torch level, the reproducer should be like
func.func @test_layer_norm_single_result(%arg0: !torch.vtensor<[1,4,768],bf16>, %arg1: !torch.vtensor<[768],bf16>, %arg2: !torch.vtensor<[768],bf16>) -> (!torch.vtensor<[1,4,768], bf16>) attributes {torch.onnx_meta.ir_version = 6 : si64, torch.onnx_meta.opset_version = 17 : si64, torch.onnx_meta.producer_name = "backend-test", torch.onnx_meta.producer_version = ""} { 
  %0 = torch.operator "onnx.LayerNormalization"(%arg0, %arg1, %arg2) {torch.onnx.axis = -1 : si64, torch.onnx.epsilon = 9.99999974E-6 : f32, torch.onnx.stash_type = 16 : si64} : (!torch.vtensor<[1,4,768],bf16>, !torch.vtensor<[768],bf16>, !torch.vtensor<[768],bf16>) -> !torch.vtensor<[1,4,768],bf16>
  return %0 : !torch.vtensor<[1,4,768],bf16>
}

The attribute epsilon should be f32, and it should come with stash_type to match bf16 type.

@jinchen62 jinchen62 self-assigned this Nov 22, 2024
@jinchen62
Copy link
Contributor

@monorimet Actually we do have decomposition for torch.aten.native_layer_norm which is lowered from onnx.laynorm. I was able to compile the bf16 test I posted above. For your case, so doesn't it come with the stash_type attribute?

@monorimet
Copy link
Contributor Author

@jinchen62 Thanks, I forgot to include the MLIR: https://sharkpublic.blob.core.windows.net/sharkpublic/flux.1/flux_1_dev_static_bf16.mlir
It does not seem to come with the stash_type attribute, only {torch.onnx.axis = -1 : si64, torch.onnx.epsilon = 9.99999997E-7 : f32}

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging a pull request may close this issue.

2 participants