Skip to content

Commit

Permalink
Merge branch 'main' into gsp_custom_split
Browse files Browse the repository at this point in the history
  • Loading branch information
jalencato authored May 20, 2024
2 parents 11b4e9f + 001b43a commit 4a0bd9a
Show file tree
Hide file tree
Showing 4 changed files with 36 additions and 13 deletions.
2 changes: 1 addition & 1 deletion graphstorm-processing/docker/push_gsprocessing_image.sh
Original file line number Diff line number Diff line change
Expand Up @@ -58,7 +58,7 @@ parse_params() {
EXEC_ENV="${2-}"
shift
;;
-a | --architecture)
-c | --architecture)
ARCH="${2-}"
shift
;;
Expand Down
15 changes: 13 additions & 2 deletions python/graphstorm/config/argument.py
Original file line number Diff line number Diff line change
Expand Up @@ -1695,9 +1695,16 @@ def target_etype(self):
def remove_target_edge_type(self):
""" Whether to remove the training target edge type for message passing.
Will set the fanout of training target edge type as zero
Will set the fanout of training target edge type as zero. Only used
with edge classification.
Only used with edge classification
If the edge classification is to predict the existence of an edge between
two nodes, we should remove the target edge in the message passing to
avoid information leak.
If it's to predict some attributes associated with an edge, we may not need
to remove the target edge.
Since we don't know what to predict, to be safe, we should remove the target
edge in message passing by default.
"""
# pylint: disable=no-member
if hasattr(self, "_remove_target_edge_type"):
Expand All @@ -1706,6 +1713,10 @@ def remove_target_edge_type(self):

# By default, remove training target etype during
# message passing to avoid information leakage
logging.warning("remove_target_edge_type is set to True by default. "
"If your edge classification task is not predicting "
"the existence of the target edge, we suggest you to "
"set it to False.")
return True

@property
Expand Down
8 changes: 4 additions & 4 deletions python/graphstorm/model/node_gnn.py
Original file line number Diff line number Diff line change
Expand Up @@ -323,17 +323,17 @@ def node_mini_batch_predict(model, emb, loader, return_proba=True, return_label=
# TODO(zhengda) I need to check if the data loader only returns target nodes.
model.eval()
with th.no_grad():
for input_nodes, seeds, _ in loader:
for ntype, in_nodes in input_nodes.items():
for _, seeds, _ in loader: # seeds are target nodes
for ntype, seed_nodes in seeds.items():
if isinstance(model.decoder, th.nn.ModuleDict):
assert ntype in model.decoder, f"Node type {ntype} not in decoder"
decoder = model.decoder[ntype]
else:
decoder = model.decoder
if return_proba:
pred = decoder.predict_proba(emb[ntype][in_nodes].to(device))
pred = decoder.predict_proba(emb[ntype][seed_nodes].to(device))
else:
pred = decoder.predict(emb[ntype][in_nodes].to(device))
pred = decoder.predict(emb[ntype][seed_nodes].to(device))
if ntype in preds:
preds[ntype].append(pred.cpu())
else:
Expand Down
24 changes: 18 additions & 6 deletions tests/unit-tests/test_gnn.py
Original file line number Diff line number Diff line change
Expand Up @@ -275,15 +275,27 @@ def require_cache_embed(self):
dataloader2 = GSgnnNodeDataLoader(data, target_nidx, fanout=[-1, -1],
batch_size=10, label_field='label',
node_feats='feat', train_task=False)
pred2, _, labels2 = node_mini_batch_gnn_predict(model, dataloader2, return_label=True)
# Call GNN mini-batch inference
pred2_gnn_pred, _, labels2_gnn_pred, = node_mini_batch_gnn_predict(model, dataloader2, return_label=True)
# Call last layer mini-batch inference with the GNN dataloader
pred2_pred, labels2_pred = node_mini_batch_predict(model, embs, dataloader2, return_label=True)
if isinstance(pred1,dict):
assert len(pred1) == len(pred2) and len(labels1) == len(labels2)
assert len(pred1) == len(pred2_gnn_pred) and len(labels1) == len(labels2_gnn_pred)
assert len(pred1) == len(pred2_pred) and len(labels1) == len(labels2_pred)
for ntype in pred1:
assert_almost_equal(pred1[ntype][0:len(pred1)].numpy(), pred2[ntype][0:len(pred2)].numpy(), decimal=5)
assert_equal(labels1[ntype].numpy(), labels2[ntype].numpy())
assert_almost_equal(pred1[ntype][0:len(pred1)].numpy(),
pred2_gnn_pred[ntype][0:len(pred2_gnn_pred)].numpy(), decimal=5)
assert_equal(labels1[ntype].numpy(), labels2_gnn_pred[ntype].numpy())
assert_almost_equal(pred1[ntype][0:len(pred1)].numpy(),
pred2_pred[ntype][0:len(pred2_pred)].numpy(), decimal=5)
assert_equal(labels1[ntype].numpy(), labels2_pred[ntype].numpy())
else:
assert_almost_equal(pred1[0:len(pred1)].numpy(), pred2[0:len(pred2)].numpy(), decimal=5)
assert_equal(labels1.numpy(), labels2.numpy())
assert_almost_equal(pred1[0:len(pred1)].numpy(),
pred2_gnn_pred[0:len(pred2_gnn_pred)].numpy(), decimal=5)
assert_equal(labels1.numpy(), labels2_gnn_pred.numpy())
assert_almost_equal(pred1[0:len(pred1)].numpy(),
pred2_pred[0:len(pred2_pred)].numpy(), decimal=5)
assert_equal(labels1.numpy(), labels2_pred.numpy())

# Test the return_proba argument.
pred3, labels3 = node_mini_batch_predict(model, embs, dataloader1, return_proba=True, return_label=True)
Expand Down

0 comments on commit 4a0bd9a

Please sign in to comment.