Skip to content

Commit

Permalink
fix detection of duplicate tensors
Browse files Browse the repository at this point in the history
  • Loading branch information
fxmarty committed Nov 6, 2023
1 parent 96061e9 commit f04d064
Showing 1 changed file with 6 additions and 1 deletion.
7 changes: 6 additions & 1 deletion bindings/python/py_src/safetensors/torch.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,6 +41,10 @@ def storage_size(tensor: torch.Tensor) -> int:
return tensor.nelement() * _SIZE[tensor.dtype]


def storage_offset(tensor: torch.Tensor) -> int:
return tensor.storage_offset()


def _filter_shared_not_shared(tensors: List[Set[str]], state_dict: Dict[str, torch.Tensor]) -> List[Set[str]]:
filtered_tensors = []
for shared in tensors:
Expand Down Expand Up @@ -71,7 +75,8 @@ def _find_shared_tensors(state_dict: Dict[str, torch.Tensor]) -> List[Set[str]]:
for k, v in state_dict.items():
if v.device != torch.device("meta") and storage_ptr(v) != 0 and storage_size(v) != 0:
# Need to add device as key because of multiple GPU.
tensors[(v.device, storage_ptr(v), storage_size(v))].add(k)
# Need to add storage_offset as key because views may share the same data_ptr.
tensors[(v.device, storage_ptr(v), storage_size(v), storage_offset(v))].add(k)
tensors = list(sorted(tensors.values()))
tensors = _filter_shared_not_shared(tensors, state_dict)
return tensors
Expand Down

0 comments on commit f04d064

Please sign in to comment.