Skip to content

Commit

Permalink
add for edge type
Browse files Browse the repository at this point in the history
  • Loading branch information
jalencato committed Dec 12, 2023
1 parent 8028eda commit 51879a2
Show file tree
Hide file tree
Showing 2 changed files with 13 additions and 2 deletions.
8 changes: 6 additions & 2 deletions python/graphstorm/config/argument.py
Original file line number Diff line number Diff line change
Expand Up @@ -1577,6 +1577,8 @@ def target_ntype(self):
if hasattr(self, "_target_ntype"):
return self._target_ntype
else:
logging.warning("There is not target ntype provided, "
"will treat as homogeneous graph")
return DEFAULT_NTYPE

@property
Expand Down Expand Up @@ -1650,8 +1652,10 @@ def target_etype(self):
classification/regression. Support multiple tasks when needed.
"""
# pylint: disable=no-member
assert hasattr(self, "_target_etype"), \
"Edge classification task needs a target etype"
if not hasattr(self, "_target_etype"):
logging.warning("There is not target etype provided, "
"will treat as homogeneous graph")
return DEFAULT_ETYPE
assert isinstance(self._target_etype, list), \
"target_etype must be a list in format: " \
"[\"query,clicks,asin\", \"query,search,asin\"]."
Expand Down
7 changes: 7 additions & 0 deletions python/graphstorm/run/gsgnn_ep/gsgnn_ep.py
Original file line number Diff line number Diff line change
Expand Up @@ -70,6 +70,13 @@ def main(config_args):
label_field=config.label_field,
decoder_edge_feat=config.decoder_edge_feat,
lm_feat_ntypes=get_lm_ntypes(config.node_lm_configs))
if config.target_etype == DEFAULT_ETYPE:
assert train_data.g.ntypes == [DEFAULT_NTYPE] and \
train_data.g.etypes == [DEFAULT_ETYPE[1]], \
f"It is required to be a homogeneous graph when not providing " \
f"target_ntype on node task, expect node type {[DEFAULT_NTYPE]} and " \
f"edge type {[DEFAULT_ETYPE[1]]}, but get {train_data.g.ntypes} " \
f"and {train_data.g.etypes}"
model = gs.create_builtin_edge_gnn_model(train_data.g, config, train_task=True)
trainer = GSgnnEdgePredictionTrainer(model, topk_model_to_save=config.topk_model_to_save)
if config.restore_model_path is not None:
Expand Down

0 comments on commit 51879a2

Please sign in to comment.