This repository has been archived by the owner on Oct 19, 2024. It is now read-only.
[FEATURE] Current gradient accumulation only works for jnp.mean loss #268
Labels
enhancement
New feature
The current gradient accumulation only works for jnp.mean loss because we always use mean reduction.
For other losses or auxiliary states, we should support other reduction types such as sum reduction and concatenation reduction.
The text was updated successfully, but these errors were encountered: