From 51879a288944030a6cda1dc7651627eb165da534 Mon Sep 17 00:00:00 2001 From: JalenCato Date: Tue, 12 Dec 2023 23:47:10 +0000 Subject: [PATCH] add for edge type --- python/graphstorm/config/argument.py | 8 ++++++-- python/graphstorm/run/gsgnn_ep/gsgnn_ep.py | 7 +++++++ 2 files changed, 13 insertions(+), 2 deletions(-) diff --git a/python/graphstorm/config/argument.py b/python/graphstorm/config/argument.py index 494ddf8fbf..61fee036c1 100644 --- a/python/graphstorm/config/argument.py +++ b/python/graphstorm/config/argument.py @@ -1577,6 +1577,8 @@ def target_ntype(self): if hasattr(self, "_target_ntype"): return self._target_ntype else: + logging.warning("There is not target ntype provided, " + "will treat as homogeneous graph") return DEFAULT_NTYPE @property @@ -1650,8 +1652,10 @@ def target_etype(self): classification/regression. Support multiple tasks when needed. """ # pylint: disable=no-member - assert hasattr(self, "_target_etype"), \ - "Edge classification task needs a target etype" + if not hasattr(self, "_target_etype"): + logging.warning("There is not target etype provided, " + "will treat as homogeneous graph") + return DEFAULT_ETYPE assert isinstance(self._target_etype, list), \ "target_etype must be a list in format: " \ "[\"query,clicks,asin\", \"query,search,asin\"]." diff --git a/python/graphstorm/run/gsgnn_ep/gsgnn_ep.py b/python/graphstorm/run/gsgnn_ep/gsgnn_ep.py index 661c9e14d3..cc82cf3ce2 100644 --- a/python/graphstorm/run/gsgnn_ep/gsgnn_ep.py +++ b/python/graphstorm/run/gsgnn_ep/gsgnn_ep.py @@ -70,6 +70,13 @@ def main(config_args): label_field=config.label_field, decoder_edge_feat=config.decoder_edge_feat, lm_feat_ntypes=get_lm_ntypes(config.node_lm_configs)) + if config.target_etype == DEFAULT_ETYPE: + assert train_data.g.ntypes == [DEFAULT_NTYPE] and \ + train_data.g.etypes == [DEFAULT_ETYPE[1]], \ + f"It is required to be a homogeneous graph when not providing " \ + f"target_ntype on node task, expect node type {[DEFAULT_NTYPE]} and " \ + f"edge type {[DEFAULT_ETYPE[1]]}, but get {train_data.g.ntypes} " \ + f"and {train_data.g.etypes}" model = gs.create_builtin_edge_gnn_model(train_data.g, config, train_task=True) trainer = GSgnnEdgePredictionTrainer(model, topk_model_to_save=config.topk_model_to_save) if config.restore_model_path is not None: