Skip to content

Commit

Permalink
fix style
Browse files Browse the repository at this point in the history
  • Loading branch information
alexbarghi-nv committed Aug 21, 2024
1 parent da53c02 commit b6ecc14
Showing 1 changed file with 19 additions and 8 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.


import cugraph_dgl.dataloading
import pytest

Expand Down Expand Up @@ -50,7 +51,10 @@ def test_dataloader_basic_homogeneous():

def sample_dgl_graphs(g, train_nid, fanouts, batch_size=1, prob_attr=None):
# Single fanout to match cugraph
sampler = dgl.dataloading.NeighborSampler(fanouts, prob=prob_attr,)
sampler = dgl.dataloading.NeighborSampler(
fanouts,
prob=prob_attr,
)
dataloader = dgl.dataloading.DataLoader(
g,
train_nid,
Expand All @@ -71,8 +75,13 @@ def sample_dgl_graphs(g, train_nid, fanouts, batch_size=1, prob_attr=None):
return dgl_output


def sample_cugraph_dgl_graphs(cugraph_g, train_nid, fanouts, batch_size=1, prob_attr=None):
sampler = cugraph_dgl.dataloading.NeighborSampler(fanouts, prob=prob_attr,)
def sample_cugraph_dgl_graphs(
cugraph_g, train_nid, fanouts, batch_size=1, prob_attr=None
):
sampler = cugraph_dgl.dataloading.NeighborSampler(
fanouts,
prob=prob_attr,
)

dataloader = cugraph_dgl.dataloading.FutureDataLoader(
cugraph_g,
Expand Down Expand Up @@ -135,17 +144,19 @@ def test_dataloader_biased_homogeneous():
dst = torch.tensor([0, 0, 0, 0, 1, 1, 1, 1])
wgt = torch.tensor([1, 1, 2, 0, 0, 0, 2, 1], dtype=torch.float32)

train_nid = torch.tensor([0,1])
train_nid = torch.tensor([0, 1])
# Create a heterograph with 3 node types and 3 edges types.
dgl_g = dgl.graph((src, dst))
dgl_g.edata['wgt'] = wgt
dgl_g.edata["wgt"] = wgt

cugraph_g = cugraph_dgl.Graph(is_multi_gpu=False)
cugraph_g.add_nodes(9)
cugraph_g.add_edges(u=src, v=dst, data={'wgt': wgt})
cugraph_g.add_edges(u=src, v=dst, data={"wgt": wgt})

dgl_output = sample_dgl_graphs(dgl_g, train_nid, [4], batch_size=2, prob_attr='wgt')
cugraph_output = sample_cugraph_dgl_graphs(cugraph_g, train_nid, [4], batch_size=2, prob_attr='wgt')
dgl_output = sample_dgl_graphs(dgl_g, train_nid, [4], batch_size=2, prob_attr="wgt")
cugraph_output = sample_cugraph_dgl_graphs(
cugraph_g, train_nid, [4], batch_size=2, prob_attr="wgt"
)

cugraph_output_nodes = cugraph_output[0]["output_nodes"].cpu().numpy()
dgl_output_nodes = dgl_output[0]["output_nodes"].cpu().numpy()
Expand Down

0 comments on commit b6ecc14

Please sign in to comment.