diff --git a/fbgemm_gpu/src/sparse_ops/sparse_segment_sum_csr.cu b/fbgemm_gpu/src/sparse_ops/sparse_segment_sum_csr.cu index e2b0a30656..9ff9c23d5a 100644 --- a/fbgemm_gpu/src/sparse_ops/sparse_segment_sum_csr.cu +++ b/fbgemm_gpu/src/sparse_ops/sparse_segment_sum_csr.cu @@ -57,7 +57,14 @@ DLL_PUBLIC Tensor segment_sum_csr_cuda( CUDA_DEVICE_GUARD(values); + TORCH_CHECK(csr_seg.numel() >= 1, "The csr_seg tensor should not be empty") + auto output = at::empty(csr_seg.numel() - 1, values.options()); + + if (csr_seg.numel() == 1) { + return output; + } + constexpr uint32_t threads_per_block = 256; const uint32_t num_blocks = csr_seg.numel() - 1; diff --git a/fbgemm_gpu/test/sparse/misc_ops_test.py b/fbgemm_gpu/test/sparse/misc_ops_test.py index a209e7f0a0..11a0a0ea23 100644 --- a/fbgemm_gpu/test/sparse/misc_ops_test.py +++ b/fbgemm_gpu/test/sparse/misc_ops_test.py @@ -174,6 +174,24 @@ def test_segment_sum_csr(self) -> None: segment_sum_cuda.cpu(), torch.Tensor([10.0, 11.0, 34.0]), rtol=0, atol=0 ) + def test_segment_sum_csr_empty_input(self) -> None: + segment_sum_cpu = torch.ops.fbgemm.segment_sum_csr( + 0, + torch.IntTensor([0]), + torch.Tensor([]), + ) + torch.testing.assert_close(segment_sum_cpu.numel(), 0, rtol=0, atol=0) + + if torch.cuda.is_available(): + segment_sum_cuda = torch.ops.fbgemm.segment_sum_csr( + 0, + torch.IntTensor([0]).cuda(), + torch.Tensor([]).cuda(), + ) + torch.testing.assert_close( + segment_sum_cuda.cpu().numel(), 0, rtol=0, atol=0 + ) + @given( batch_size=st.just(2), m=st.just(3),