-
Notifications
You must be signed in to change notification settings - Fork 4.3k
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Showing
11 changed files
with
208 additions
and
148 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,3 @@ | ||
from .api import init_as_ptensor, is_padded_tensor, to_padded_tensor, to_unpadded_tensor | ||
|
||
__all__ = ["is_padded_tensor", "to_padded_tensor", "to_unpadded_tensor", "init_as_ptensor"] |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,128 @@ | ||
import torch | ||
|
||
|
||
def _hijack_detach_and_clone(ptensor: torch.Tensor) -> torch.Tensor: | ||
""" | ||
Hijack the detach and clone methods of the tensor to make sure the dist_layout is copied. | ||
Args: | ||
tensor (torch.Tensor): The tensor to be hijacked. | ||
Returns: | ||
torch.Tensor: The hijacked tensor. | ||
""" | ||
ptensor._unpad_detach = ptensor.detach | ||
ptensor._unpad_clone = ptensor.clone | ||
|
||
def new_detach(self): | ||
t_ = self._unpad_detach() | ||
t_.padding_dim = self.padding_dim | ||
t_.origin_length = self.origin_length | ||
t_.current_length = self.current_length | ||
return t_ | ||
|
||
def new_clone(self, *args, **kwargs): | ||
t_ = self._unpad_clone(*args, **kwargs) | ||
t_.padding_dim = self.padding_dim | ||
t_.origin_length = self.origin_length | ||
t_.current_length = self.current_length | ||
return t_ | ||
|
||
# bind the new methods to the tensor | ||
ptensor.detach = new_detach.__get__(ptensor) | ||
ptensor.clone = new_clone.__get__(ptensor) | ||
return ptensor | ||
|
||
|
||
def _hijack_back_detach_and_clone(ptensor: torch.Tensor) -> torch.Tensor: | ||
""" | ||
Hijack the detach and clone methods of the tensor to make sure the dist_layout is copied. | ||
Args: | ||
tensor (torch.Tensor): The tensor to be hijacked. | ||
Returns: | ||
torch.Tensor: The hijacked tensor. | ||
""" | ||
ptensor.detach = ptensor._unpad_detach | ||
ptensor.clone = ptensor._unpad_clone | ||
|
||
delattr(ptensor, "_unpad_detach") | ||
delattr(ptensor, "_unpad_clone") | ||
|
||
return ptensor | ||
|
||
|
||
def is_padded_tensor(tensor: torch.Tensor) -> bool: | ||
""" | ||
Check whether the given tensor is a padding tensor. | ||
Args: | ||
tensor (torch.Tensor): The tensor to be checked. | ||
Returns: | ||
bool: Whether the given tensor is a padding tensor. | ||
""" | ||
return hasattr(tensor, "padding_dim") | ||
|
||
|
||
def to_padded_tensor( | ||
tensor: torch.Tensor, | ||
current_length: int, | ||
padding_dim: int, | ||
) -> torch.Tensor: | ||
assert ( | ||
padding_dim < tensor.dim() | ||
), f"Please passing a valid padding_dim. the dimension of the tensor is {tensor.dim()}" | ||
|
||
if is_padded_tensor(tensor): | ||
return tensor | ||
|
||
origin_length = tensor.shape[padding_dim] | ||
padding_num = current_length - origin_length | ||
padding_data = torch.zeros( | ||
*tensor.shape[:padding_dim], | ||
padding_num, | ||
*tensor.shape[padding_dim + 1 :], | ||
device=tensor.device, | ||
dtype=tensor.dtype, | ||
) | ||
tensor.data = torch.cat((tensor.data, padding_data), dim=padding_dim).contiguous() | ||
|
||
setattr(tensor, "padding_dim", padding_dim) | ||
setattr(tensor, "origin_length", origin_length) | ||
setattr(tensor, "current_length", current_length) | ||
|
||
_hijack_detach_and_clone(tensor) | ||
|
||
return tensor | ||
|
||
|
||
def to_unpadded_tensor(ptensor: torch.Tensor): | ||
if not is_padded_tensor(ptensor): | ||
return ptensor | ||
|
||
unpad_slices = [slice(None)] * ptensor.dim() | ||
unpad_slices[ptensor.padding_dim] = slice(None, ptensor.origin_length) | ||
ptensor.data = ptensor.data[tuple(unpad_slices)] | ||
|
||
delattr(ptensor, "padding_dim") | ||
delattr(ptensor, "origin_length") | ||
delattr(ptensor, "current_length") | ||
|
||
_hijack_back_detach_and_clone(ptensor) | ||
|
||
return ptensor | ||
|
||
|
||
def init_as_ptensor(tensor: torch.Tensor, current_length: int, origin_length: int, padding_dim: int): | ||
if is_padded_tensor(tensor): | ||
return tensor | ||
|
||
setattr(tensor, "padding_dim", padding_dim) | ||
setattr(tensor, "origin_length", origin_length) | ||
setattr(tensor, "current_length", current_length) | ||
|
||
_hijack_detach_and_clone(tensor) | ||
|
||
return tensor |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Oops, something went wrong.