Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Support stash_type attribute for onnx.LayerNormalization #3888

Merged
merged 2 commits into from
Nov 27, 2024

Conversation

jinchen62
Copy link
Collaborator

@jinchen62 jinchen62 commented Nov 22, 2024

Fixes nod-ai/SHARK-ModelDev#888

If stash_type is different from input_dtype/result_dtype:

  1. convert x dtype to stash_type
  2. calculate mean and var in stash_type since x is in stash_type already
  3. convert back to result_dtype before stage two calculation
  4. convert mean_dtype and var_dtype if they are different from stash_type

e2e test added in nod-ai/SHARK-TestSuite#399

@zjgarvey
Copy link
Collaborator

I think we should probably support the stash type arg by separating the two stages of computation as is suggested by ONNX in https://onnx.ai/onnx/operators/onnx__LayerNormalization.html. If an onnx op actually has different result types and stash types, we would likely see numeric mismatches for those situations unless we perform the computation correctly.

Another option is to allow LayerNormalization to be function-expanded on import via

function_expansion_allowlists_by_domain: Optional[Dict[str, set[str]]] = field(

In any case, we should put together a few e2e tests for this op:

  1. With bf16 result type and bf16 stash type
  2. With bf16 result type and unspecified stash type

@jinchen62 jinchen62 changed the title Remove stash_type check for onnx.LayerNormalization lowering Support stash_type attribute for onnx.LayerNormalization Nov 24, 2024
Copy link
Collaborator

@zjgarvey zjgarvey left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I suppose I still don't get why this works since the two stages mentioned in the onnx docs aren't separated in the torch op. Technically, aren't we supposed to cast back to the original result type before the final mul and add?

I think it is fine to merge this as-is, considering the examples we tested seem to give correct numerics. If we end up seeing numeric failures in models with layer normalization and weird dtypes, we can always fallback on the function expander on import.

Can you mention the relevant e2e tests in the commit message?

@jinchen62
Copy link
Collaborator Author

@zjgarvey Yeah I convert the dtype back before stage two if the dtype is different in decomposition. Added link of e2e test in commit message.

@jinchen62 jinchen62 merged commit 7452460 into llvm:main Nov 27, 2024
3 checks passed
@jinchen62 jinchen62 deleted the layer_norm branch November 27, 2024 00:47
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

(torch-to-onnx) FLUX.1 - bf16 onnx.LayerNormalization failing to legalize
2 participants