Skip to content

Commit

Permalink
Adding new gradient checkpointing policies
Browse files Browse the repository at this point in the history
  • Loading branch information
erfanzar committed Oct 31, 2023
1 parent 86c31e3 commit 1574c31
Showing 1 changed file with 13 additions and 6 deletions.
19 changes: 13 additions & 6 deletions lib/python/EasyDel/modules/flax_modelling_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,12 +42,19 @@ def with_sharding_constraint(x, partition_specs):


def get_gradient_checkpoint_policy(name):
return {
'everything_saveable': jax.checkpoint_policies.everything_saveable,
'nothing_saveable': jax.checkpoint_policies.nothing_saveable,
'checkpoint_dots': jax.checkpoint_policies.checkpoint_dots,
'checkpoint_dots_with_no_batch_dims': jax.checkpoint_policies.checkpoint_dots_with_no_batch_dims,
}[name]
gradients = dict(
everything_saveable=jax.checkpoint_policies.everything_saveable,
nothing_saveable=jax.checkpoint_policies.nothing_saveable,
dots_saveable=jax.checkpoint_policies.dots_saveable,
checkpoint_dots=jax.checkpoint_policies.dots_saveable,
dots_with_no_batch_dims_saveable=jax.checkpoint_policies.dot_with_no_batch_dims_saveable,
checkpoint_dots_with_no_batch_dims=jax.checkpoint_policies.dot_with_no_batch_dims_saveable,
save_anything_except_these_names=jax.checkpoint_policies.save_anything_except_these_names,
save_any_names_but_these=jax.checkpoint_policies.save_any_names_but_these,
save_only_these_names=jax.checkpoint_policies.save_only_these_names,
save_from_both_policies=jax.checkpoint_policies.save_from_both_policies
)
return gradients[name]


def repeat_kv_bnsh(x: chex.Array, n_rep: int) -> chex.Array:
Expand Down

0 comments on commit 1574c31

Please sign in to comment.