From d1042f7975bf1cd7ea65de4135e640a3d4c0a482 Mon Sep 17 00:00:00 2001 From: Zhiyuan Chen Date: Thu, 19 Sep 2024 15:04:02 +0800 Subject: [PATCH] fix NestedTensor.__getitem__ Signed-off-by: Zhiyuan Chen --- danling/metrics/metrics.py | 2 +- danling/tensors/nested_tensor.py | 22 ++++++++++++++-------- 2 files changed, 15 insertions(+), 9 deletions(-) diff --git a/danling/metrics/metrics.py b/danling/metrics/metrics.py index 958db33a..117b62b6 100644 --- a/danling/metrics/metrics.py +++ b/danling/metrics/metrics.py @@ -126,7 +126,7 @@ class Metrics(Metric): >>> metrics = Metrics(auroc=auroc, auprc=auprc, ignored_index=-100) >>> metrics.update([[0.1, 0.4, 0.6, 0.8], [0.1, 0.4, 0.6]], [[0, -100, 1, 0], [0, -100, 1]]) >>> metrics.input, metrics.target - (PNTensor([0.1000, 0.6000, 0.8000, 0.1000, 0.6000]), PNTensor([0, 1, 0, 0, 1])) + (tensor([0.1000, 0.6000, 0.8000, 0.1000, 0.6000]), tensor([0, 1, 0, 0, 1])) """ metrics: FlatDict[str, Callable] diff --git a/danling/tensors/nested_tensor.py b/danling/tensors/nested_tensor.py index 3457dfd4..11586da9 100644 --- a/danling/tensors/nested_tensor.py +++ b/danling/tensors/nested_tensor.py @@ -177,7 +177,7 @@ class NestedTensor: [4, 5, 0]]), tensor([[ True, True, True], [ True, True, False]])) >>> nested_tensor[1] - (tensor([4, 5]), tensor([True, True])) + tensor([4, 5]) >>> nested_tensor[:, 1:] NestedTensor([[2, 3], [5, 0]]) @@ -219,7 +219,7 @@ def _storage(self, tensors: Sequence): if len(tensors) == 0: raise ValueError("tensors must be a non-empty Iterable.") if not isinstance(tensors[0], Tensor): - tensors = [tensor(t) for t in tensors] + tensors = [torch.tensor(t) for t in tensors] # if drop_last=False, the last element is likely not a NestedTensor and has an extra batch dimension ndims = {t.ndim for t in tensors[:-1]} if len(ndims) == 1: @@ -556,14 +556,20 @@ def __torch_function__(cls, func, types, args=(), kwargs=None): return func(*args, **kwargs) return NestedTensorFunc[func](*args, **kwargs) - def __getitem__(self, index: int | slice | tuple) -> tuple[Tensor, Tensor] | NestedTensor: + def __getitem__(self, index: int | slice | list | tuple) -> Tensor | tuple[Tensor, Tensor] | NestedTensor: + if isinstance(index, int): + return self._storage[index] + if isinstance(index, (slice, list)): + storage = tuple(self._storage[index] if isinstance(index, slice) else [self._storage[i] for i in index]) + pad = pad_tensor( + storage, size=self._size(storage), batch_first=self.batch_first, padding_value=float(self.padding_value) + ) + mask = mask_tensor( + storage, size=self._size(storage), batch_first=self.batch_first, mask_value=self.mask_value + ) + return pad, mask if isinstance(index, tuple): return NestedTensor([t[index[0]][index[1:]] for t in self._storage]) - if isinstance(index, (int, slice)): - ret = self._storage[index] - if isinstance(ret, Tensor): - return ret, torch.ones_like(ret, dtype=torch.bool) - return self.tensor, self.mask raise ValueError(f"Unsupported index type {type(index)}") def __getattr__(self, name: str) -> Any: