From 5f20885249b463a3d5f17471961503917565462f Mon Sep 17 00:00:00 2001 From: "[[ -z $EMAIL ]] && read -e -p \"Enter your email (for git configuration): \" EMAIL" Date: Thu, 9 Nov 2023 12:47:01 -0500 Subject: [PATCH 1/5] Deal with duplicates --- src/accelerate/utils/other.py | 18 +++++++++++++++++- tests/test_utils.py | 18 ++++++++++++++++++ 2 files changed, 35 insertions(+), 1 deletion(-) diff --git a/src/accelerate/utils/other.py b/src/accelerate/utils/other.py index 285dd0a5ad9..34766ba7f46 100644 --- a/src/accelerate/utils/other.py +++ b/src/accelerate/utils/other.py @@ -19,9 +19,11 @@ from contextlib import contextmanager from functools import partial from types import MethodType +from typing import OrderedDict import torch from packaging.version import Version +from safetensors.torch import _remove_duplicate_names from safetensors.torch import save_file as safe_save_file from ..commands.config.default import write_basic_config # noqa: F401 @@ -129,7 +131,21 @@ def save(obj, f, save_on_each_node: bool = False, safe_serialization: bool = Tru safe_serialization (`bool`, *optional*, defaults to `True`): Whether to save the model using `safetensors` or the traditional PyTorch way (that uses `pickle`). """ - save_func = torch.save if not safe_serialization else partial(safe_save_file, metadata={"format": "pt"}) + # Check if it's a model and remove duplicates + if safe_serialization: + metadata = {"format": "pt"} + if isinstance(obj, OrderedDict): + to_removes = _remove_duplicate_names(obj) + for kept_name, to_remove_group in to_removes.items(): + for to_remove in to_remove_group: + if to_remove not in metadata: + metadata[to_remove] = kept_name + del obj[to_remove] + obj = {k: v.contiguous() for k, v in obj.items()} + save_func = partial(safe_save_file, metadata=metadata) + else: + save_func = torch.save + if PartialState().distributed_type == DistributedType.TPU: xm.save(obj, f) elif PartialState().is_main_process and not save_on_each_node: diff --git a/tests/test_utils.py b/tests/test_utils.py index 2bf86314eed..fe04f3fcfce 100644 --- a/tests/test_utils.py +++ b/tests/test_utils.py @@ -14,12 +14,14 @@ import os import pickle +import tempfile import unittest import warnings from collections import UserDict, namedtuple from unittest.mock import Mock, patch import torch +from torch import nn from accelerate.state import PartialState from accelerate.test_utils.testing import require_cuda, require_torch_min_version @@ -32,6 +34,7 @@ listify, patch_environment, recursively_apply, + save, send_to_device, ) @@ -205,3 +208,18 @@ def test_check_os_kernel_warning_when_release_lt_min(self): self.assertEqual(ctx.records[0].levelname, "WARNING") self.assertIn("5.4.0", ctx.records[0].msg) self.assertIn("5.5.0", ctx.records[0].msg) + + def test_save_safetensor_shared_memory(self): + class Model(nn.Module): + def __init__(self): + super().__init__() + self.a = nn.Linear(100, 100) + self.b = self.a + + def forward(self, x): + return self.b(self.a(x)) + + model = Model() + with tempfile.TemporaryDirectory() as tmp_dir: + save_path = os.path.join(tmp_dir, "model.safetensors") + save(model.state_dict(), save_path, safe_serialization=True) From 96a7698e1eea3ce93086d8388773db405a46c15b Mon Sep 17 00:00:00 2001 From: "[[ -z $EMAIL ]] && read -e -p \"Enter your email (for git configuration): \" EMAIL" Date: Fri, 10 Nov 2023 10:04:05 -0500 Subject: [PATCH 2/5] refactor --- src/accelerate/accelerator.py | 33 ++------------------- src/accelerate/utils/__init__.py | 1 + src/accelerate/utils/other.py | 49 +++++++++++++++++++++++++------- 3 files changed, 42 insertions(+), 41 deletions(-) diff --git a/src/accelerate/accelerator.py b/src/accelerate/accelerator.py index 5407395a0df..c464cbd6f28 100755 --- a/src/accelerate/accelerator.py +++ b/src/accelerate/accelerator.py @@ -14,7 +14,6 @@ from __future__ import annotations -import collections import contextlib import functools import json @@ -64,6 +63,7 @@ RNGType, TorchDynamoPlugin, check_os_kernel, + clean_state_dict_for_safetensors, compare_versions, convert_model, convert_outputs_to_fp32, @@ -73,7 +73,6 @@ get_mixed_precision_context_manager, get_pretty_name, has_transformer_engine_layers, - id_tensor_storage, is_bf16_available, is_deepspeed_available, is_fp8_available, @@ -2583,35 +2582,7 @@ def save_model( state_dict = self.get_state_dict(model) if safe_serialization: - # Safetensors does not allow tensor aliasing. - # We're going to remove aliases before saving - ptrs = collections.defaultdict(list) - # when bnb serialization is used the weights in the state dict can be strings - for name, tensor in state_dict.items(): - if not isinstance(tensor, str): - ptrs[id_tensor_storage(tensor)].append(name) - - # These are all the pointers of shared tensors. - shared_ptrs = {ptr: names for ptr, names in ptrs.items() if len(names) > 1} - warn_names = set() - for names in shared_ptrs.values(): - # When not all duplicates have been cleaned, still remove those keys, but put a clear warning. - # If the link between tensors was done at runtime then `from_pretrained` will not get - # the key back leading to random tensor. A proper warning will be shown - # during reload (if applicable), but since the file is not necessarily compatible with - # the config, better show a proper warning. - found = 0 - for name in names: - if name in state_dict: - found += 1 - if found > 1: - del state_dict[name] - warn_names.add(name) - if len(warn_names) > 0: - logger.warning( - f"Removed shared tensor {warn_names} while saving. This should be OK, but check by verifying that you don't receive any warning while reloading", - ) - + state_dict = clean_state_dict_for_safetensors(state_dict) weights_name = SAFE_WEIGHTS_NAME if safe_serialization else WEIGHTS_NAME # Shard the model if it is too big. diff --git a/src/accelerate/utils/__init__.py b/src/accelerate/utils/__init__.py index 96e3fe61035..88cc927f001 100644 --- a/src/accelerate/utils/__init__.py +++ b/src/accelerate/utils/__init__.py @@ -168,6 +168,7 @@ from .memory import find_executable_batch_size, release_memory from .other import ( check_os_kernel, + clean_state_dict_for_safetensors, clear_environment, convert_bytes, extract_model_from_parallel, diff --git a/src/accelerate/utils/other.py b/src/accelerate/utils/other.py index 34766ba7f46..c1c4b4e95be 100644 --- a/src/accelerate/utils/other.py +++ b/src/accelerate/utils/other.py @@ -12,6 +12,7 @@ # See the License for the specific language governing permissions and # limitations under the License. +import collections import os import platform import re @@ -23,7 +24,6 @@ import torch from packaging.version import Version -from safetensors.torch import _remove_duplicate_names from safetensors.torch import save_file as safe_save_file from ..commands.config.default import write_basic_config # noqa: F401 @@ -32,6 +32,7 @@ from .constants import FSDP_PYTORCH_VERSION from .dataclasses import DistributedType from .imports import is_deepspeed_available, is_torch_distributed_available, is_tpu_available +from .modeling import id_tensor_storage from .transformer_engine import convert_model from .versions import is_torch_version @@ -117,6 +118,41 @@ def wait_for_everyone(): PartialState().wait_for_everyone() +def clean_state_dict_for_safetensors(state_dict: dict): + """ + Cleans the state dictionary from a model and removes tensor aliasing if present. + + Args: + state_dict (`dict`): + The state dictionary from a model + """ + ptrs = collections.defaultdict(list) + # When bnb serialization is used, weights in state dict can eb strings + for name, tensor in state_dict.items(): + if not isinstance(tensor, str): + ptrs[id_tensor_storage(tensor)].append(name) + + # These are all pointers of tensors with shared memory + shared_ptrs = {ptr: names for ptr, names in ptrs.items() if len(names) > 1} + warn_names = set() + for names in shared_ptrs.values(): + # When not all duplicates have been cleaned, we still remove those keys but put a clear warning. + # If the link between tensors was done at runtime then `from_pretrained` will not get + # the key back leading to random tensor. A proper warning will be shown + # during reload (if applicable), but since the file is not necessarily compatible with + # the config, better show a proper warning. + found_names = [name for name in names if name in state_dict] + warn_names.update(found_names[1:]) + for name in found_names[1:]: + del state_dict[name] + if len(warn_names) > 0: + logger.warning( + f"Removed shared tensor {warn_names} while saving. This should be OK, but check by verifying that you don't receive any warning while reloading", + ) + state_dict = {k: v.contiguous() for k, v in state_dict.items()} + return state_dict + + def save(obj, f, save_on_each_node: bool = False, safe_serialization: bool = True): """ Save the data to disk. Use in place of `torch.save()`. @@ -133,16 +169,9 @@ def save(obj, f, save_on_each_node: bool = False, safe_serialization: bool = Tru """ # Check if it's a model and remove duplicates if safe_serialization: - metadata = {"format": "pt"} + save_func = partial(safe_save_file, metadata={"format": "pt"}) if isinstance(obj, OrderedDict): - to_removes = _remove_duplicate_names(obj) - for kept_name, to_remove_group in to_removes.items(): - for to_remove in to_remove_group: - if to_remove not in metadata: - metadata[to_remove] = kept_name - del obj[to_remove] - obj = {k: v.contiguous() for k, v in obj.items()} - save_func = partial(safe_save_file, metadata=metadata) + obj = clean_state_dict_for_safetensors(obj) else: save_func = torch.save From aff488bb0bee8ed2316a32039d9d29074b47fc0e Mon Sep 17 00:00:00 2001 From: "[[ -z $EMAIL ]] && read -e -p \"Enter your email (for git configuration): \" EMAIL" Date: Fri, 10 Nov 2023 10:12:46 -0500 Subject: [PATCH 3/5] Keep false for save --- src/accelerate/utils/other.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/src/accelerate/utils/other.py b/src/accelerate/utils/other.py index 5be44a01600..026e16166e1 100644 --- a/src/accelerate/utils/other.py +++ b/src/accelerate/utils/other.py @@ -152,7 +152,7 @@ def clean_state_dict_for_safetensors(state_dict: dict): state_dict = {k: v.contiguous() for k, v in state_dict.items()} return state_dict -def save(obj, f, save_on_each_node: bool = False, safe_serialization: bool = True): +def save(obj, f, save_on_each_node: bool = False, safe_serialization: bool = False): """ Save the data to disk. Use in place of `torch.save()`. @@ -163,7 +163,7 @@ def save(obj, f, save_on_each_node: bool = False, safe_serialization: bool = Tru The file (or file-like object) to use to save the data save_on_each_node (`bool`, *optional*, defaults to `False`): Whether to only save on the global main process - safe_serialization (`bool`, *optional*, defaults to `True`): + safe_serialization (`bool`, *optional*, defaults to `False`): Whether to save `obj` using `safetensors` or the traditional PyTorch way (that uses `pickle`). """ # Check if it's a model and remove duplicates From 2ca09738cd96ebe54604c33f9a777ebb1083263d Mon Sep 17 00:00:00 2001 From: "[[ -z $EMAIL ]] && read -e -p \"Enter your email (for git configuration): \" EMAIL" Date: Fri, 10 Nov 2023 10:23:39 -0500 Subject: [PATCH 4/5] Clean --- src/accelerate/utils/other.py | 1 + 1 file changed, 1 insertion(+) diff --git a/src/accelerate/utils/other.py b/src/accelerate/utils/other.py index 026e16166e1..b55b83565d6 100644 --- a/src/accelerate/utils/other.py +++ b/src/accelerate/utils/other.py @@ -152,6 +152,7 @@ def clean_state_dict_for_safetensors(state_dict: dict): state_dict = {k: v.contiguous() for k, v in state_dict.items()} return state_dict + def save(obj, f, save_on_each_node: bool = False, safe_serialization: bool = False): """ Save the data to disk. Use in place of `torch.save()`. From 32f7bd0b29eec7e59c2a675c105182488f622a32 Mon Sep 17 00:00:00 2001 From: "[[ -z $EMAIL ]] && read -e -p \"Enter your email (for git configuration): \" EMAIL" Date: Fri, 10 Nov 2023 10:29:29 -0500 Subject: [PATCH 5/5] Better test for logs --- tests/test_utils.py | 5 ++++- 1 file changed, 4 insertions(+), 1 deletion(-) diff --git a/tests/test_utils.py b/tests/test_utils.py index fe04f3fcfce..fa23e72986d 100644 --- a/tests/test_utils.py +++ b/tests/test_utils.py @@ -222,4 +222,7 @@ def forward(self, x): model = Model() with tempfile.TemporaryDirectory() as tmp_dir: save_path = os.path.join(tmp_dir, "model.safetensors") - save(model.state_dict(), save_path, safe_serialization=True) + with self.assertLogs(level="WARNING") as log: + save(model.state_dict(), save_path, safe_serialization=True) + self.assertEqual(len(log.records), 1) + self.assertIn("Removed shared tensor", log.output[0])