diff --git a/python/cugraph-pyg/cugraph_pyg/loader/cugraph_node_loader.py b/python/cugraph-pyg/cugraph_pyg/loader/cugraph_node_loader.py index 79bc975ca64..bcfaf579820 100644 --- a/python/cugraph-pyg/cugraph_pyg/loader/cugraph_node_loader.py +++ b/python/cugraph-pyg/cugraph_pyg/loader/cugraph_node_loader.py @@ -151,10 +151,13 @@ def __init__( self.__input_files = iter(input_files) return - input_node_info = torch_geometric.loader.utils.get_input_nodes( - (feature_store, graph_store), - input_nodes, - input_id=None, + # To accommodate DLFW/PyG 2.5 + get_input_nodes = torch_geometric.loader.utils.get_input_nodes + get_input_nodes_kwargs = {} + if "input_id" in get_input_nodes.__annotations__: + get_input_nodes_kwargs["input_id"] = None + input_node_info = get_input_nodes( + (feature_store, graph_store), input_nodes, **get_input_nodes_kwargs ) # PyG 2.4