Skip to content

Commit

Permalink
Enable shared file system with save and save_state via ProjectCon…
Browse files Browse the repository at this point in the history
…figuration (#1953)

* Support shared storage, start

* Pass use_local_node_storage

* Reverse and different namings

* Not global only

* Addres comments

* Clean

* Apply suggestions from code review

Co-authored-by: Sourab Mangrulkar <[email protected]>

* Save on each node as explicit arg

* More explicit

---------

Co-authored-by: Sourab Mangrulkar <[email protected]>
  • Loading branch information
muellerzr and pacman100 authored Oct 3, 2023
1 parent 76ee7f2 commit 956114a
Show file tree
Hide file tree
Showing 4 changed files with 53 additions and 19 deletions.
21 changes: 18 additions & 3 deletions src/accelerate/accelerator.py
Original file line number Diff line number Diff line change
Expand Up @@ -2508,6 +2508,10 @@ def save(self, obj, f, safe_serialization=False):
f (`str` or `os.PathLike`): Where to save the content of `obj`.
safe_serialization (`bool`, *optional*, defaults to `False`): Whether to save `obj` using `safetensors`
Note:
If `save_on_each_node` was passed in as a `ProjectConfiguration`, will save the object once per node,
rather than only once on the main node.
Example:
```python
Expand All @@ -2518,7 +2522,12 @@ def save(self, obj, f, safe_serialization=False):
>>> accelerator.save(arr, "array.pkl")
```
"""
save(obj, f, safe_serialization=safe_serialization)
save(
obj,
f,
save_on_each_node=self.project_configuration.save_on_each_node,
safe_serialization=safe_serialization,
)

def save_model(
self,
Expand Down Expand Up @@ -2793,10 +2802,16 @@ def _inner(folder):
hook(self._models, weights, output_dir)

save_location = save_accelerator_state(
output_dir, weights, optimizers, schedulers, self.state.process_index, self.scaler
output_dir,
weights,
optimizers,
schedulers,
self.state.process_index,
self.scaler,
save_on_each_node=self.project_configuration.save_on_each_node,
)
for i, obj in enumerate(self._custom_objects):
save_custom_state(obj, output_dir, i)
save_custom_state(obj, output_dir, i, save_on_each_node=self.project_configuration.save_on_each_node)
self.project_configuration.iteration += 1
return save_location

Expand Down
15 changes: 9 additions & 6 deletions src/accelerate/checkpointing.py
Original file line number Diff line number Diff line change
Expand Up @@ -51,6 +51,7 @@ def save_accelerator_state(
schedulers: list,
process_index: int,
scaler: GradScaler = None,
save_on_each_node: bool = False,
):
"""
Saves the current states of the models, optimizers, scaler, and RNG generators to a given directory.
Expand All @@ -68,32 +69,34 @@ def save_accelerator_state(
The current process index in the Accelerator state
scaler (`torch.cuda.amp.GradScaler`, *optional*):
An optional gradient scaler instance to save
save_on_each_node (`bool`, *optional*):
Whether to save on every node, or only the main node.
"""
# Model states
for i, state in enumerate(model_states):
weights_name = f"{MODEL_NAME}.bin" if i == 0 else f"{MODEL_NAME}_{i}.bin"
output_model_file = os.path.join(output_dir, weights_name)
save(state, output_model_file)
save(state, output_model_file, save_on_each_node=save_on_each_node)
logger.info(f"Model weights saved in {output_model_file}")
# Optimizer states
for i, opt in enumerate(optimizers):
state = opt.state_dict()
optimizer_name = f"{OPTIMIZER_NAME}.bin" if i == 0 else f"{OPTIMIZER_NAME}_{i}.bin"
output_optimizer_file = os.path.join(output_dir, optimizer_name)
save(state, output_optimizer_file)
save(state, output_optimizer_file, save_on_each_node=save_on_each_node)
logger.info(f"Optimizer state saved in {output_optimizer_file}")
# Scheduler states
for i, scheduler in enumerate(schedulers):
state = scheduler.state_dict()
scheduler_name = f"{SCHEDULER_NAME}.bin" if i == 0 else f"{SCHEDULER_NAME}_{i}.bin"
output_scheduler_file = os.path.join(output_dir, scheduler_name)
save(state, output_scheduler_file)
save(state, output_scheduler_file, save_on_each_node=save_on_each_node)
logger.info(f"Scheduler state saved in {output_scheduler_file}")
# GradScaler state
if scaler is not None:
state = scaler.state_dict()
output_scaler_file = os.path.join(output_dir, SCALER_NAME)
torch.save(state, output_scaler_file)
torch.save(state, output_scaler_file, save_on_each_node=save_on_each_node)
logger.info(f"Gradient scaler state saved in {output_scaler_file}")
# Random number generator states
states = {}
Expand Down Expand Up @@ -197,14 +200,14 @@ def load_accelerator_state(
logger.info("Could not load random states")


def save_custom_state(obj, path, index: int = 0):
def save_custom_state(obj, path, index: int = 0, save_on_each_node: bool = False):
"""
Saves the state of `obj` to `{path}/custom_checkpoint_{index}.pkl`
"""
# Should this be the right way to get a qual_name type value from `obj`?
save_location = Path(path) / f"custom_checkpoint_{index}.pkl"
logger.info(f"Saving the state of {get_pretty_name(obj)} to {save_location}")
torch.save(obj.state_dict(), save_location)
save(obj.state_dict(), save_location, save_on_each_node=save_on_each_node)


def load_custom_state(obj, path, index: int = 0):
Expand Down
10 changes: 10 additions & 0 deletions src/accelerate/utils/dataclasses.py
Original file line number Diff line number Diff line change
Expand Up @@ -420,6 +420,16 @@ class ProjectConfiguration:
metadata={"help": "The current save iteration."},
)

save_on_each_node: bool = field(
default=False,
metadata={
"help": (
"When doing multi-node distributed training, whether to save models and checkpoints on each node, or"
" only on the main one"
)
},
)

def set_directories(self, project_dir: str = None):
"Sets `self.project_dir` and `self.logging_dir` to the appropriate values."
self.project_dir = project_dir
Expand Down
26 changes: 16 additions & 10 deletions src/accelerate/utils/other.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
import os
import socket
from contextlib import contextmanager
from functools import partial
from types import MethodType

import torch
Expand Down Expand Up @@ -109,22 +110,27 @@ def wait_for_everyone():
PartialState().wait_for_everyone()


def save(obj, f, safe_serialization=False):
def save(obj, f, save_on_each_node: bool = False, safe_serialization: bool = False):
"""
Save the data to disk. Use in place of `torch.save()`.
Args:
obj: The data to save
f: The file (or file-like object) to use to save the data
safe_serialization (`bool`, *optional*, defaults to `False`): Whether to save `obj` using `safetensors`
"""
obj:
The data to save
f:
The file (or file-like object) to use to save the data
save_on_each_node (`bool`, *optional*, defaults to `False`):
Whether to only save on the global main process
safe_serialization (`bool`, *optional*, defaults to `False`):
Whether to save `obj` using `safetensors`
"""
save_func = torch.save if not safe_serialization else partial(safe_save_file, metadata={"format": "pt"})
if PartialState().distributed_type == DistributedType.TPU:
xm.save(obj, f)
elif PartialState().local_process_index == 0:
if safe_serialization:
safe_save_file(obj, f, metadata={"format": "pt"})
else:
torch.save(obj, f)
elif PartialState().is_main_process and not save_on_each_node:
save_func(obj, f)
elif PartialState().is_local_main_process and save_on_each_node:
save_func(obj, f)


@contextmanager
Expand Down

0 comments on commit 956114a

Please sign in to comment.