-
Notifications
You must be signed in to change notification settings - Fork 61
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
[Doc] New API example Notebook 2 - Link Prediction (#846)
*Issue #, if available:* *Description of changes:* This PR add a new Jupyter notebook example for demonstrating link prediction pipeline. For the notebook, this PR also adds a simple RGCN model in the `demo_models.py` file. By submitting this pull request, I confirm that you can use, modify, copy, and redistribute this contribution, under the terms of your choice. --------- Co-authored-by: Ubuntu <[email protected]>
- Loading branch information
Showing
2 changed files
with
423 additions
and
1 deletion.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,362 @@ | ||
{ | ||
"cells": [ | ||
{ | ||
"cell_type": "markdown", | ||
"metadata": {}, | ||
"source": [ | ||
"# Notebook 2: Use GraphStorm APIs for Building a Link Prediction Pipeline\n", | ||
"\n", | ||
"This notebook demonstrates how to use GraphStorm's APIs to create a graph machine learning pipeline for a link prediction task.\n", | ||
"\n", | ||
"In this notebook, we modify the RGCN model used in the Notebook 1 to adapt to link prediction tasks and use it to conduct link prediction on the ACM dataset created by the **Notebook_0_Data_Prepare**. \n", | ||
"\n", | ||
"### Prerequsites\n", | ||
"\n", | ||
"- GraphStorm installed using pip. Please find [more details on installation of GraphStorm](https://graphstorm.readthedocs.io/en/latest/install/env-setup.html#setup-graphstorm-with-pip-packages).\n", | ||
"- ACM data created in the [Notebook 0: Data Prepare](https://graphstorm.readthedocs.io/en/latest/notebooks/Notebook_0_Data_Prepare.html), and is stored in the `./acm_gs_1p/` folder.\n", | ||
"- Installation of supporting libraries, e.g., matplotlib." | ||
] | ||
}, | ||
{ | ||
"cell_type": "code", | ||
"execution_count": 1, | ||
"metadata": {}, | ||
"outputs": [], | ||
"source": [ | ||
"# Setup log level in Jupyter Notebook\n", | ||
"import logging\n", | ||
"logging.basicConfig(level=20)" | ||
] | ||
}, | ||
{ | ||
"cell_type": "markdown", | ||
"metadata": {}, | ||
"source": [ | ||
"---\n", | ||
"The major steps of creating a link prediction pipeline are same as the node classification pipeline in the Notebook 1. In this notebook, we will only highlight the different components for clarity.\n", | ||
"\n", | ||
"### 0. Initialize the GraphStorm Standalone Environment" | ||
] | ||
}, | ||
{ | ||
"cell_type": "code", | ||
"execution_count": 2, | ||
"metadata": { | ||
"tags": [] | ||
}, | ||
"outputs": [], | ||
"source": [ | ||
"import graphstorm as gs\n", | ||
"gs.initialize()" | ||
] | ||
}, | ||
{ | ||
"cell_type": "markdown", | ||
"metadata": {}, | ||
"source": [ | ||
"### 1. Setup GraphStorm Dataset and DataLoaders\n" | ||
] | ||
}, | ||
{ | ||
"cell_type": "code", | ||
"execution_count": 3, | ||
"metadata": { | ||
"tags": [] | ||
}, | ||
"outputs": [], | ||
"source": [ | ||
"nfeats_4_modeling = {'author':['feat'], 'paper':['feat'],'subject':['feat']}\n", | ||
"\n", | ||
"# create a GraphStorm Dataset for the ACM graph data generated in the Notebook 0\n", | ||
"acm_data = gs.dataloading.GSgnnData(part_config='./acm_gs_1p/acm.json', node_feat_field=nfeats_4_modeling)" | ||
] | ||
}, | ||
{ | ||
"cell_type": "markdown", | ||
"metadata": {}, | ||
"source": [ | ||
"Because link prediction needs both positive and negative edges for training, we use GraphStorm's `GSgnnLinkPredictionDataloader` which is dedicated for link prediction dataloading. This class takes some common arugments as these `NodePredictionDataloader`s, such as `dataset`, `target_idx`, `node_feats`, and `batch_size`. It also takes some link prediction-related arguments, e.g., `num_negative_edges`, `exlude_training_targets`, and etc." | ||
] | ||
}, | ||
{ | ||
"cell_type": "code", | ||
"execution_count": 15, | ||
"metadata": { | ||
"tags": [] | ||
}, | ||
"outputs": [], | ||
"source": [ | ||
"# define dataloaders for training and validation\n", | ||
"train_dataloader = gs.dataloading.GSgnnLinkPredictionDataLoader(\n", | ||
" dataset=acm_data,\n", | ||
" target_idx=acm_data.get_edge_train_set(etypes=[('paper', 'citing', 'paper')]),\n", | ||
" fanout=[20, 20],\n", | ||
" num_negative_edges=10,\n", | ||
" node_feats=nfeats_4_modeling,\n", | ||
" batch_size=64,\n", | ||
" exclude_training_targets=False,\n", | ||
" reverse_edge_types_map=[\"paper,citing,cited,paper\"],\n", | ||
" train_task=True)\n", | ||
"val_dataloader = gs.dataloading.GSgnnLinkPredictionTestDataLoader(\n", | ||
" dataset=acm_data,\n", | ||
" target_idx=acm_data.get_edge_val_set(etypes=[('paper', 'citing', 'paper')]),\n", | ||
" fanout=[100, 100],\n", | ||
" num_negative_edges=100,\n", | ||
" node_feats=nfeats_4_modeling,\n", | ||
" batch_size=256)\n", | ||
"test_dataloader = gs.dataloading.GSgnnLinkPredictionTestDataLoader(\n", | ||
" dataset=acm_data,\n", | ||
" target_idx=acm_data.get_edge_test_set(etypes=[('paper', 'citing', 'paper')]),\n", | ||
" fanout=[100, 100],\n", | ||
" num_negative_edges=100,\n", | ||
" node_feats=nfeats_4_modeling,\n", | ||
" batch_size=256)" | ||
] | ||
}, | ||
{ | ||
"cell_type": "markdown", | ||
"metadata": {}, | ||
"source": [ | ||
"### 2. Create a GraphStorm-compatible RGCN Model for Link Prediction \n", | ||
"\n", | ||
"For the link prediction task, we modified the RGCN model used for node classification to adopt to link prediction task. Users can find the details in the `demon_models.py` file." | ||
] | ||
}, | ||
{ | ||
"cell_type": "code", | ||
"execution_count": 5, | ||
"metadata": { | ||
"tags": [] | ||
}, | ||
"outputs": [], | ||
"source": [ | ||
"# import a simplified RGCN model for node classification\n", | ||
"from demo_models import RgcnLPModel\n", | ||
"\n", | ||
"model = RgcnLPModel(g=acm_data.g,\n", | ||
" num_hid_layers=2,\n", | ||
" node_feat_field=nfeats_4_modeling,\n", | ||
" hid_size=128)" | ||
] | ||
}, | ||
{ | ||
"cell_type": "markdown", | ||
"metadata": {}, | ||
"source": [ | ||
"### 3. Setup a GraphStorm Evaluator\n", | ||
"\n", | ||
"Here we change evaluator to a `GSgnnMrrLPEvaluator` that uses \"mrr\" as the metric dedicated for evaluation of link prediction performance." | ||
] | ||
}, | ||
{ | ||
"cell_type": "code", | ||
"execution_count": 6, | ||
"metadata": { | ||
"tags": [] | ||
}, | ||
"outputs": [], | ||
"source": [ | ||
"# setup a link prediction evaluator for the trainer\n", | ||
"evaluator = gs.eval.GSgnnMrrLPEvaluator(eval_frequency=1000)" | ||
] | ||
}, | ||
{ | ||
"cell_type": "markdown", | ||
"metadata": {}, | ||
"source": [ | ||
"### 4. Setup a Trainer and Training\n", | ||
"\n", | ||
"GraphStorm has the `GSgnnLinkPredictionTrainer` for link prediction training loop. The way of constructing this trainer and calling `fit()` method are same as the `GSgnnNodePredictionTrainer` used in the Notebook 1." | ||
] | ||
}, | ||
{ | ||
"cell_type": "code", | ||
"execution_count": 7, | ||
"metadata": { | ||
"scrolled": true, | ||
"tags": [] | ||
}, | ||
"outputs": [], | ||
"source": [ | ||
"# create a GraphStorm link prediction task trainer for the RGCN model\n", | ||
"trainer = gs.trainer.GSgnnLinkPredictionTrainer(model, topk_model_to_save=1)\n", | ||
"trainer.setup_evaluator(evaluator)\n", | ||
"trainer.setup_device(gs.utils.get_device())" | ||
] | ||
}, | ||
{ | ||
"cell_type": "code", | ||
"execution_count": 16, | ||
"metadata": { | ||
"tags": [] | ||
}, | ||
"outputs": [], | ||
"source": [ | ||
"# Train the model with the trainer using fit() function\n", | ||
"trainer.fit(train_loader=train_dataloader,\n", | ||
" val_loader=val_dataloader,\n", | ||
" test_loader=test_dataloader,\n", | ||
" num_epochs=5,\n", | ||
" save_model_path='a_save_path/',\n", | ||
" save_model_frequency=1000,\n", | ||
" use_mini_batch_infer=True)" | ||
] | ||
}, | ||
{ | ||
"cell_type": "markdown", | ||
"metadata": {}, | ||
"source": [ | ||
"### (Optional) 5. Visualize Model Performance History\n", | ||
"\n", | ||
"Same as the node classification pipeline, we can use the history stored in the evaluator." | ||
] | ||
}, | ||
{ | ||
"cell_type": "code", | ||
"execution_count": 17, | ||
"metadata": { | ||
"tags": [] | ||
}, | ||
"outputs": [], | ||
"source": [ | ||
"import matplotlib.pyplot as plt\n", | ||
"\n", | ||
"# extract evaluation history of metrics from the trainer's evaluator:\n", | ||
"val_metrics, test_metrics = [], []\n", | ||
"for val_metric, test_metric in trainer.evaluator.history:\n", | ||
" val_metrics.append(val_metric['mrr'])\n", | ||
" test_metrics.append(test_metric['mrr'])\n", | ||
"\n", | ||
"# plot the performance curves\n", | ||
"fig, ax = plt.subplots()\n", | ||
"ax.plot(val_metrics, label='val')\n", | ||
"ax.plot(test_metrics, label='test')\n", | ||
"ax.set(xlabel='Epoch', ylabel='Mrr')\n", | ||
"ax.legend(loc='best')" | ||
] | ||
}, | ||
{ | ||
"cell_type": "markdown", | ||
"metadata": {}, | ||
"source": [ | ||
"### 6. Inference with the Trained Model\n", | ||
"\n", | ||
"The operations of model restore are same as those used in the Notebook 1. Users can find the best model path first, and use model's `restore_model()` to load the trained model file." | ||
] | ||
}, | ||
{ | ||
"cell_type": "code", | ||
"execution_count": 18, | ||
"metadata": { | ||
"tags": [] | ||
}, | ||
"outputs": [], | ||
"source": [ | ||
"# after training, the best model is saved to disk:\n", | ||
"best_model_path = trainer.get_best_model_path()\n", | ||
"print('Best model path:', best_model_path)" | ||
] | ||
}, | ||
{ | ||
"cell_type": "code", | ||
"execution_count": 19, | ||
"metadata": {}, | ||
"outputs": [], | ||
"source": [ | ||
"# we can restore the model from the saved path using the model's restore_model() function.\n", | ||
"model.restore_model(best_model_path)" | ||
] | ||
}, | ||
{ | ||
"cell_type": "markdown", | ||
"metadata": {}, | ||
"source": [ | ||
"To do inference, users can either create a new dataloader as the following code does, or reuse one of the dataloaders defined in training." | ||
] | ||
}, | ||
{ | ||
"cell_type": "code", | ||
"execution_count": 12, | ||
"metadata": {}, | ||
"outputs": [], | ||
"source": [ | ||
"# Setup dataloader for inference\n", | ||
"infer_dataloader = gs.dataloading.GSgnnLinkPredictionTestDataLoader(\n", | ||
" dataset=acm_data,\n", | ||
" target_idx=acm_data.get_edge_infer_set(etypes=[('paper', 'citing', 'paper')]),\n", | ||
" fanout=[100, 100],\n", | ||
" num_negative_edges=100,\n", | ||
" node_feats=nfeats_4_modeling,\n", | ||
" batch_size=256)" | ||
] | ||
}, | ||
{ | ||
"cell_type": "markdown", | ||
"metadata": {}, | ||
"source": [ | ||
"Now we can define a `GSgnnLinkPredictionInferrer` by giving the restored model and do inference by calling its `infer()` method." | ||
] | ||
}, | ||
{ | ||
"cell_type": "code", | ||
"execution_count": 20, | ||
"metadata": {}, | ||
"outputs": [], | ||
"source": [ | ||
"# Create an Inferrer object\n", | ||
"infer = gs.inference.GSgnnLinkPredictionInferrer(model)\n", | ||
"\n", | ||
"# Run inference on the inference dataset\n", | ||
"infer.infer(acm_data,\n", | ||
" infer_dataloader,\n", | ||
" save_embed_path='infer/embeddings',\n", | ||
" use_mini_batch_infer=True)" | ||
] | ||
}, | ||
{ | ||
"cell_type": "markdown", | ||
"metadata": {}, | ||
"source": [ | ||
"For link prediction task, the inference outputs are embeddings of all nodes in the inference graph." | ||
] | ||
}, | ||
{ | ||
"cell_type": "code", | ||
"execution_count": 21, | ||
"metadata": {}, | ||
"outputs": [], | ||
"source": [ | ||
"# The GNN embeddings of all nodes in the inference graph are saved to the folder named after the target_ntype\n", | ||
"!ls -lh infer/embeddings/paper" | ||
] | ||
}, | ||
{ | ||
"cell_type": "code", | ||
"execution_count": null, | ||
"metadata": {}, | ||
"outputs": [], | ||
"source": [] | ||
} | ||
], | ||
"metadata": { | ||
"kernelspec": { | ||
"display_name": "gsf", | ||
"language": "python", | ||
"name": "gsf" | ||
}, | ||
"language_info": { | ||
"codemirror_mode": { | ||
"name": "ipython", | ||
"version": 3 | ||
}, | ||
"file_extension": ".py", | ||
"mimetype": "text/x-python", | ||
"name": "python", | ||
"nbconvert_exporter": "python", | ||
"pygments_lexer": "ipython3", | ||
"version": "3.9.18" | ||
} | ||
}, | ||
"nbformat": 4, | ||
"nbformat_minor": 4 | ||
} |
Oops, something went wrong.