From 0df6f814be43491b4d921d1f9d3c3a6ea0a81941 Mon Sep 17 00:00:00 2001 From: Yifu Wang Date: Fri, 28 Oct 2022 15:06:26 -0700 Subject: [PATCH] Polish API documentation (#123) Summary: As title. Pull Request resolved: https://github.com/pytorch/torchsnapshot/pull/123 Reviewed By: ananthsub Differential Revision: D40815842 Pulled By: yifuwang fbshipit-source-id: ed10e82e0d79cec7f0a5ea307f400d69b3e09752 --- docs/source/api_reference.rst | 5 +- docs/source/conf.py | 1 + docs/source/getting_started.rst | 5 +- docs/source/index.rst | 1 - docs/source/utilities.rst | 17 --- torchsnapshot/rng_state.py | 31 +++-- torchsnapshot/snapshot.py | 228 ++++++++++++-------------------- torchsnapshot/state_dict.py | 24 +--- 8 files changed, 116 insertions(+), 196 deletions(-) delete mode 100644 docs/source/utilities.rst diff --git a/docs/source/api_reference.rst b/docs/source/api_reference.rst index bdbb8e9..467a12e 100644 --- a/docs/source/api_reference.rst +++ b/docs/source/api_reference.rst @@ -3,4 +3,7 @@ API Reference .. autoclass:: torchsnapshot.Snapshot :members: - :undoc-members: + +.. autoclass:: torchsnapshot.StateDict + +.. autoclass:: torchsnapshot.RNGState diff --git a/docs/source/conf.py b/docs/source/conf.py index fc6b048..f61347f 100644 --- a/docs/source/conf.py +++ b/docs/source/conf.py @@ -82,3 +82,4 @@ } add_module_names = False +autodoc_member_order = "bysource" diff --git a/docs/source/getting_started.rst b/docs/source/getting_started.rst index 1334996..952fd37 100644 --- a/docs/source/getting_started.rst +++ b/docs/source/getting_started.rst @@ -99,7 +99,7 @@ Objects within a snapshot can be efficiently accessed without fetching the entir Taking a Snapshot Asynchronously -------------------------------- -When host memory is abundant, users can leverage it with :func:`Snapshot.async_take() ` to allow training to resume before all storage I/O completes. :func:`Snapshot.async_take() ` return as soon as it stages the snapshot content in host RAM and schedules storage I/O in background. This can drastically reduce the time blocked for checkpointing especially when the underly storage is slow. +When host memory is abundant, users can leverage it with :func:`Snapshot.async_take() ` to allow training to resume before all storage I/O completes. :func:`Snapshot.async_take() ` returns as soon as it stages the snapshot content in host RAM and schedules storage I/O in background. This can drastically reduce the time blocked for checkpointing especially when the underly storage is slow. .. code-block:: Python @@ -124,8 +124,7 @@ When host memory is abundant, users can leverage it with :func:`Snapshot.async_t Reproducibility --------------- -TorchSnapshot provides a utility called :class:`RNGState ` to help users manage reproducibility. If an :class:`RNGState ` object is captured in the application state, TorchSnapshot ensures that the global RNG state is set to the same values after taking the snapshot and after restoring from the snapshot. - +TorchSnapshot provides a utility called :class:`RNGState ` to help users manage reproducibility. If an :class:`RNGState ` object is captured in the application state, TorchSnapshot ensures that the global RNG state is set to the same values after restoring from the snapshot as it was after taking the snapshot. .. code-block:: Python diff --git a/docs/source/index.rst b/docs/source/index.rst index 8a488f1..2491cb0 100644 --- a/docs/source/index.rst +++ b/docs/source/index.rst @@ -15,7 +15,6 @@ TorchSnapshot API getting_started.rst api_reference.rst - utilities.rst Examples -------- diff --git a/docs/source/utilities.rst b/docs/source/utilities.rst deleted file mode 100644 index 8285997..0000000 --- a/docs/source/utilities.rst +++ /dev/null @@ -1,17 +0,0 @@ -Utilities -========= - -StateDict ---------- - -.. autoclass:: torchsnapshot.StateDict - :members: - :undoc-members: - - -RNGState --------- - -.. autoclass:: torchsnapshot.RNGState - :members: - :undoc-members: diff --git a/torchsnapshot/rng_state.py b/torchsnapshot/rng_state.py index 3ac9ce1..68b8c36 100644 --- a/torchsnapshot/rng_state.py +++ b/torchsnapshot/rng_state.py @@ -12,25 +12,32 @@ class RNGState: """ - When captured in app state, it is guaranteed that rng states will be the - same after ``Snapshot.take`` and ``Snapshot.restore``. + A special stateful object for saving and restoring global RNG state. - :: + When captured in the application state, it is guaranteed that the global + RNG state is set to the same values after restoring from the snapshot as it + was after taking the snapshot. + + Example: - app_state = { - "rng_state": RNGState(), - } - snapshot = Snapshot.take("foo/bar", app_state, backend=...) - after_take = torch.rand(1) + :: - snapshot.restore(app_state) - after_restore = torch.rand(1) + >>> Snapshot.take( + >>> path="foo/bar", + >>> app_state={"rng_state": RNGState()}, + >>> ) + >>> after_take = torch.rand(1) - torch.testing.assert_close(after_take, after_restore) + >>> # In the same process or in another process + >>> snapshot = Snapshot(path="foo/bar") + >>> snapshot.restore(app_state) + >>> after_restore = torch.rand(1) - TODO augment this to capture rng states other than torch.get_rng_state(). + >>> torch.testing.assert_close(after_take, after_restore) """ + # TODO: augment this to capture rng states other than torch.get_rng_state() + def state_dict(self) -> Dict[str, torch.Tensor]: return {"rng_state": torch.get_rng_state()} diff --git a/torchsnapshot/snapshot.py b/torchsnapshot/snapshot.py index 6c8398a..22d9918 100644 --- a/torchsnapshot/snapshot.py +++ b/torchsnapshot/snapshot.py @@ -65,82 +65,20 @@ class Snapshot: """ - Snapshot represents the persisted program state at one point in time. + Create a reference to an existing snapshot. - Basic usage: - :: + Args: + path (str): The path to the snapshot. This should be the same as the + ``path`` argument used for :func:`Snapshot.take` when the snapshot + was taken. - # Define the program state - app_state = {"model": model, "optimizer": optimizer"} + pg (ProcessGroup, optional): The process group for the participants of + :meth:`Snapshot.restore`. If none, the default process group will be + used. - # At an appropriate time, persist the program state as a snapshot - snapshot = Snapshot.take(path=path, app_state=app_state) - - # On resuming, restore the program state from a snapshot - snapshot.restore(app_state) - - Overview: - - At high level, torchsnapshot saves each value in state dicts as a - file/object in the corresponding storage system. It also saves a manifest - describing the persisted values and the structure of the original state - dict. - - Comparing with :py:func:`torch.save` and :py:func:`torch.load`, torchsnapshot: - - - Enables efficient random access of persisted model weights. - - - Accelerates persistence by parallelizing writes. - - - For replicated values, persistence is parallelized across ranks. - - - Enables flexible yet robust elasticity (changing world size on - restore). - - - Elasticity: - - Elasticity is implemented via correctly making persisted values - available to a newly joined rank, and having it correctly restores the - corresponding runtime objects from the persisted values. - - For the purpose of elasticity, all persisted values fall into one of - the categories in [per-rank, replicated, sharded]. - - per-rank: - - By default, all non-sharded values are treated as per-rank. - - On save, the value is only saved by the owning rank. - - On load, the value is only made available to the same rank. - - replicated: - - A user can suggest any non-sharded value as replicated via glob - patterns. - - On save, the value is only saved once (can be by any rank). - - On load, the value is made available to all ranks, including newly - joined ranks. - - sharded: - - Specific types are always treated as sharded (e.g. ShardedTensor). - - On save, all shard-owning ranks save their shards. - - On load, all shards are made available to all ranks, including - newly joined rank. All ranks can read from all shards for - restoring the runtime object from persisted values. - (ShardedTensor resharding is powered by torch.dist.checkpoint). - - If all values within a snapshot are either replicated or sharded, the - snapshot is automatically reshard-able. - - If a snapshot contains per-rank values, it cannot be resharded unless - the per-rank values are explicitly coerced to replicated on load. + storage_options (Dict[str, Any], optional): Additional keyword options + for the storage plugin to use. See each storage plugin's documentation + for customizations. """ def __init__( @@ -149,18 +87,6 @@ def __init__( pg: Optional[dist.ProcessGroup] = None, storage_options: Optional[Dict[str, Any]] = None, ) -> None: - """ - Initializes the reference to an existing snapshot. - - Args: - path: The location of the snapshot. - pg: The process group for the processes restoring from the snapshot. - When unspecified: - - If distributed is initialized, the global process group will be used. - - If distributed is not initialized, single process is assumed. - storage_options: Additional keyword options for the StoragePlugin to use. - See each StoragePlugin's documentation for customizations. - """ self.path: str = path self.pg: Optional[dist.ProcessGroup] = pg self._metadata: Optional[SnapshotMetadata] = None @@ -179,20 +105,43 @@ def take( ] = None, ) -> "Snapshot": """ - Take a snapshot from the program state. + Takes a snapshot of the application state. Args: - app_state: The program state to take the snapshot from. - path: The location to save the snapshot. - pg: The process group for the processes taking the snapshot. - When unspecified: - - If distributed is initialized, the global process group will be used. - - If distributed is not initialized, single process is assumed. - replicated: A list of glob patterns for hinting the matching paths - as replicated. Note that patterns not specified by all ranks - are ignored. - storage_options: Additional keyword options for the StoragePlugin to use. - See each StoragePlugin's documentation for customizations. + app_state (Dict[str, Stateful]): The application state to persist. + It takes the form of a dictionary, with the keys being + user-defined strings and the values being stateful objects. + Stateful objects are objects that exposes ``.state_dict()`` and + ``.load_state_dict()`` methods. Common PyTorch objects such as + :class:`torch.nn.Module`, :class:`torch.optim.Optimizer`, and + LR schedulers all qualify as stateful objects. + + path (str): The location to save the snapshot. ``path`` can have a + URI prefix (e.g. ``s3://``) that specifies a storage backend. + If no URI prefix is supplied, ``path`` is assumed to be a file + system location. For distributed snapshot, if ``path`` is + inconsistent across participating ranks, the value specified by + rank 0 will be used. For multi-host snapshot, ``path`` needs to + be a location accessible by all hosts. + + .. note:: ``path`` must **not** point to an existing snapshot. + + pg (ProcessGroup, optional): The process group for the participants + of :meth:`Snapshot.take`. If none, the default process group will + be used. + + replicated (List[str], optional): Glob patterns for marking + checkpoint content as replicated. Matching objects will be deduped + and load-balanced across ranks. + + .. note:: The replication property is automatically inferred + for ``DistributedDataParallel``. Only specify this argument + if your model has fully replicated states but does not use + ``DistributedDataParallel``. + + storage_options (Dict[str, Any], optional): Additional keyword + options for the storage plugin to use. See each storage plugin's + documentation for customizations. Returns: The newly taken snapshot. @@ -252,31 +201,23 @@ def async_take( ] = None, ) -> "PendingSnapshot": """ - Asynchronously take a snapshot from the program state. + Asynchronously takes a snapshot from the application state. - This method creates a consistent snapshot of the app state (i.e. - changes to the app state after this method returns have no effect on - the snapshot). The asynchronicity is a result of performing storage I/O - in the background. + This function is identical to :func:`Snapshot.take`, except that it + returns early and performs as much I/O operations in the background as + possible, allowing training to resume early. Args: - app_state: The program state to take the snapshot from. - path: The location to save the snapshot. - pg: The process group for the processes taking the snapshot. - When unspecified: - - If distributed is initialized, the global process group will be used. - - If distributed is not initialized, single process is assumed. - replicated: A list of glob patterns for hinting the matching paths - as replicated. Note that patterns not specified by all ranks - are ignored. - storage_options: Additional keyword options for the StoragePlugin to use. - See each StoragePlugin's documentation for customizations. + app_state (Dict[str, Stateful]): Same as the ``app_state`` argument of :func:`Snapshot.take`. + path (str): Same as the ``path`` argument of :func:`Snapshot.take`. + pg (ProcessGroup, optional): Same as the ``pg`` argument of :func:`Snapshot.take`. + replicated (List[str], optional): Same as the ``replicated`` argument of :func:`Snapshot.take`. + storage_options (Dict[str, Any], optional): Same as the ``storage_options`` argument of :func:`Snapshot.take`. Returns: - A handle with which the newly taken snapshot can be obtained via - `.wait()`. Note that waiting on the handle is optional. The - snapshot will be committed regardless of whether `.wait()` is - invoked. + A handle to the pending snapshot. The handle has exposes a + ``.done()`` method for querying the progress and a ``.wait()`` + method for waiting for the snapshot's completion. """ torch._C._log_api_usage_once("torchsnapshot.Snapshot.async_take") cls._validate_app_state(app_state) @@ -436,11 +377,13 @@ def _take_impl( def restore(self, app_state: AppState) -> None: """ - Restores the program state from the snapshot. + Restores the application state from the snapshot. Args: - app_state: The program state to restore from the snapshot. - + app_state (Dict[str, Stateful]): The application state to restore. + ``app_state`` needs to be either identical to or a subset of the + ``app_state`` used for :func:`Snapshot.take` when the snapshot was + taken. """ torch._C._log_api_usage_once("torchsnapshot.Snapshot.restore") self._validate_app_state(app_state) @@ -505,31 +448,25 @@ def read_object( memory_budget_bytes: Optional[int] = None, ) -> T: """ - Read a persisted object from the snapshot's content. - - The persisted object to read is specified by its path in the snapshot - metadata. Available paths can be obtained via `snapshot.get_manifest()`. + Reads an object from the snapshot's content. - A path in snapshot metadata follows the following format: - - ``RANK/STATEFUL_NAME/STATE_DICT_KEY[/NESTED_CONTAINER_KEY...]`` + Args: + path (str): The path to the target object within the snapshot. + ``path`` is equivalent to the target object's key in the + snapshot manifest and can be obtained via + :meth:`Snapshot.get_manifest`. - The rank only matters when the persisted object is "per-rank". - Arbitrary rank can be used when the persisted object is "replicated" or - "sharded". + obj_out (Any, optional): When specified, load the object in-place + into ``obj_out`` if in-place load is supported for the object's + type. Otherwise, ``obj_out`` is ignored. - If the persisted object is a sharded tensor, `obj_out` must be - supplied. The supplied tensor can be either a tensor or sharded tensor. - `read_object` will correctly populate `obj_out`'s data according to - sharding spec. + .. note:: + When the target object is a ``ShardedTensor``, ``obj_out`` + must be specified. - Args: - path: The path to the persisted object. - obj_out: If specified and the object type supports in-place load, - `read_object` will directly read the persisted object into - `obj_out`'s buffer. - memory_budget_bytes: When specified, the read operation will keep - the temporary memory buffer size below this threshold. + memory_budget_bytes (int, optional): When specified, the read + operation will keep the temporary memory buffer size below this + threshold. Returns: The object read from the snapshot's content. @@ -595,10 +532,15 @@ def read_object( def get_manifest(self) -> Dict[str, Entry]: """ - Returns the snapshot's manifest. + Returns the snapshot manifest. + + Each entry in the dictionary corresponds to an object in the snapshot, + with the keys being the logical paths to the objects and the values + being the metadata describing the object. For distributed snapshots, + the manifest contain entries for objects saved by all ranks. Returns: - The snapshot's manifest. + The snapshot manifest. """ return copy.deepcopy(self.metadata.manifest) diff --git a/torchsnapshot/state_dict.py b/torchsnapshot/state_dict.py index c3b91c5..3195735 100644 --- a/torchsnapshot/state_dict.py +++ b/torchsnapshot/state_dict.py @@ -12,26 +12,12 @@ # pyre-fixme[24]: Python <3.9 doesn't support typing on UserDict class StateDict(UserDict): """ - A dict that implements the Stateful protocol. It is handy for capturing - stateful objects that do not already implement the Stateful protocol or - can't implement the protocol (i.e. primitive types). + A dictionary that exposes ``.state_dict()`` and ``.load_state_dict()`` + methods. - :: - - model = Model() - progress = StateDict(current_epoch=0) - app_state = {"model": model, "progress": progress} - - # Load from the last snapshot if available - ... - - while progress["current_epoch"] < NUM_EPOCHS: - # Train for an epoch - ... - progress["current_epoch"] += 1 - - # progress is captured by the snapshot - Snapshot.take("foo/bar", app_state, backend=...) + It can be used to capture objects that do not expose ``.state_dict()`` and + ``.load_state_dict()`` methods (e.g. Tensors, Python primitive types) as + part of the application state. """ def state_dict(self) -> Dict[str, Any]: