diff --git a/python/graphstorm/inference/__init__.py b/python/graphstorm/inference/__init__.py index aad67b6d78..92b4f7f965 100644 --- a/python/graphstorm/inference/__init__.py +++ b/python/graphstorm/inference/__init__.py @@ -19,3 +19,4 @@ from .lp_infer import GSgnnLinkPredictionInferrer from .np_infer import GSgnnNodePredictionInferrer from .ep_infer import GSgnnEdgePredictionInferrer +from .emb_infer import GSgnnEmbGenInferer diff --git a/python/graphstorm/inference/emb_infer.py b/python/graphstorm/inference/emb_infer.py new file mode 100644 index 0000000000..413c43eebe --- /dev/null +++ b/python/graphstorm/inference/emb_infer.py @@ -0,0 +1,111 @@ +""" + Copyright 2023 Contributors + + Licensed under the Apache License, Version 2.0 (the "License"); + you may not use this file except in compliance with the License. + You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + + Unless required by applicable law or agreed to in writing, software + distributed under the License is distributed on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + See the License for the specific language governing permissions and + limitations under the License. + + Inferer wrapper for embedding generation. +""" +import logging +from graphstorm.config import (BUILTIN_TASK_NODE_CLASSIFICATION, + BUILTIN_TASK_NODE_REGRESSION, + BUILTIN_TASK_EDGE_CLASSIFICATION, + BUILTIN_TASK_EDGE_REGRESSION, + BUILTIN_TASK_LINK_PREDICTION) +from .graphstorm_infer import GSInferrer +from ..model.utils import save_embeddings as save_gsgnn_embeddings +from ..model.gnn import do_full_graph_inference, do_mini_batch_inference +from ..utils import sys_tracker, get_rank, get_world_size, barrier + + +class GSgnnEmbGenInferer(GSInferrer): + """ Embedding Generation inferrer. + + This is a high-level inferrer wrapper that can be used directly + to generate embedding for inferer. + + Parameters + ---------- + model : GSgnnNodeModel + The GNN model with different task. + """ + def infer(self, data, task_type, save_embed_path, eval_fanout, + use_mini_batch_infer=False, + node_id_mapping_file=None, + save_embed_format="pytorch"): + """ Do Embedding Generating + + Generate node embeddings and save into disk. + + Parameters + ---------- + data: GSgnnData + The GraphStorm dataset + task_type : str + task_type must be one of graphstorm builtin task types + save_embed_path : str + The path where the GNN embeddings will be saved. + eval_fanout: list of int + The fanout of each GNN layers used in evaluation and inference. + use_mini_batch_infer : bool + Whether to use mini-batch inference when computing node embeddings. + node_id_mapping_file: str + Path to the file storing node id mapping generated by the + graph partition algorithm. + save_embed_format : str + Specify the format of saved embeddings. + """ + + device = self.device + + assert save_embed_path is not None, \ + "It requires save embed path for gs_gen_node_embedding" + + sys_tracker.check('start generating embedding') + self._model.eval() + + # infer ntypes must be sorted for node embedding saving + if task_type == BUILTIN_TASK_LINK_PREDICTION: + infer_ntypes = None + elif task_type in {BUILTIN_TASK_NODE_REGRESSION, BUILTIN_TASK_NODE_CLASSIFICATION}: + infer_ntypes = sorted(data.infer_idxs) + elif task_type in {BUILTIN_TASK_EDGE_CLASSIFICATION, BUILTIN_TASK_EDGE_REGRESSION}: + infer_ntypes = set() + for etype in data.infer_idxs: + infer_ntypes.add(etype[0]) + infer_ntypes.add(etype[2]) + infer_ntypes = sorted(infer_ntypes) + else: + raise TypeError("Not supported for task type: ", task_type) + + if use_mini_batch_infer: + embs = do_mini_batch_inference(self._model, data, fanout=eval_fanout, + edge_mask=None, + task_tracker=self.task_tracker, + infer_ntypes=infer_ntypes) + else: + embs = do_full_graph_inference(self._model, data, fanout=eval_fanout, + edge_mask=None, + task_tracker=self.task_tracker) + if infer_ntypes: + embs = {ntype: embs[ntype] for ntype in infer_ntypes} + + if get_rank() == 0: + logging.info("save embeddings to %s", save_embed_path) + + save_gsgnn_embeddings(save_embed_path, embs, get_rank(), + get_world_size(), + device=device, + node_id_mapping_file=node_id_mapping_file, + save_embed_format=save_embed_format) + barrier() + sys_tracker.check('save embeddings') diff --git a/python/graphstorm/run/gs_gen_node_embedding.py b/python/graphstorm/run/gs_gen_node_embedding.py new file mode 100644 index 0000000000..f879cfabf3 --- /dev/null +++ b/python/graphstorm/run/gs_gen_node_embedding.py @@ -0,0 +1,46 @@ +""" + Copyright 2023 Contributors + + Licensed under the Apache License, Version 2.0 (the "License"); + you may not use this file except in compliance with the License. + You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + + Unless required by applicable law or agreed to in writing, software + distributed under the License is distributed on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + See the License for the specific language governing permissions and + limitations under the License. + + Entry point for running embedding generating tasks. + + Run as: + python3 -m graphstorm.run.gs_gen_node_embedding +""" +import os +import logging + +from .launch import get_argument_parser +from .launch import check_input_arguments +from .launch import submit_jobs + +def main(): + """ Main function + """ + parser = get_argument_parser() + args, exec_script_args = parser.parse_known_args() + check_input_arguments(args) + + lib_dir = os.path.abspath(os.path.dirname(__file__)) + cmd = "gsgnn_emb/gsgnn_node_emb.py" + cmd_path = os.path.join(lib_dir, cmd) + exec_script_args = [cmd_path] + exec_script_args + + submit_jobs(args, exec_script_args) + +if __name__ == "__main__": + FMT = "%(asctime)s %(levelname)s %(message)s" + logging.basicConfig(format=FMT, level=logging.INFO) + main() + \ No newline at end of file diff --git a/python/graphstorm/run/gsgnn_emb/__init__.py b/python/graphstorm/run/gsgnn_emb/__init__.py new file mode 100644 index 0000000000..e69de29bb2 diff --git a/python/graphstorm/run/gsgnn_emb/gsgnn_node_emb.py b/python/graphstorm/run/gsgnn_emb/gsgnn_node_emb.py new file mode 100644 index 0000000000..191471d9f6 --- /dev/null +++ b/python/graphstorm/run/gsgnn_emb/gsgnn_node_emb.py @@ -0,0 +1,104 @@ +""" + Copyright 2023 Contributors + + Licensed under the Apache License, Version 2.0 (the "License"); + you may not use this file except in compliance with the License. + You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + + Unless required by applicable law or agreed to in writing, software + distributed under the License is distributed on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + See the License for the specific language governing permissions and + limitations under the License. + + GSgnn pure gpu generate embeddings. +""" +import graphstorm as gs +from graphstorm.config import get_argument_parser +from graphstorm.config import GSConfig +from graphstorm.utils import rt_profiler, sys_tracker, setup_device, use_wholegraph +from graphstorm.dataloading import (GSgnnEdgeInferData, GSgnnNodeInferData) +from graphstorm.config import (BUILTIN_TASK_NODE_CLASSIFICATION, + BUILTIN_TASK_NODE_REGRESSION, + BUILTIN_TASK_EDGE_CLASSIFICATION, + BUILTIN_TASK_EDGE_REGRESSION, + BUILTIN_TASK_LINK_PREDICTION) +from graphstorm.inference import GSgnnEmbGenInferer + + +def main(config_args): + """ main function + """ + config = GSConfig(config_args) + config.verify_arguments(True) + + gs.initialize(ip_config=config.ip_config, backend=config.backend, + use_wholegraph=use_wholegraph(config.part_config)) + rt_profiler.init(config.profile_path, rank=gs.get_rank()) + sys_tracker.init(config.verbose, rank=gs.get_rank()) + device = setup_device(config.local_rank) + tracker = gs.create_builtin_task_tracker(config) + if gs.get_rank() == 0: + tracker.log_params(config.__dict__) + + if config.task_type == BUILTIN_TASK_LINK_PREDICTION: + input_data = GSgnnEdgeInferData(config.graph_name, + config.part_config, + eval_etypes=config.eval_etype, + node_feat_field=config.node_feat_name) + elif config.task_type in {BUILTIN_TASK_NODE_REGRESSION, BUILTIN_TASK_NODE_CLASSIFICATION}: + input_data = GSgnnNodeInferData(config.graph_name, + config.part_config, + eval_ntypes=config.target_ntype, + node_feat_field=config.node_feat_name) + elif config.task_type in {BUILTIN_TASK_EDGE_CLASSIFICATION, BUILTIN_TASK_EDGE_REGRESSION}: + input_data = GSgnnEdgeInferData(config.graph_name, + config.part_config, + eval_etypes=config.target_etype, + node_feat_field=config.node_feat_name) + else: + raise TypeError("Not supported for task type: ", config.task_type) + + # assert the setting for the graphstorm embedding generation. + assert config.save_embed_path is not None, \ + "save embeded path cannot be none for gs_gen_node_embeddings" + assert config.restore_model_path is not None, \ + "restore model path cannot be none for gs_gen_node_embeddings" + + # load the model + if config.task_type == BUILTIN_TASK_LINK_PREDICTION: + model = gs.create_builtin_lp_gnn_model(input_data.g, config, train_task=False) + elif config.task_type in {BUILTIN_TASK_NODE_REGRESSION, BUILTIN_TASK_NODE_CLASSIFICATION}: + model = gs.create_builtin_node_gnn_model(input_data.g, config, train_task=False) + elif config.task_type in {BUILTIN_TASK_EDGE_CLASSIFICATION, BUILTIN_TASK_EDGE_REGRESSION}: + model = gs.create_builtin_edge_gnn_model(input_data.g, config, train_task=False) + else: + raise TypeError("Not supported for task type: ", config.task_type) + model.restore_model(config.restore_model_path, + model_layer_to_load=config.restore_model_layers) + + # start to infer + emb_generator = GSgnnEmbGenInferer(model) + emb_generator.setup_device(device=device) + + emb_generator.infer(input_data, config.task_type, + save_embed_path=config.save_embed_path, + eval_fanout=config.eval_fanout, + use_mini_batch_infer=config.use_mini_batch_infer, + node_id_mapping_file=config.node_id_mapping_file, + save_embed_format=config.save_embed_format) + +def generate_parser(): + """ Generate an argument parser + """ + parser = get_argument_parser() + return parser + + +if __name__ == '__main__': + arg_parser = generate_parser() + + args = arg_parser.parse_args() + main(args) diff --git a/python/graphstorm/run/gsgnn_ep/gsgnn_ep.py b/python/graphstorm/run/gsgnn_ep/gsgnn_ep.py index 7acb4cd134..20e3bbb745 100644 --- a/python/graphstorm/run/gsgnn_ep/gsgnn_ep.py +++ b/python/graphstorm/run/gsgnn_ep/gsgnn_ep.py @@ -156,7 +156,8 @@ def main(config_args): save_embeddings(config.save_embed_path, embs, gs.get_rank(), gs.get_world_size(), device=device, - node_id_mapping_file=config.node_id_mapping_file) + node_id_mapping_file=config.node_id_mapping_file, + save_embed_format=config.save_embed_format) def generate_parser(): """ Generate an argument parser diff --git a/python/graphstorm/run/gsgnn_lp/gsgnn_lp.py b/python/graphstorm/run/gsgnn_lp/gsgnn_lp.py index b89e568424..6ce7c76e52 100644 --- a/python/graphstorm/run/gsgnn_lp/gsgnn_lp.py +++ b/python/graphstorm/run/gsgnn_lp/gsgnn_lp.py @@ -202,7 +202,8 @@ def main(config_args): save_embeddings(config.save_embed_path, embeddings, gs.get_rank(), gs.get_world_size(), device=device, - node_id_mapping_file=config.node_id_mapping_file) + node_id_mapping_file=config.node_id_mapping_file, + save_embed_format=config.save_embed_format) def generate_parser(): """ Generate an argument parser diff --git a/python/graphstorm/run/gsgnn_np/gsgnn_np.py b/python/graphstorm/run/gsgnn_np/gsgnn_np.py index 42f78fed99..2d7e8e240e 100644 --- a/python/graphstorm/run/gsgnn_np/gsgnn_np.py +++ b/python/graphstorm/run/gsgnn_np/gsgnn_np.py @@ -156,7 +156,8 @@ def main(config_args): save_embeddings(config.save_embed_path, embeddings, gs.get_rank(), gs.get_world_size(), device=device, - node_id_mapping_file=config.node_id_mapping_file) + node_id_mapping_file=config.node_id_mapping_file, + save_embed_format=config.save_embed_format) def generate_parser(): """ Generate an argument parser diff --git a/python/graphstorm/run/launch.py b/python/graphstorm/run/launch.py index b23d19a263..6d7e1b2b70 100644 --- a/python/graphstorm/run/launch.py +++ b/python/graphstorm/run/launch.py @@ -908,7 +908,7 @@ def check_input_arguments(args): ), "--num-servers must be a positive number." assert ( args.part_config is not None - ), "A user has to specify a partition configuration file with --part-onfig." + ), "A user has to specify a partition configuration file with --part-config." assert ( args.ip_config is not None ), "A user has to specify an IP configuration file with --ip-config." diff --git a/tests/end2end-tests/graphstorm-ec/mgpu_test.sh b/tests/end2end-tests/graphstorm-ec/mgpu_test.sh index 71928c453a..562b4e3817 100644 --- a/tests/end2end-tests/graphstorm-ec/mgpu_test.sh +++ b/tests/end2end-tests/graphstorm-ec/mgpu_test.sh @@ -135,6 +135,15 @@ python3 check_infer.py --train_embout /data/gsgnn_ec/emb/ --infer_embout /data/g error_and_exit $? +echo "**************dataset: Movielens, use gen_node_embeddings to generate embeddings on edge classification" +python3 -m graphstorm.run.gs_gen_node_embedding --workspace $GS_HOME/training_scripts/gsgnn_ep/ --num-trainers $NUM_TRAINERS --use-mini-batch-infer false --num-servers 1 --num-samplers 0 --part-config /data/movielen_100k_multi_label_ec/movie-lens-100k.json --ip-config ip_list.txt --ssh-port 2222 --cf ml_ec.yaml --exclude-training-targets True --multilabel true --num-classes 6 --node-feat-name movie:title user:feat --save-embed-path /data/gsgnn_ec/save-emb/ --restore-model-path /data/gsgnn_ec/epoch-$best_epoch/ --logging-file /tmp/train_log.txt --logging-level debug + +error_and_exit $? + +python3 $GS_HOME/tests/end2end-tests/check_infer.py --train_embout /data/gsgnn_ec/emb/ --infer_embout /data/gsgnn_ec/save-emb/ + +error_and_exit $? + echo "**************dataset: Generated multilabel MovieLens EC, do inference on saved model without test_mask" python3 -m graphstorm.run.gs_edge_classification --inference --workspace $GS_HOME/inference_scripts/ep_infer --num-trainers $NUM_INFO_TRAINERS --num-servers 1 --num-samplers 0 --part-config /data/movielen_100k_ec_no_test_1p_4t/movie-lens-100k.json --ip-config ip_list.txt --ssh-port 2222 --cf ml_ec_infer.yaml --multilabel true --num-classes 6 --node-feat-name movie:title user:feat --use-mini-batch-infer false --save-embed-path /data/gsgnn_ec/infer-emb/ --restore-model-path /data/gsgnn_ec/epoch-$best_epoch/ --save-prediction-path /data/gsgnn_ec/prediction/ --no-validation true diff --git a/tests/end2end-tests/graphstorm-lp/mgpu_test.sh b/tests/end2end-tests/graphstorm-lp/mgpu_test.sh index 9ed2129c04..5484541e0e 100644 --- a/tests/end2end-tests/graphstorm-lp/mgpu_test.sh +++ b/tests/end2end-tests/graphstorm-lp/mgpu_test.sh @@ -192,6 +192,11 @@ then fi rm -fr /data/gsgnn_lp_ml_dot/infer-emb/ +echo "**************dataset: Movielens, use gen_embeddings to generate embeddings on link prediction" +python3 -m graphstorm.run.gs_gen_node_embedding --workspace $GS_HOME/training_scripts/gsgnn_lp --num-trainers $NUM_TRAINERS --num-servers 1 --num-samplers 0 --part-config /data/movielen_100k_lp_train_val_1p_4t/movie-lens-100k.json --ip-config ip_list.txt --ssh-port 2222 --cf ml_lp.yaml --fanout '10,15' --num-layers 2 --use-mini-batch-infer false --eval-batch-size 1024 --use-node-embeddings true --exclude-training-targets True --reverse-edge-types-map user,rating,rating-rev,movie --restore-model-path /data/gsgnn_lp_ml_dot/epoch-$best_epoch_dot/ --save-embed-path /data/gsgnn_lp_ml_dot/save-emb/ --logging-file /tmp/train_log.txt --logging-level debug + +error_and_exit $? + echo "**************dataset: Movielens, do mini-batch inference on saved model, decoder: dot" python3 -m graphstorm.run.gs_link_prediction --inference --workspace $GS_HOME/inference_scripts/lp_infer --num-trainers $NUM_INFO_TRAINERS --num-servers 1 --num-samplers 0 --part-config /data/movielen_100k_lp_train_val_1p_4t/movie-lens-100k.json --ip-config ip_list.txt --ssh-port 2222 --cf ml_lp_infer.yaml --fanout '10,15' --num-layers 2 --use-mini-batch-infer false --use-node-embeddings true --eval-batch-size 1024 --save-embed-path /data/gsgnn_lp_ml_dot/infer-emb/ --restore-model-path /data/gsgnn_lp_ml_dot/epoch-$best_epoch_dot/ --use-mini-batch-infer true --logging-file /tmp/log.txt diff --git a/tests/end2end-tests/graphstorm-nc/mgpu_test.sh b/tests/end2end-tests/graphstorm-nc/mgpu_test.sh index 7a532d257f..4c6c56b19f 100644 --- a/tests/end2end-tests/graphstorm-nc/mgpu_test.sh +++ b/tests/end2end-tests/graphstorm-nc/mgpu_test.sh @@ -137,6 +137,15 @@ python3 $GS_HOME/tests/end2end-tests/check_infer.py --train_embout /data/gsgnn_n error_and_exit $? +echo "**************dataset: Movielens, use gen_embeddings to generate embeddings on node classification" +python3 -m graphstorm.run.gs_gen_node_embedding --workspace $GS_HOME/training_scripts/gsgnn_np/ --num-trainers $NUM_TRAINERS --num-servers 1 --num-samplers 0 --part-config /data/movielen_100k_train_val_1p_4t/movie-lens-100k.json --ip-config ip_list.txt --ssh-port 2222 --cf ml_nc.yaml --restore-model-path /data/gsgnn_nc_ml/epoch-$best_epoch/ --save-embed-path /data/gsgnn_nc_ml/save-emb/ --use-mini-batch-infer false --logging-file /tmp/train_log.txt --logging-level debug + +error_and_exit $? + +python3 $GS_HOME/tests/end2end-tests/check_infer.py --train_embout /data/gsgnn_nc_ml/emb/ --infer_embout /data/gsgnn_nc_ml/save-emb/ + +error_and_exit $? + echo "**************dataset: Movielens, do inference on saved model with mini-batch-infer without test mask" python3 -m graphstorm.run.gs_node_classification --inference --workspace $GS_HOME/inference_scripts/np_infer/ --num-trainers $NUM_INFERs --num-servers 1 --num-samplers 0 --part-config /data/movielen_100k_train_notest_1p_4t/movie-lens-100k.json --ip-config ip_list.txt --ssh-port 2222 --cf ml_nc_infer.yaml --use-mini-batch-infer true --save-embed-path /data/gsgnn_nc_ml/mini-infer-emb/ --restore-model-path /data/gsgnn_nc_ml/epoch-$best_epoch/ --save-prediction-path /data/gsgnn_nc_ml/prediction/ --no-validation true