Skip to content

Commit

Permalink
node type
Browse files Browse the repository at this point in the history
  • Loading branch information
jalencato committed Dec 12, 2023
1 parent 5e7cbb8 commit a02b41e
Show file tree
Hide file tree
Showing 2 changed files with 9 additions and 3 deletions.
9 changes: 6 additions & 3 deletions python/graphstorm/config/argument.py
Original file line number Diff line number Diff line change
Expand Up @@ -149,6 +149,7 @@ def __init__(self, cmd_args):
# Override class attributes using command-line arguments
self.override_arguments(cmd_args)
self.local_rank = cmd_args.local_rank
self.is_homo = False

logging.debug(str(configuration))
cmd_args_dict = cmd_args.__dict__
Expand Down Expand Up @@ -1573,9 +1574,11 @@ def target_ntype(self):
""" The node type for prediction
"""
# pylint: disable=no-member
assert hasattr(self, "_target_ntype"), \
"Must provide the target ntype through target_ntype"
return self._target_ntype
if hasattr(self, "_target_ntype"):
return self._target_ntype
else:
self.is_homo = True
return "_N"

@property
def eval_target_ntype(self):
Expand Down
3 changes: 3 additions & 0 deletions python/graphstorm/run/gsgnn_np/gsgnn_np.py
Original file line number Diff line number Diff line change
Expand Up @@ -73,6 +73,9 @@ def main(config_args):
node_feat_field=config.node_feat_name,
label_field=config.label_field,
lm_feat_ntypes=get_lm_ntypes(config.node_lm_configs))
if config.is_homo:
assert train_data.g.ntypes == ["_N"], "It is required to be a homogeneous graph " \
"when not providing target_ntype on node task"
model = gs.create_builtin_node_gnn_model(train_data.g, config, train_task=True)

if config.training_method["name"] == "glem":
Expand Down

0 comments on commit a02b41e

Please sign in to comment.