Skip to content

Commit

Permalink
Polish API documentation (#123)
Browse files Browse the repository at this point in the history
Summary:
As title.

Pull Request resolved: #123

Reviewed By: ananthsub

Differential Revision: D40815842

Pulled By: yifuwang

fbshipit-source-id: ed10e82e0d79cec7f0a5ea307f400d69b3e09752
  • Loading branch information
yifuwang authored and facebook-github-bot committed Oct 28, 2022
1 parent c728806 commit 0df6f81
Show file tree
Hide file tree
Showing 8 changed files with 116 additions and 196 deletions.
5 changes: 4 additions & 1 deletion docs/source/api_reference.rst
Original file line number Diff line number Diff line change
Expand Up @@ -3,4 +3,7 @@ API Reference

.. autoclass:: torchsnapshot.Snapshot
:members:
:undoc-members:

.. autoclass:: torchsnapshot.StateDict

.. autoclass:: torchsnapshot.RNGState
1 change: 1 addition & 0 deletions docs/source/conf.py
Original file line number Diff line number Diff line change
Expand Up @@ -82,3 +82,4 @@
}

add_module_names = False
autodoc_member_order = "bysource"
5 changes: 2 additions & 3 deletions docs/source/getting_started.rst
Original file line number Diff line number Diff line change
Expand Up @@ -99,7 +99,7 @@ Objects within a snapshot can be efficiently accessed without fetching the entir
Taking a Snapshot Asynchronously
--------------------------------

When host memory is abundant, users can leverage it with :func:`Snapshot.async_take() <torchsnapshot.Snapshot.async_take>` to allow training to resume before all storage I/O completes. :func:`Snapshot.async_take() <torchsnapshot.Snapshot.async_take>` return as soon as it stages the snapshot content in host RAM and schedules storage I/O in background. This can drastically reduce the time blocked for checkpointing especially when the underly storage is slow.
When host memory is abundant, users can leverage it with :func:`Snapshot.async_take() <torchsnapshot.Snapshot.async_take>` to allow training to resume before all storage I/O completes. :func:`Snapshot.async_take() <torchsnapshot.Snapshot.async_take>` returns as soon as it stages the snapshot content in host RAM and schedules storage I/O in background. This can drastically reduce the time blocked for checkpointing especially when the underly storage is slow.


