Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Doc] New API example Notebook 2 - Link Prediction #846

Merged
merged 7 commits into from
May 21, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
362 changes: 362 additions & 0 deletions docs/source/notebooks/Notebook_2_LP_Pipeline.ipynb
Original file line number Diff line number Diff line change
@@ -0,0 +1,362 @@
{
"cells": [
{
"cell_type": "markdown",
"metadata": {},
"source": [
"# Notebook 2: Use GraphStorm APIs for Building a Link Prediction Pipeline\n",
"\n",
"This notebook demonstrates how to use GraphStorm's APIs to create a graph machine learning pipeline for a link prediction task.\n",
"\n",
"In this notebook, we modify the RGCN model used in the Notebook 1 to adapt to link prediction tasks and use it to conduct link prediction on the ACM dataset created by the **Notebook_0_Data_Prepare**. \n",
"\n",
"### Prerequsites\n",
"\n",
"- GraphStorm installed using pip. Please find [more details on installation of GraphStorm](https://graphstorm.readthedocs.io/en/latest/install/env-setup.html#setup-graphstorm-with-pip-packages).\n",
"- ACM data created in the [Notebook 0: Data Prepare](https://graphstorm.readthedocs.io/en/latest/notebooks/Notebook_0_Data_Prepare.html), and is stored in the `./acm_gs_1p/` folder.\n",
"- Installation of supporting libraries, e.g., matplotlib."
]
},
{
"cell_type": "code",
"execution_count": 1,
"metadata": {},
"outputs": [],
"source": [
"# Setup log level in Jupyter Notebook\n",
"import logging\n",
"logging.basicConfig(level=20)"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"---\n",
"The major steps of creating a link prediction pipeline are same as the node classification pipeline in the Notebook 1. In this notebook, we will only highlight the different components for clarity.\n",
"\n",
"### 0. Initialize the GraphStorm Standalone Environment"
]
},
{
"cell_type": "code",
"execution_count": 2,
"metadata": {
"tags": []
},
"outputs": [],
"source": [
"import graphstorm as gs\n",
"gs.initialize()"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"### 1. Setup GraphStorm Dataset and DataLoaders\n"
]
},
{
"cell_type": "code",
"execution_count": 3,
"metadata": {
"tags": []
},
"outputs": [],
"source": [
"nfeats_4_modeling = {'author':['feat'], 'paper':['feat'],'subject':['feat']}\n",
"\n",
"# create a GraphStorm Dataset for the ACM graph data generated in the Notebook 0\n",
"acm_data = gs.dataloading.GSgnnData(part_config='./acm_gs_1p/acm.json', node_feat_field=nfeats_4_modeling)"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"Because link prediction needs both positive and negative edges for training, we use GraphStorm's `GSgnnLinkPredictionDataloader` which is dedicated for link prediction dataloading. This class takes some common arugments as these `NodePredictionDataloader`s, such as `dataset`, `target_idx`, `node_feats`, and `batch_size`. It also takes some link prediction-related arguments, e.g., `num_negative_edges`, `exlude_training_targets`, and etc."
]
},
{
"cell_type": "code",
"execution_count": 15,
"metadata": {
"tags": []
},
"outputs": [],
"source": [
"# define dataloaders for training and validation\n",
"train_dataloader = gs.dataloading.GSgnnLinkPredictionDataLoader(\n",
" dataset=acm_data,\n",
" target_idx=acm_data.get_edge_train_set(etypes=[('paper', 'citing', 'paper')]),\n",
" fanout=[20, 20],\n",
" num_negative_edges=10,\n",
" node_feats=nfeats_4_modeling,\n",
" batch_size=64,\n",
" exclude_training_targets=False,\n",
" reverse_edge_types_map=[\"paper,citing,cited,paper\"],\n",
" train_task=True)\n",
"val_dataloader = gs.dataloading.GSgnnLinkPredictionTestDataLoader(\n",
" dataset=acm_data,\n",
" target_idx=acm_data.get_edge_val_set(etypes=[('paper', 'citing', 'paper')]),\n",
" fanout=[100, 100],\n",
" num_negative_edges=100,\n",
" node_feats=nfeats_4_modeling,\n",
" batch_size=256)\n",
"test_dataloader = gs.dataloading.GSgnnLinkPredictionTestDataLoader(\n",
" dataset=acm_data,\n",
" target_idx=acm_data.get_edge_test_set(etypes=[('paper', 'citing', 'paper')]),\n",
" fanout=[100, 100],\n",
" num_negative_edges=100,\n",
" node_feats=nfeats_4_modeling,\n",
" batch_size=256)"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"### 2. Create a GraphStorm-compatible RGCN Model for Link Prediction \n",
"\n",
"For the link prediction task, we modified the RGCN model used for node classification to adopt to link prediction task. Users can find the details in the `demon_models.py` file."
]
},
{
"cell_type": "code",
"execution_count": 5,
"metadata": {
"tags": []
},
"outputs": [],
"source": [
"# import a simplified RGCN model for node classification\n",
"from demo_models import RgcnLPModel\n",
"\n",
"model = RgcnLPModel(g=acm_data.g,\n",
" num_hid_layers=2,\n",
" node_feat_field=nfeats_4_modeling,\n",
" hid_size=128)"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"### 3. Setup a GraphStorm Evaluator\n",
"\n",
"Here we change evaluator to a `GSgnnMrrLPEvaluator` that uses \"mrr\" as the metric dedicated for evaluation of link prediction performance."
]
},
{
"cell_type": "code",
"execution_count": 6,
"metadata": {
"tags": []
},
"outputs": [],
"source": [
"# setup a link prediction evaluator for the trainer\n",
"evaluator = gs.eval.GSgnnMrrLPEvaluator(eval_frequency=1000)"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"### 4. Setup a Trainer and Training\n",
"\n",
"GraphStorm has the `GSgnnLinkPredictionTrainer` for link prediction training loop. The way of constructing this trainer and calling `fit()` method are same as the `GSgnnNodePredictionTrainer` used in the Notebook 1."
]
},
{
"cell_type": "code",
"execution_count": 7,
"metadata": {
"scrolled": true,
"tags": []
},
"outputs": [],
"source": [
"# create a GraphStorm link prediction task trainer for the RGCN model\n",
"trainer = gs.trainer.GSgnnLinkPredictionTrainer(model, topk_model_to_save=1)\n",
"trainer.setup_evaluator(evaluator)\n",
"trainer.setup_device(gs.utils.get_device())"
]
},
{
"cell_type": "code",
"execution_count": 16,
"metadata": {
"tags": []
},
"outputs": [],
"source": [
"# Train the model with the trainer using fit() function\n",
"trainer.fit(train_loader=train_dataloader,\n",
" val_loader=val_dataloader,\n",
" test_loader=test_dataloader,\n",
" num_epochs=5,\n",
" save_model_path='a_save_path/',\n",
" save_model_frequency=1000,\n",
" use_mini_batch_infer=True)"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"### (Optional) 5. Visualize Model Performance History\n",
"\n",
"Same as the node classification pipeline, we can use the history stored in the evaluator."
]
},
{
"cell_type": "code",
"execution_count": 17,
"metadata": {
"tags": []
},
"outputs": [],
"source": [
"import matplotlib.pyplot as plt\n",
"\n",
"# extract evaluation history of metrics from the trainer's evaluator:\n",
"val_metrics, test_metrics = [], []\n",
"for val_metric, test_metric in trainer.evaluator.history:\n",
" val_metrics.append(val_metric['mrr'])\n",
" test_metrics.append(test_metric['mrr'])\n",
"\n",
"# plot the performance curves\n",
"fig, ax = plt.subplots()\n",
"ax.plot(val_metrics, label='val')\n",
"ax.plot(test_metrics, label='test')\n",
"ax.set(xlabel='Epoch', ylabel='Mrr')\n",
"ax.legend(loc='best')"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"### 6. Inference with the Trained Model\n",
"\n",
"The operations of model restore are same as those used in the Notebook 1. Users can find the best model path first, and use model's `restore_model()` to load the trained model file."
]
},
{
"cell_type": "code",
"execution_count": 18,
"metadata": {
"tags": []
},
"outputs": [],
"source": [
"# after training, the best model is saved to disk:\n",
"best_model_path = trainer.get_best_model_path()\n",
"print('Best model path:', best_model_path)"
]
},
{
"cell_type": "code",
"execution_count": 19,
"metadata": {},
"outputs": [],
"source": [
"# we can restore the model from the saved path using the model's restore_model() function.\n",
"model.restore_model(best_model_path)"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"To do inference, users can either create a new dataloader as the following code does, or reuse one of the dataloaders defined in training."
]
},
{
"cell_type": "code",
"execution_count": 12,
"metadata": {},
"outputs": [],
"source": [
"# Setup dataloader for inference\n",
"infer_dataloader = gs.dataloading.GSgnnLinkPredictionTestDataLoader(\n",
" dataset=acm_data,\n",
" target_idx=acm_data.get_edge_infer_set(etypes=[('paper', 'citing', 'paper')]),\n",
" fanout=[100, 100],\n",
" num_negative_edges=100,\n",
" node_feats=nfeats_4_modeling,\n",
" batch_size=256)"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"Now we can define a `GSgnnLinkPredictionInferrer` by giving the restored model and do inference by calling its `infer()` method."
]
},
{
"cell_type": "code",
"execution_count": 20,
"metadata": {},
"outputs": [],
"source": [
"# Create an Inferrer object\n",
"infer = gs.inference.GSgnnLinkPredictionInferrer(model)\n",
"\n",
"# Run inference on the inference dataset\n",
"infer.infer(acm_data,\n",
" infer_dataloader,\n",
" save_embed_path='infer/embeddings',\n",
" use_mini_batch_infer=True)"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"For link prediction task, the inference outputs are embeddings of all nodes in the inference graph."
]
},
{
"cell_type": "code",
"execution_count": 21,
"metadata": {},
"outputs": [],
"source": [
"# The GNN embeddings of all nodes in the inference graph are saved to the folder named after the target_ntype\n",
"!ls -lh infer/embeddings/paper"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": []
}
],
"metadata": {
"kernelspec": {
"display_name": "gsf",
"language": "python",
"name": "gsf"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.9.18"
}
},
"nbformat": 4,
"nbformat_minor": 4
}
Loading
Loading