diff --git a/python/cugraph/cugraph/gnn/data_loading/dist_sampler.py b/python/cugraph/cugraph/gnn/data_loading/dist_sampler.py index bccc5faab28..465ead45424 100644 --- a/python/cugraph/cugraph/gnn/data_loading/dist_sampler.py +++ b/python/cugraph/cugraph/gnn/data_loading/dist_sampler.py @@ -236,6 +236,8 @@ def __sample_from_nodes_func( random_state: int, assume_equal_input_size: bool, ) -> Union[None, Iterator[Tuple[Dict[str, "torch.Tensor"], int, int]]]: + torch = import_optional("torch") + current_batches = torch.arange( batch_id_start + call_id * batches_per_call, batch_id_start