diff --git a/bindings/python/py_src/safetensors/torch.py b/bindings/python/py_src/safetensors/torch.py index 7fa59675..cdde1a23 100644 --- a/bindings/python/py_src/safetensors/torch.py +++ b/bindings/python/py_src/safetensors/torch.py @@ -41,6 +41,10 @@ def storage_size(tensor: torch.Tensor) -> int: return tensor.nelement() * _SIZE[tensor.dtype] +def storage_offset(tensor: torch.Tensor) -> int: + return tensor.storage_offset() + + def _filter_shared_not_shared(tensors: List[Set[str]], state_dict: Dict[str, torch.Tensor]) -> List[Set[str]]: filtered_tensors = [] for shared in tensors: @@ -71,7 +75,8 @@ def _find_shared_tensors(state_dict: Dict[str, torch.Tensor]) -> List[Set[str]]: for k, v in state_dict.items(): if v.device != torch.device("meta") and storage_ptr(v) != 0 and storage_size(v) != 0: # Need to add device as key because of multiple GPU. - tensors[(v.device, storage_ptr(v), storage_size(v))].add(k) + # Need to add storage_offset as key because views may share the same data_ptr. + tensors[(v.device, storage_ptr(v), storage_size(v), storage_offset(v))].add(k) tensors = list(sorted(tensors.values())) tensors = _filter_shared_not_shared(tensors, state_dict) return tensors