.. code-block:: Python
Expand All @@ -124,8 +124,7 @@ When host memory is abundant, users can leverage it with :func:`Snapshot.async_t
Reproducibility
---------------

TorchSnapshot provides a utility called :class:`RNGState <torchsnapshot.rng_state.RNGState>` to help users manage reproducibility. If an :class:`RNGState <torchsnapshot.rng_state.RNGState>` object is captured in the application state, TorchSnapshot ensures that the global RNG state is set to the same values after taking the snapshot and after restoring from the snapshot.

TorchSnapshot provides a utility called :class:`RNGState <torchsnapshot.rng_state.RNGState>` to help users manage reproducibility. If an :class:`RNGState <torchsnapshot.rng_state.RNGState>` object is captured in the application state, TorchSnapshot ensures that the global RNG state is set to the same values after restoring from the snapshot as it was after taking the snapshot.

.. code-block:: Python
Expand Down
1 change: 0 additions & 1 deletion docs/source/index.rst
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,6 @@ TorchSnapshot API

getting_started.rst
api_reference.rst
utilities.rst

Examples
--------
Expand Down
17 changes: 0 additions & 17 deletions docs/source/utilities.rst

This file was deleted.

31 changes: 19 additions & 12 deletions torchsnapshot/rng_state.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,25 +12,32 @@

class RNGState:
"""
When captured in app state, it is guaranteed that rng states will be the
same after ``Snapshot.take`` and ``Snapshot.restore``.
A special stateful object for saving and restoring global RNG state.
::
When captured in the application state, it is guaranteed that the global
RNG state is set to the same values after restoring from the snapshot as it
was after taking the snapshot.
Example:
app_state = {
"rng_state": RNGState(),
}
snapshot = Snapshot.take("foo/bar", app_state, backend=...)
after_take = torch.rand(1)
::
snapshot.restore(app_state)
after_restore = torch.rand(1)
>>> Snapshot.take(
>>> path="foo/bar",
>>> app_state={"rng_state": RNGState()},
>>> )
>>> after_take = torch.rand(1)
torch.testing.assert_close(after_take, after_restore)
>>> # In the same process or in another process
>>> snapshot = Snapshot(path="foo/bar")
>>> snapshot.restore(app_state)
>>> after_restore = torch.rand(1)
TODO augment this to capture rng states other than torch.get_rng_state().
>>> torch.testing.assert_close(after_take, after_restore)
"""

# TODO: augment this to capture rng states other than torch.get_rng_state()

def state_dict(self) -> Dict[str, torch.Tensor]:
return {"rng_state": torch.get_rng_state()}

Expand Down
228 changes: 85 additions & 143 deletions torchsnapshot/snapshot.py
Original file line number Diff line number Diff line change
Expand Up @@ -65,82 +65,20 @@

class Snapshot:
"""
Snapshot represents the persisted program state at one point in time.
Create a reference to an existing snapshot.
Basic usage:
::
Args:
path (str): The path to the snapshot. This should be the same as the
``path`` argument used for :func:`Snapshot.take` when the snapshot
was taken.
# Define the program state
app_state = {"model": model, "optimizer": optimizer"}
pg (ProcessGroup, optional): The process group for the participants of
:meth:`Snapshot.restore`. If none, the default process group will be
used.
# At an appropriate time, persist the program state as a snapshot
snapshot = Snapshot.take(path=path, app_state=app_state)
# On resuming, restore the program state from a snapshot
snapshot.restore(app_state)
Overview:
At high level, torchsnapshot saves each value in state dicts as a
file/object in the corresponding storage system. It also saves a manifest
describing the persisted values and the structure of the original state
dict.
Comparing with :py:func:`torch.save` and :py:func:`torch.load`, torchsnapshot:
- Enables efficient random access of persisted model weights.
- Accelerates persistence by parallelizing writes.
- For replicated values, persistence is parallelized across ranks.
- Enables flexible yet robust elasticity (changing world size on
restore).
Elasticity:
Elasticity is implemented via correctly making persisted values
available to a newly joined rank, and having it correctly restores the
corresponding runtime objects from the persisted values.
For the purpose of elasticity, all persisted values fall into one of
the categories in [per-rank, replicated, sharded].
per-rank:
By default, all non-sharded values are treated as per-rank.
On save, the value is only saved by the owning rank.
On load, the value is only made available to the same rank.
replicated:
A user can suggest any non-sharded value as replicated via glob
patterns.
On save, the value is only saved once (can be by any rank).
On load, the value is made available to all ranks, including newly
joined ranks.
sharded:
Specific types are always treated as sharded (e.g. ShardedTensor).
On save, all shard-owning ranks save their shards.
On load, all shards are made available to all ranks, including
newly joined rank. All ranks can read from all shards for
restoring the runtime object from persisted values.
(ShardedTensor resharding is powered by torch.dist.checkpoint).
If all values within a snapshot are either replicated or sharded, the
snapshot is automatically reshard-able.
If a snapshot contains per-rank values, it cannot be resharded unless
the per-rank values are explicitly coerced to replicated on load.
storage_options (Dict[str, Any], optional): Additional keyword options
for the storage plugin to use. See each storage plugin's documentation
for customizations.
"""

def __init__(
Expand All @@ -149,18 +87,6 @@ def __init__(
pg: Optional[dist.ProcessGroup] = None,
storage_options: Optional[Dict[str, Any]] = None,
) -> None:
"""
Initializes the reference to an existing snapshot.
Args:
path: The location of the snapshot.
pg: The process group for the processes restoring from the snapshot.
When unspecified:
- If distributed is initialized, the global process group will be used.
- If distributed is not initialized, single process is assumed.
storage_options: Additional keyword options for the StoragePlugin to use.
See each StoragePlugin's documentation for customizations.
"""
self.path: str = path
self.pg: Optional[dist.ProcessGroup] = pg
self._metadata: Optional[SnapshotMetadata] = None
Expand All @@ -179,20 +105,43 @@ def take(
] = None,
) -> "Snapshot":
"""
Take a snapshot from the program state.
Takes a snapshot of the application state.
Args:
app_state: The program state to take the snapshot from.
path: The location to save the snapshot.
pg: The process group for the processes taking the snapshot.
When unspecified:
- If distributed is initialized, the global process group will be used.
- If distributed is not initialized, single process is assumed.
replicated: A list of glob patterns for hinting the matching paths
as replicated. Note that patterns not specified by all ranks
are ignored.
storage_options: Additional keyword options for the StoragePlugin to use.
See each StoragePlugin's documentation for customizations.
app_state (Dict[str, Stateful]): The application state to persist.
It takes the form of a dictionary, with the keys being
user-defined strings and the values being stateful objects.
Stateful objects are objects that exposes ``.state_dict()`` and
``.load_state_dict()`` methods. Common PyTorch objects such as
:class:`torch.nn.Module`, :class:`torch.optim.Optimizer`, and
LR schedulers all qualify as stateful objects.
path (str): The location to save the snapshot. ``path`` can have a
URI prefix (e.g. ``s3://``) that specifies a storage backend.
If no URI prefix is supplied, ``path`` is assumed to be a file
system location. For distributed snapshot, if ``path`` is
inconsistent across participating ranks, the value specified by
rank 0 will be used. For multi-host snapshot, ``path`` needs to
be a location accessible by all hosts.
.. note:: ``path`` must **not** point to an existing snapshot.
pg (ProcessGroup, optional): The process group for the participants
of :meth:`Snapshot.take`. If none, the default process group will
be used.
replicated (List[str], optional): Glob patterns for marking
checkpoint content as replicated. Matching objects will be deduped
and load-balanced across ranks.
.. note:: The replication property is automatically inferred
for ``DistributedDataParallel``. Only specify this argument
if your model has fully replicated states but does not use
``DistributedDataParallel``.
storage_options (Dict[str, Any], optional): Additional keyword
options for the storage plugin to use. See each storage plugin's
documentation for customizations.
Returns:
The newly taken snapshot.
Expand Down Expand Up @@ -252,31 +201,23 @@ def async_take(
] = None,
) -> "PendingSnapshot":
"""
Asynchronously take a snapshot from the program state.
Asynchronously takes a snapshot from the application state.
This method creates a consistent snapshot of the app state (i.e.
changes to the app state after this method returns have no effect on
the snapshot). The asynchronicity is a result of performing storage I/O
in the background.
This function is identical to :func:`Snapshot.take`, except that it
returns early and performs as much I/O operations in the background as
possible, allowing training to resume early.
Args:
app_state: The program state to take the snapshot from.
path: The location to save the snapshot.
pg: The process group for the processes taking the snapshot.
When unspecified:
- If distributed is initialized, the global process group will be used.
- If distributed is not initialized, single process is assumed.
replicated: A list of glob patterns for hinting the matching paths
as replicated. Note that patterns not specified by all ranks
are ignored.
storage_options: Additional keyword options for the StoragePlugin to use.
See each StoragePlugin's documentation for customizations.
app_state (Dict[str, Stateful]): Same as the ``app_state`` argument of :func:`Snapshot.take`.
path (str): Same as the ``path`` argument of :func:`Snapshot.take`.
pg (ProcessGroup, optional): Same as the ``pg`` argument of :func:`Snapshot.take`.
replicated (List[str], optional): Same as the ``replicated`` argument of :func:`Snapshot.take`.
storage_options (Dict[str, Any], optional): Same as the ``storage_options`` argument of :func:`Snapshot.take`.
Returns:
A handle with which the newly taken snapshot can be obtained via
`.wait()`. Note that waiting on the handle is optional. The
snapshot will be committed regardless of whether `.wait()` is
invoked.
A handle to the pending snapshot. The handle has exposes a
``.done()`` method for querying the progress and a ``.wait()``
method for waiting for the snapshot's completion.
"""
torch._C._log_api_usage_once("torchsnapshot.Snapshot.async_take")
cls._validate_app_state(app_state)
Expand Down Expand Up @@ -436,11 +377,13 @@ def _take_impl(

def restore(self, app_state: AppState) -> None:
"""
Restores the program state from the snapshot.
Restores the application state from the snapshot.
Args:
app_state: The program state to restore from the snapshot.
app_state (Dict[str, Stateful]): The application state to restore.
``app_state`` needs to be either identical to or a subset of the
``app_state`` used for :func:`Snapshot.take` when the snapshot was
taken.
"""
torch._C._log_api_usage_once("torchsnapshot.Snapshot.restore")
self._validate_app_state(app_state)
Expand Down Expand Up @@ -505,31 +448,25 @@ def read_object(
memory_budget_bytes: Optional[int] = None,
) -> T:
"""
Read a persisted object from the snapshot's content.
The persisted object to read is specified by its path in the snapshot
metadata. Available paths can be obtained via `snapshot.get_manifest()`.
Reads an object from the snapshot's content.
A path in snapshot metadata follows the following format:
``RANK/STATEFUL_NAME/STATE_DICT_KEY[/NESTED_CONTAINER_KEY...]``
Args:
path (str): The path to the target object within the snapshot.
``path`` is equivalent to the target object's key in the
snapshot manifest and can be obtained via
:meth:`Snapshot.get_manifest`.
The rank only matters when the persisted object is "per-rank".
Arbitrary rank can be used when the persisted object is "replicated" or
"sharded".
obj_out (Any, optional): When specified, load the object in-place
into ``obj_out`` if in-place load is supported for the object's
type. Otherwise, ``obj_out`` is ignored.
If the persisted object is a sharded tensor, `obj_out` must be
supplied. The supplied tensor can be either a tensor or sharded tensor.
`read_object` will correctly populate `obj_out`'s data according to
sharding spec.
.. note::
When the target object is a ``ShardedTensor``, ``obj_out``
must be specified.
Args:
path: The path to the persisted object.
obj_out: If specified and the object type supports in-place load,
`read_object` will directly read the persisted object into
`obj_out`'s buffer.
memory_budget_bytes: When specified, the read operation will keep
the temporary memory buffer size below this threshold.
memory_budget_bytes (int, optional): When specified, the read
operation will keep the temporary memory buffer size below this
threshold.
Returns:
The object read from the snapshot's content.
Expand Down Expand Up @@ -595,10 +532,15 @@ def read_object(

def get_manifest(self) -> Dict[str, Entry]:
"""
Returns the snapshot's manifest.
Returns the snapshot manifest.
Each entry in the dictionary corresponds to an object in the snapshot,
with the keys being the logical paths to the objects and the values
being the metadata describing the object. For distributed snapshots,
the manifest contain entries for objects saved by all ranks.
Returns:
The snapshot's manifest.
The snapshot manifest.
"""
return copy.deepcopy(self.metadata.manifest)

Expand Down
Loading

0 comments on commit 0df6f81

Please sign in to comment.