Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

fp8 bwd #108

Draft
wants to merge 10 commits into
base: main_perf
Choose a base branch
from
Draft

fp8 bwd #108

wants to merge 10 commits into from

Conversation

micmelesse
Copy link
Collaborator

No description provided.

alexkranias-amd and others added 10 commits December 9, 2024 10:09
feat: added fp32 output to input_helper

passing

feat: fp8 tests. small amount of error

added fp8e5m2 type

note: RuntimeError: "abs_cuda" not implemented for 'Float8_e4m3fnuz'

enabled fp8 GEMMs

fix: error down to < 0.1

added another fp8 dtype

best accuracy is with no scaling

improved accuracy to within < 0.02. issue related to torch side casting

fix: passes if we allow v to be fp16 instead of fp8. otherwise we have error < 0.1

all error is < 0.07

feat: added per head scaling tensors

progress towards implementing scaling tensors in kernel

save

issue: error caused by acc += tl.dot(p.to(v.type.element_ty), v)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

2 participants