Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Deal with shared memory scenarios #2136

Merged
merged 6 commits into from
Nov 10, 2023
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
18 changes: 17 additions & 1 deletion src/accelerate/utils/other.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,9 +19,11 @@
from contextlib import contextmanager
from functools import partial
from types import MethodType
from typing import OrderedDict

import torch
from packaging.version import Version
from safetensors.torch import _remove_duplicate_names
from safetensors.torch import save_file as safe_save_file

from ..commands.config.default import write_basic_config # noqa: F401
Expand Down Expand Up @@ -129,7 +131,21 @@ def save(obj, f, save_on_each_node: bool = False, safe_serialization: bool = Tru
safe_serialization (`bool`, *optional*, defaults to `True`):
Whether to save the model using `safetensors` or the traditional PyTorch way (that uses `pickle`).
"""
save_func = torch.save if not safe_serialization else partial(safe_save_file, metadata={"format": "pt"})
# Check if it's a model and remove duplicates
if safe_serialization:
metadata = {"format": "pt"}
if isinstance(obj, OrderedDict):
to_removes = _remove_duplicate_names(obj)
for kept_name, to_remove_group in to_removes.items():
for to_remove in to_remove_group:
if to_remove not in metadata:
metadata[to_remove] = kept_name
del obj[to_remove]
obj = {k: v.contiguous() for k, v in obj.items()}
save_func = partial(safe_save_file, metadata=metadata)
else:
save_func = torch.save

if PartialState().distributed_type == DistributedType.TPU:
xm.save(obj, f)
elif PartialState().is_main_process and not save_on_each_node:
Expand Down
18 changes: 18 additions & 0 deletions tests/test_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,12 +14,14 @@

import os
import pickle
import tempfile
import unittest
import warnings
from collections import UserDict, namedtuple
from unittest.mock import Mock, patch

import torch
from torch import nn

from accelerate.state import PartialState
from accelerate.test_utils.testing import require_cuda, require_torch_min_version
Expand All @@ -32,6 +34,7 @@
listify,
patch_environment,
recursively_apply,
save,
send_to_device,
)

Expand Down Expand Up @@ -205,3 +208,18 @@ def test_check_os_kernel_warning_when_release_lt_min(self):
self.assertEqual(ctx.records[0].levelname, "WARNING")
self.assertIn("5.4.0", ctx.records[0].msg)
self.assertIn("5.5.0", ctx.records[0].msg)

def test_save_safetensor_shared_memory(self):
class Model(nn.Module):
def __init__(self):
super().__init__()
self.a = nn.Linear(100, 100)
self.b = self.a

def forward(self, x):
return self.b(self.a(x))

model = Model()
with tempfile.TemporaryDirectory() as tmp_dir:
save_path = os.path.join(tmp_dir, "model.safetensors")
save(model.state_dict(), save_path, safe_serialization=True)
Loading