From 1574c3123dd0be5c542ae2a79ccaa7fe4f87a4ec Mon Sep 17 00:00:00 2001 From: Erfan Zare Chavoshi <59269023+erfanzar@users.noreply.github.com> Date: Tue, 31 Oct 2023 12:21:50 +0330 Subject: [PATCH] Adding new `gradient checkpointing policies` --- .../EasyDel/modules/flax_modelling_utils.py | 19 +++++++++++++------ 1 file changed, 13 insertions(+), 6 deletions(-) diff --git a/lib/python/EasyDel/modules/flax_modelling_utils.py b/lib/python/EasyDel/modules/flax_modelling_utils.py index e6a39aa07..91196489a 100644 --- a/lib/python/EasyDel/modules/flax_modelling_utils.py +++ b/lib/python/EasyDel/modules/flax_modelling_utils.py @@ -42,12 +42,19 @@ def with_sharding_constraint(x, partition_specs): def get_gradient_checkpoint_policy(name): - return { - 'everything_saveable': jax.checkpoint_policies.everything_saveable, - 'nothing_saveable': jax.checkpoint_policies.nothing_saveable, - 'checkpoint_dots': jax.checkpoint_policies.checkpoint_dots, - 'checkpoint_dots_with_no_batch_dims': jax.checkpoint_policies.checkpoint_dots_with_no_batch_dims, - }[name] + gradients = dict( + everything_saveable=jax.checkpoint_policies.everything_saveable, + nothing_saveable=jax.checkpoint_policies.nothing_saveable, + dots_saveable=jax.checkpoint_policies.dots_saveable, + checkpoint_dots=jax.checkpoint_policies.dots_saveable, + dots_with_no_batch_dims_saveable=jax.checkpoint_policies.dot_with_no_batch_dims_saveable, + checkpoint_dots_with_no_batch_dims=jax.checkpoint_policies.dot_with_no_batch_dims_saveable, + save_anything_except_these_names=jax.checkpoint_policies.save_anything_except_these_names, + save_any_names_but_these=jax.checkpoint_policies.save_any_names_but_these, + save_only_these_names=jax.checkpoint_policies.save_only_these_names, + save_from_both_policies=jax.checkpoint_policies.save_from_both_policies + ) + return gradients[name] def repeat_kv_bnsh(x: chex.Array, n_rep: int) -> chex.Array: