Skip to content

Commit

Permalink
import torch
Browse files Browse the repository at this point in the history
  • Loading branch information
alexbarghi-nv committed Aug 22, 2024
1 parent 1fcd982 commit a22ce90
Showing 1 changed file with 2 additions and 0 deletions.
2 changes: 2 additions & 0 deletions python/cugraph/cugraph/gnn/data_loading/dist_sampler.py
Original file line number Diff line number Diff line change
Expand Up @@ -236,6 +236,8 @@ def __sample_from_nodes_func(
random_state: int,
assume_equal_input_size: bool,
) -> Union[None, Iterator[Tuple[Dict[str, "torch.Tensor"], int, int]]]:
torch = import_optional("torch")

current_batches = torch.arange(
batch_id_start + call_id * batches_per_call,
batch_id_start
Expand Down

0 comments on commit a22ce90

Please sign in to comment.