Skip to content

Commit

Permalink
fix return dtype of getitem Tensor indexing (tinygrad#4158)
Browse files Browse the repository at this point in the history
the use of sum can auto-upcast the result. fixed by using the data dtype as the acc_dtype
  • Loading branch information
chenyuxyz authored Apr 12, 2024
1 parent f6c8032 commit d9c5a2b
Show file tree
Hide file tree
Showing 2 changed files with 9 additions and 2 deletions.
7 changes: 7 additions & 0 deletions test/test_dtype.py
Original file line number Diff line number Diff line change
Expand Up @@ -439,6 +439,13 @@ def test_functions_return_index(self, dtype, default_int, default_float):
assert Tensor([0, 1], dtype=dtype).argmin().dtype == dtypes.int32
assert Tensor([0, 1], dtype=dtype).multinomial().dtype == dtypes.int32

@given(strat.sampled_from(core_dtypes), strat.sampled_from(dtype_ints))
def test_tensor_indexing_returns_same_dtype(self, data_dtype, indices_dtype):
X_data = Tensor.rand(60000, 1, 28, 28, dtype=data_dtype)
indices = Tensor.randint(512, high=X_data.shape[0]).cast(indices_dtype)
X = X_data[indices]
assert X.dtype == X_data.dtype

class TestTypePromotion(unittest.TestCase):
@given(strat.sampled_from(core_dtypes))
def test_self_promo_to_self(self, dtype):
Expand Down
4 changes: 2 additions & 2 deletions tinygrad/tensor.py
Original file line number Diff line number Diff line change
Expand Up @@ -502,12 +502,12 @@ def calc_dim(tensor_dim:int) -> int:
masks.append(i == a)

# reduce masks to 1 mask
mask = functools.reduce(lambda x,y: x.mul(y), masks)
mask: Tensor = functools.reduce(lambda x,y: x.mul(y), masks)

# inject 1's for the extra dims added in create masks
sh = ret.shape[:first_dim] + (1,) * len(big_shape) + ret.shape[first_dim:]
# sum reduce the extra dims introduced in create masks
ret = (ret.reshape(sh) * mask).sum(tuple(i + len(big_shape) for i in idx.keys()))
ret = (ret.reshape(sh) * mask).sum(tuple(i + len(big_shape) for i in idx.keys()), acc_dtype=ret.dtype)

# special permute case
if first_dim != 0 and len(idx) != 1 and tuple(idx.keys()) != tuple(range(first_dim, last_dim+1)):
Expand Down

0 comments on commit d9c5a2b

Please sign in to comment.