Skip to content

Commit

Permalink
fix NestedTensor.__getitem__
Browse files Browse the repository at this point in the history
Signed-off-by: Zhiyuan Chen <[email protected]>
  • Loading branch information
ZhiyuanChen committed Sep 19, 2024
1 parent 5ec77bb commit d1042f7
Show file tree
Hide file tree
Showing 2 changed files with 15 additions and 9 deletions.
2 changes: 1 addition & 1 deletion danling/metrics/metrics.py
Original file line number Diff line number Diff line change
Expand Up @@ -126,7 +126,7 @@ class Metrics(Metric):
>>> metrics = Metrics(auroc=auroc, auprc=auprc, ignored_index=-100)
>>> metrics.update([[0.1, 0.4, 0.6, 0.8], [0.1, 0.4, 0.6]], [[0, -100, 1, 0], [0, -100, 1]])
>>> metrics.input, metrics.target
(PNTensor([0.1000, 0.6000, 0.8000, 0.1000, 0.6000]), PNTensor([0, 1, 0, 0, 1]))
(tensor([0.1000, 0.6000, 0.8000, 0.1000, 0.6000]), tensor([0, 1, 0, 0, 1]))
"""

metrics: FlatDict[str, Callable]
Expand Down
22 changes: 14 additions & 8 deletions danling/tensors/nested_tensor.py
Original file line number Diff line number Diff line change
Expand Up @@ -177,7 +177,7 @@ class NestedTensor:
[4, 5, 0]]), tensor([[ True, True, True],
[ True, True, False]]))
>>> nested_tensor[1]
(tensor([4, 5]), tensor([True, True]))
tensor([4, 5])
>>> nested_tensor[:, 1:]
NestedTensor([[2, 3],
[5, 0]])
Expand Down Expand Up @@ -219,7 +219,7 @@ def _storage(self, tensors: Sequence):
if len(tensors) == 0:
raise ValueError("tensors must be a non-empty Iterable.")
if not isinstance(tensors[0], Tensor):
tensors = [tensor(t) for t in tensors]
tensors = [torch.tensor(t) for t in tensors]
# if drop_last=False, the last element is likely not a NestedTensor and has an extra batch dimension
ndims = {t.ndim for t in tensors[:-1]}
if len(ndims) == 1:
Expand Down Expand Up @@ -556,14 +556,20 @@ def __torch_function__(cls, func, types, args=(), kwargs=None):
return func(*args, **kwargs)
return NestedTensorFunc[func](*args, **kwargs)

def __getitem__(self, index: int | slice | tuple) -> tuple[Tensor, Tensor] | NestedTensor:
def __getitem__(self, index: int | slice | list | tuple) -> Tensor | tuple[Tensor, Tensor] | NestedTensor:
if isinstance(index, int):
return self._storage[index]
if isinstance(index, (slice, list)):
storage = tuple(self._storage[index] if isinstance(index, slice) else [self._storage[i] for i in index])
pad = pad_tensor(
storage, size=self._size(storage), batch_first=self.batch_first, padding_value=float(self.padding_value)
)
mask = mask_tensor(
storage, size=self._size(storage), batch_first=self.batch_first, mask_value=self.mask_value
)
return pad, mask
if isinstance(index, tuple):
return NestedTensor([t[index[0]][index[1:]] for t in self._storage])
if isinstance(index, (int, slice)):
ret = self._storage[index]
if isinstance(ret, Tensor):
return ret, torch.ones_like(ret, dtype=torch.bool)
return self.tensor, self.mask
raise ValueError(f"Unsupported index type {type(index)}")

def __getattr__(self, name: str) -> Any:
Expand Down

0 comments on commit d1042f7

Please sign in to comment.