From 6ad37c97914da287d3740cea5df33313e485275f Mon Sep 17 00:00:00 2001 From: David Hall Date: Tue, 21 Jan 2025 16:12:12 -0800 Subject: [PATCH] Partial checkpoints (#861) --- config/harness/eval_llama3.yaml | 52 +++++++++++------------ src/levanter/checkpoint.py | 11 ++++- src/levanter/tensorstore_serialization.py | 51 ++++++++++++++++++---- src/levanter/trainer.py | 5 +++ tests/test_checkpoint.py | 37 ++++++++++++++++ tests/test_tensorstore_serialization.py | 21 ++++++++- 6 files changed, 140 insertions(+), 37 deletions(-) diff --git a/config/harness/eval_llama3.yaml b/config/harness/eval_llama3.yaml index 260620102..cef182f4e 100644 --- a/config/harness/eval_llama3.yaml +++ b/config/harness/eval_llama3.yaml @@ -2,32 +2,32 @@ eval_harness: task_spec: - task: commonsense_qa # 5-way multiple-choice questions based on common-sense, everyday scenarios num_fewshot: 10 - - task: agieval_lsat_ar # 3-shot tests in legal domain - num_fewshot: 3 - - task: arc_easy # 10-shot, four-way MCQ questions involving grade 3-9 basic science - num_fewshot: 10 - - task: arc_challenge # a (harder) version of arc_easy - num_fewshot: 10 - - task: boolq # answer yes/no questions based on a passage - num_fewshot: 10 - - task: copa # use causal reasoning to predict the correct outcome of a given scenario - num_fewshot: 0 - - task: hellaswag # 4-way multiple choice commonsense reasoning dataset - num_fewshot: 0 - task_alias: hellaswag_0shot - - task: hellaswag # 4-way multiple choice commonsense reasoning dataset - num_fewshot: 10 - task_alias: hellaswag_10shot - - task: lambada # predict the endings of text passages - num_fewshot: 0 - - task: openbookqa # 4-way multiple choice question answering task that requires multi-step reasoning - num_fewshot: 0 - - task: piqa # answer questions based on a passage - num_fewshot: 10 - - task: wsc273 # Winograd Schema Challenge - num_fewshot: 0 - - task: winogrande # Winograd challenge, extended to more domains - num_fewshot: 0 +# - task: agieval_lsat_ar # 3-shot tests in legal domain +# num_fewshot: 3 +# - task: arc_easy # 10-shot, four-way MCQ questions involving grade 3-9 basic science +# num_fewshot: 10 +# - task: arc_challenge # a (harder) version of arc_easy +# num_fewshot: 10 +# - task: boolq # answer yes/no questions based on a passage +# num_fewshot: 10 +# - task: copa # use causal reasoning to predict the correct outcome of a given scenario +# num_fewshot: 0 +# - task: hellaswag # 4-way multiple choice commonsense reasoning dataset +# num_fewshot: 0 +# task_alias: hellaswag_0shot +# - task: hellaswag # 4-way multiple choice commonsense reasoning dataset +# num_fewshot: 10 +# task_alias: hellaswag_10shot +# - task: lambada # predict the endings of text passages +# num_fewshot: 0 +# - task: openbookqa # 4-way multiple choice question answering task that requires multi-step reasoning +# num_fewshot: 0 +# - task: piqa # answer questions based on a passage +# num_fewshot: 10 +# - task: wsc273 # Winograd Schema Challenge +# num_fewshot: 0 +# - task: winogrande # Winograd challenge, extended to more domains +# num_fewshot: 0 # requires generation ## - task: squadv2 # reading comprehension benchmark # num_fewshot: 10 diff --git a/src/levanter/checkpoint.py b/src/levanter/checkpoint.py index ed537c927..b9eabe922 100644 --- a/src/levanter/checkpoint.py +++ b/src/levanter/checkpoint.py @@ -353,6 +353,7 @@ def load_checkpoint( discover_latest=True, axis_mapping: Optional[haliax.partitioning.ResourceMapping] = None, mesh: Optional[jax.sharding.Mesh] = None, + allow_partial: bool = False, ) -> M: """ Load a checkpoint from a given path. If discover_latest is True, then the latest checkpoint @@ -367,6 +368,7 @@ def load_checkpoint( discover_latest: whether to discover the latest checkpoint in the given path axis_mapping: the axis mapping to use for loading the checkpoint mesh: the mesh to use for loading the checkpoint + allow_partial: if True, allow partial loading of the checkpoint. If False, all parameters must be present in the checkpoint. Returns: the loaded checkpoint, with the same structure as the exemplar tree @@ -397,7 +399,9 @@ def load_checkpoint( ser, non_ser = equinox.partition(tree, is_jax_array_like) try: - tree = tree_deserialize_leaves_tensorstore(checkpoint_path, ser, axis_mapping=axis_mapping, mesh=mesh) + tree = tree_deserialize_leaves_tensorstore( + checkpoint_path, ser, axis_mapping=axis_mapping, mesh=mesh, allow_missing=allow_partial + ) tree = equinox.combine(tree, non_ser) return tree except: # noqa @@ -445,6 +449,7 @@ def load_checkpoint_or_initialize( donate_args: FilterSpec = True, donate_kwargs: Optional[FilterSpec] = None, do_load: Optional[bool] = None, + allow_partial: bool = False, ) -> Callable[Sig, M]: """ Load a checkpoint from a given path. If discover_latest is True, then the latest checkpoint @@ -476,6 +481,7 @@ def load_checkpoint_or_initialize( donate_args: a FilterSpec that specifies which arguments to donate to init_fn if we need to initialize donate_kwargs: a FilterSpec that specifies which kwargs to donate to init_fn if we need to initialize do_load: if True, always load the checkpoint. If False, always initialize. If None, load if the checkpoint exists, otherwise initialize + allow_partial: if True, allow partial loading of the checkpoint. If False, all parameters must be present in the checkpoint. Returns: A function that takes the same arguments as init_fn, but loads the checkpoint if it exists and returns the @@ -493,6 +499,8 @@ def load_checkpoint_or_initialize( ) def init_and_merge(state, *args, **kwargs): init_state = init_fn(*args, **kwargs) + # remove all ShapeDTypeStructs from the state + state = equinox.filter(state, lambda x: not isinstance(x, jax.ShapeDtypeStruct)) return equinox.combine(state, init_state) def load_or_init(*args, **kwargs): @@ -516,6 +524,7 @@ def load_or_init(*args, **kwargs): discover_latest=discover_latest, axis_mapping=axis_mapping, mesh=mesh, + allow_partial=allow_partial, ) except FileNotFoundError: if do_load is True: diff --git a/src/levanter/tensorstore_serialization.py b/src/levanter/tensorstore_serialization.py index fc9155cd1..09c66e19d 100644 --- a/src/levanter/tensorstore_serialization.py +++ b/src/levanter/tensorstore_serialization.py @@ -20,7 +20,7 @@ from haliax.partitioning import ResourceMapping from haliax.util import is_named_array -from levanter.utils import jax_utils +from levanter.utils import fsspec_utils, jax_utils logger = logging.getLogger(__name__) @@ -119,6 +119,8 @@ def tree_deserialize_leaves_tensorstore( axis_mapping: Optional[ResourceMapping] = None, mesh: Optional[Mesh] = None, manager: Optional[array_ser.GlobalAsyncCheckpointManager] = None, + *, + allow_missing: bool = False, ): """ Deserializes a PyTree of Arrays and NamedArrays from a Tensorstore checkpoint, returning a pytree with the same shape @@ -132,6 +134,7 @@ def tree_deserialize_leaves_tensorstore( axis_mapping: optional, the axis mapping for the NamedArrays (if they are not yet arrays) mesh: optional, the mesh for the NamedArrays (if they are not yet arrays) manager: optional, the checkpoint manager to use. If not provided, a new one will be created + allow_missing: if True, missing leaves will be allowed and kept as-is Returns: A pytree with the same shape as the exemplar pytree, but with the arrays deserialized from the checkpoint @@ -151,26 +154,56 @@ def tree_deserialize_leaves_tensorstore( shardings_leaves, shardings_structure = jtu.tree_flatten(shardings, is_leaf=_is_named_or_none) assert len(shardings_leaves) == len(paths) - # ok, so, jax really doesn't want any Nones in the leaves here, so we need to temporarily partition the pytree real_indices = [i for i, x in enumerate(shardings_leaves) if x is not None] - real_leaves = [x for x in shardings_leaves if x is not None] - real_paths = [paths[i] for i in real_indices] + paths_to_load = [] + indices_to_load = [] + shardings_to_load = [] + + missing_paths = [] + missing_indices = [] + + for i in real_indices: + path = paths[i] + + if not fsspec_utils.exists(path): + missing_paths.append(path) + missing_indices.append(i) + continue - assert len(real_leaves) == len(real_paths), f"{len(real_leaves)} != {len(real_paths)}" + paths_to_load.append(path) + indices_to_load.append(i) + shardings_to_load.append(shardings_leaves[i]) + + # ok now check for missing paths + if missing_paths: + if not allow_missing: + raise FileNotFoundError(f"Missing paths: {missing_paths}") + else: + to_log = f"Several keys were missing from the checkpoint directory {checkpoint_dir}:" + leaf_paths = jtu.tree_leaves(leaf_key_paths, is_leaf=_is_named_or_none) + for i in missing_indices: + to_log += f"\n - {leaf_paths[i]}" + logger.warning(to_log) + + deser_leaves = manager.deserialize_with_paths(shardings=shardings_to_load, paths=paths_to_load) - deser_leaves = manager.deserialize_with_paths(shardings=real_leaves, paths=real_paths) # now we need to recreate the original structure - out_leaves = [None] * len(shardings_leaves) - for i, x in zip(real_indices, deser_leaves): + out_leaves = jax.tree_leaves(pytree, is_leaf=_is_named_or_none) + assert len(out_leaves) == len(shardings_leaves) + # out_leaves = [None] * len(shardings_leaves) + for i, x in zip(indices_to_load, deser_leaves): out_leaves[i] = x deser_arrays = jtu.tree_unflatten(shardings_structure, out_leaves) - # deser_arrays only has arrays, but we need named arrays for at least some. + # deser_arrays only has arrays for the deserialized arrays, but we need named arrays for at least some. # The original pytree has the structure we want, so we'll use that to rebuild the named arrays def _rebuild_named_array(like, array): + if is_named_array(array): + return array + if is_named_array(like): return hax.NamedArray(array, like.axes) else: diff --git a/src/levanter/trainer.py b/src/levanter/trainer.py index 82f32422a..7984e59b7 100644 --- a/src/levanter/trainer.py +++ b/src/levanter/trainer.py @@ -343,6 +343,7 @@ def initial_state( mesh=self.device_mesh, subpath="model", do_load=True, + allow_partial=self.config.allow_partial_checkpoint, )() model_init = jax.tree_util.Partial(lambda m: m, loaded_model) @@ -369,6 +370,7 @@ def init_state_and_model(model_init, training_key): mesh=self.device_mesh, is_checkpointed=saveable_train_state, do_load=load_checkpoint, + allow_partial=self.config.allow_partial_checkpoint, )(model_init, training_key) return state @@ -629,6 +631,9 @@ class TrainerConfig: load_checkpoint_path: Optional[str] = None """can be a parent (to find latest) or a specific checkpoint. if None, will set to checkpointer.base_path.""" initialize_from: Optional[str] = None # Levanter trainer checkpoint to initialize from + allow_partial_checkpoint: bool = False + """If True, we allow loading a checkpoint that doesn't have all the parameters in the model. + Missing parameters are initialized from the model_init function.""" jax_config: Mapping[str, JsonAtom] = field( default_factory=lambda: copy.deepcopy(DEFAULT_JAX_CONFIG) diff --git a/tests/test_checkpoint.py b/tests/test_checkpoint.py index 037384c51..272ff085e 100644 --- a/tests/test_checkpoint.py +++ b/tests/test_checkpoint.py @@ -329,3 +329,40 @@ def init_fn(key): jax.tree_util.tree_leaves(arrays_only(eqx.filter(loaded, is_checkpointed))), jax.tree_util.tree_leaves(arrays_only(eqx.filter(model1, is_checkpointed))), ) + + +def test_load_from_checkpoint_allows_partial_checkpoints(): + In = Axis("in", 2) + Out = Axis("out", 1) + + class MyModule(eqx.Module): + a: hax.NamedArray + b: hax.NamedArray | None + + def init_fn(key, use_b): + k_a, k_b = jax.random.split(key) + return MyModule(a=hax.random.normal(k_a, (In, Out)), b=hax.random.normal(k_b, (In, Out)) if use_b else None) + + k0 = jax.random.PRNGKey(0) + k1 = jax.random.PRNGKey(1) + + model0 = init_fn(k0, False) + model1 = init_fn(k1, True) + + is_checkpointed = True + + with jax.sharding.Mesh(jax.devices(), ("devices",)), tempfile.TemporaryDirectory() as tmpdir: + + save_checkpoint(eqx.filter(model0, is_checkpointed), step=0, checkpoint_path=tmpdir) + + loaded = load_checkpoint_or_initialize( + init_fn, + tmpdir, + is_checkpointed=is_checkpointed, + allow_partial=True, + )(k1, True) + + assert not any(jax.tree_util.tree_leaves(eqx.filter(loaded, lambda x: isinstance(x, ShapeDtypeStruct)))) + assert hax.all(hax.equal(loaded.a, model0.a)) + assert loaded.b is not None + assert hax.all(hax.equal(loaded.b, model1.b)) diff --git a/tests/test_tensorstore_serialization.py b/tests/test_tensorstore_serialization.py index 77d63d656..37f4721dd 100644 --- a/tests/test_tensorstore_serialization.py +++ b/tests/test_tensorstore_serialization.py @@ -156,5 +156,24 @@ class MyModule(eqx.Module): m3 = MyModule(a=hax.zeros(A), b=hax.ones(A)) with TemporaryDirectory() as tmpdir: tree_serialize_leaves_tensorstore(tmpdir, m2) - with pytest.raises(ValueError): + with pytest.raises(FileNotFoundError): tree_deserialize_leaves_tensorstore(tmpdir, m3) + + +def test_tensorstore_ok_with_missing(): + mesh = jax.sharding.Mesh(jax.devices(), ("device",)) + with mesh: + A = hax.Axis("A", 10) + + class MyModule(eqx.Module): + a: Any + b: Any + + m = MyModule(a=None, b=hax.zeros(A)) + m2 = MyModule(a=hax.full(A, 4), b=hax.ones(A)) + + with TemporaryDirectory() as tmpdir: + tree_serialize_leaves_tensorstore(tmpdir, m) + m3 = tree_deserialize_leaves_tensorstore(tmpdir, m2, allow_missing=True) + assert hax.all(m3.a == hax.full(A, 4)) + assert hax.all(m3.b == hax.zeros(A))