Skip to content

Commit

Permalink
node loader
Browse files Browse the repository at this point in the history
  • Loading branch information
alexbarghi-nv committed Nov 11, 2023
1 parent 0b05029 commit 801a302
Showing 1 changed file with 4 additions and 4 deletions.
8 changes: 4 additions & 4 deletions python/cugraph-pyg/cugraph_pyg/loader/cugraph_node_loader.py
Original file line number Diff line number Diff line change
Expand Up @@ -51,10 +51,10 @@ def __init__(
feature_store: CuGraphStore,
graph_store: CuGraphStore,
input_nodes: InputNodes = None,
*,
batch_size: int = 0,
*,
shuffle: bool = False,
drop_last: bool = False,
drop_last: bool = True,
edge_types: Sequence[Tuple[str]] = None,
directory: Union[str, tempfile.TemporaryDirectory] = None,
input_files: List[str] = None,
Expand Down Expand Up @@ -211,7 +211,7 @@ def __init__(

# Truncate if we can't evenly divide the input array
stop = (len(input_nodes) // batch_size) * batch_size
input_nodes, remainder = cupy.array_split(stop)
input_nodes, remainder = cupy.array_split(input_nodes, [stop])

# Split into batches
input_nodes = cupy.split(input_nodes, len(input_nodes) // batch_size)
Expand All @@ -227,7 +227,7 @@ def __init__(
{
"start": batch_i,
"batch": cupy.full(
batch_size, batch_num + starting_batch_id, dtype="int32"
len(batch_i), batch_num + starting_batch_id, dtype="int32"
),
}
),
Expand Down

0 comments on commit 801a302

Please sign in to comment.