Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fixes flash_attn + cascade attention_code to decoder Transformer bloc… #71

Merged

Conversation

colon3ltocard
Copy link
Contributor

@colon3ltocard colon3ltocard commented Oct 10, 2024

  • Cascade attention_code to decoder Transformer block (wasn't the case !)
  • fixes flash_attn dim ordering (It still won't work on grib > 64x64 due to the way the unetrpp splits attention heads accross channels and not flattened physical (pixels, voxels) dims
  • adds num_heads_decoder to allow changing the number of attention heads in the decoder
  • adds a Dockerfile to work on EWC with A100 and flash_attn, we need an older cuda version, added doc for that

@colon3ltocard colon3ltocard requested a review from LBerth October 10, 2024 12:24
@LBerth LBerth merged commit 456461b into meteofrance:main Oct 10, 2024
1 check passed
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

3 participants