Skip to content

Commit

Permalink
[Documentation] Add multi-task learning CLI notebook example (#1082)
Browse files Browse the repository at this point in the history
Add yaml configuration file for model training

*Issue #, if available:*

*Description of changes:*


By submitting this pull request, I confirm that you can use, modify,
copy, and redistribute this contribution, under the terms of your
choice.

---------

Co-authored-by: Xiang Song <[email protected]>
  • Loading branch information
classicsong and Xiang Song authored Nov 7, 2024
1 parent ee16e74 commit af7ae16
Show file tree
Hide file tree
Showing 2 changed files with 372 additions and 0 deletions.
313 changes: 313 additions & 0 deletions docs/source/cli/notebooks/Notebook_4_MT_CLI.ipynb
Original file line number Diff line number Diff line change
@@ -0,0 +1,313 @@
{
"cells": [
{
"cell_type": "markdown",
"id": "fa464346",
"metadata": {},
"source": [
"# Use GraphStorm CLIs for Multi-task Learning"
]
},
{
"cell_type": "markdown",
"id": "bb315f10",
"metadata": {},
"source": [
"This notebook demonstrates how to use GraphStorm Command Line Interfaces (CLIs) to run multi-task GNN model training and inference. By playing with this nodebook, users will be able to get familiar with GraphStom CLIs, hence furhter using them on their own tasks and models."
]
},
{
"cell_type": "markdown",
"id": "9a5c6df6",
"metadata": {},
"source": [
"In this notebook, we will train a RGCN model on the ACM dataset with two training supervisions, i.e., link prediction and node feature reconstruction."
]
},
{
"cell_type": "markdown",
"id": "3087d88f",
"metadata": {},
"source": [
"**Note:** For more details about multi-task learning please refer to [Multi-task Learning in GraphStorm](https://graphstorm.readthedocs.io/en/latest/advanced/multi-task-learning.html)"
]
},
{
"cell_type": "markdown",
"id": "a0c7764a",
"metadata": {},
"source": [
"### 0. Setup environment\n",
"----\n",
"First let's install GraphStorm and its dependencies, PyTorch and DGL."
]
},
{
"cell_type": "code",
"execution_count": 1,
"id": "7e6423b7",
"metadata": {},
"outputs": [],
"source": [
"!pip install scikit-learn==1.4.2\n",
"!pip install scipy==1.13.0\n",
"!pip install pandas==1.3.5\n",
"!pip install pyarrow==14.0.0\n",
"!pip install graphstorm\n",
"!pip install torch==2.1.0 torchvision torchaudio --index-url https://download.pytorch.org/whl/cpu\n",
"!pip install dgl==1.1.3 -f https://data.dgl.ai/wheels-internal/repo.html"
]
},
{
"cell_type": "markdown",
"id": "76f8f9c9",
"metadata": {},
"source": [
"### 1. Create the example ACM graph data\n",
"---\n",
"This notebook uses the ACM graph as an example. We use the following script to create the ACM graph data."
]
},
{
"cell_type": "code",
"execution_count": 2,
"id": "e3ae0728",
"metadata": {},
"outputs": [],
"source": [
"!mkdir example\n",
"!wget -O /example/acm_data.py https://github.com/awslabs/graphstorm/raw/main/examples/acm_data.py\n",
"!python /example/acm_data.py --output-path /example/acm_raw"
]
},
{
"cell_type": "markdown",
"id": "e524e0c6",
"metadata": {},
"source": [
"The ACM graph data includes node files and edge files. It also includes a JSON configuration file describing how to construct a graph for model training. More details can be found in [Use Your Own Data (ACM data example)](https://graphstorm.readthedocs.io/en/latest/tutorials/own-data.html#use-your-own-data)."
]
},
{
"cell_type": "code",
"execution_count": 3,
"id": "2c3e63a9",
"metadata": {},
"outputs": [],
"source": [
"!ls -al /example/acm_raw/"
]
},
{
"cell_type": "markdown",
"id": "ee3ca793",
"metadata": {},
"source": [
"### 2. Construct and Partition ACM Graph\n",
"---\n",
"Since GraphStorm is designed naturally for distributed GNN training, we need to construct a graph and split it into multiple partitions. In this example, for simplicity, we create a graph with one partition (no actual splitting)."
]
},
{
"cell_type": "code",
"execution_count": 4,
"id": "8602b955",
"metadata": {},
"outputs": [],
"source": [
"!python -m graphstorm.gconstruct.construct_graph \\\n",
" --conf-file /example/acm_raw/config.json \\\n",
" --output-dir /example/acm_gs \\\n",
" --num-parts 1 \\\n",
" --graph-name acm"
]
},
{
"cell_type": "markdown",
"id": "de68edcf",
"metadata": {},
"source": [
"The generated ACM graph contains all the information required for GNN model training. For more details of preparing data for multi-task learning, please refer to [Preparing multi-task learning data](https://graphstorm.readthedocs.io/en/latest/advanced/multi-task-learning.html#preparing-the-training-data)."
]
},
{
"cell_type": "markdown",
"id": "5735fdf1",
"metadata": {},
"source": [
"### 3. GNN Model Training \n",
"---\n",
"\n",
"Once the graph constucted, we can call the GraphStorm multi-task learning CLI to run model training. Before kicking off the model training, we need to create a YAML configuration file for the CLI."
]
},
{
"cell_type": "code",
"execution_count": 5,
"id": "92e982df",
"metadata": {},
"outputs": [],
"source": [
"!wget -O /example/acm_mt.yaml https://github.com/awslabs/graphstorm/raw/main/examples/use_your_own_data/acm_mt.yaml"
]
},
{
"cell_type": "code",
"execution_count": 6,
"id": "ba138428",
"metadata": {},
"outputs": [],
"source": [
"!cat /example/acm_mt.yaml"
]
},
{
"cell_type": "markdown",
"id": "994ea961",
"metadata": {},
"source": [
"The YAML configuration file defines two training tasks: \n",
" * A link prediction task on the `<paper, citing, paper>` edges. The task specific settings are under the`gsf::multi_task_learning::link_prediction` configuration block.\n",
" * A node feature reconstruction task on the `paper` nodes with the node feature `label` to be reconstructed. The task specific settings are under the`gsf::multi_task_learning::reconstruct_node_feat` configuration block.\n",
" \n",
"For more details of multi-task YAML configuration, please refer to [Define Multi-task for training](https://graphstorm.readthedocs.io/en/latest/advanced/multi-task-learning.html#define-multi-task-for-training)."
]
},
{
"cell_type": "markdown",
"id": "e7d3182a",
"metadata": {},
"source": [
"#### Launch the training job"
]
},
{
"cell_type": "code",
"execution_count": 7,
"id": "67397950",
"metadata": {},
"outputs": [],
"source": [
"!python -m graphstorm.run.gs_multi_task_learning \\\n",
" --workspace /example \\\n",
" --part-config /example/acm_gs/acm.json \\\n",
" --num-trainers 1 \\\n",
" --cf /example/acm_mt.yaml \\\n",
" --num-epochs 4"
]
},
{
"cell_type": "markdown",
"id": "e20bce57",
"metadata": {},
"source": [
"The saved model is under `/example/acm_lp/models/`."
]
},
{
"cell_type": "code",
"execution_count": 8,
"id": "0212a5b1",
"metadata": {},
"outputs": [],
"source": [
"!ls -a /example/acm_lp/models/"
]
},
{
"cell_type": "markdown",
"id": "8b1ca0af",
"metadata": {},
"source": [
"### 4. GNN Model Inference \n",
"---\n",
"\n",
"Once the model is trained, we can do model inference with the trained model artifacts by using the GraphStorm multi-task learning CLI. We can use the same YAML configuration file for model inference."
]
},
{
"cell_type": "markdown",
"id": "27fe766a",
"metadata": {},
"source": [
"#### Launch the inference job\n",
"We load the model checkpoint of epoch-2 in the example to do inference. The inference command will report the test scores for both link prediction task and node feature reconstruction task."
]
},
{
"cell_type": "code",
"execution_count": 9,
"id": "a3324ee5",
"metadata": {},
"outputs": [],
"source": [
"!python -m graphstorm.run.gs_multi_task_learning \\\n",
" --inference \\\n",
" --workspace /example \\\n",
" --part-config /example/acm_gs/acm.json \\\n",
" --restore-model-path /example/acm_lp/models/epoch-2 \\\n",
" --num-trainers 1 \\\n",
" --cf /example/acm_mt.yaml"
]
},
{
"cell_type": "markdown",
"id": "9134523f",
"metadata": {},
"source": [
"#### Launch the embedding generation inference job\n",
"\n",
"You can also use the GraphStorm `gs_gen_node_embedding` CLI to generate node embeddings with the trained GNN model on the ACM graph. The saved node embeddings are under `/example/acm_lp/emb/`."
]
},
{
"cell_type": "code",
"execution_count": 10,
"id": "b868af38",
"metadata": {},
"outputs": [],
"source": [
"!python -m graphstorm.run.gs_gen_node_embedding \\\n",
" --inference \\\n",
" --workspace /example \\\n",
" --part-config /example/acm_gs/acm.json \\\n",
" --restore-model-path /example/acm_lp/models/epoch-2 \\\n",
" --save-embed-path /example/acm_lp/emb/ \\\n",
" --restore-model-layers \"embed,gnn\" \\\n",
" --num-trainers 1 \\\n",
" --cf /example/acm_mt.yaml"
]
},
{
"cell_type": "code",
"execution_count": 11,
"id": "16161461",
"metadata": {},
"outputs": [],
"source": [
"!ls -al /example/acm_lp/emb/"
]
}
],
"metadata": {
"kernelspec": {
"display_name": "conda_python3",
"language": "python",
"name": "conda_python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.10.15"
}
},
"nbformat": 4,
"nbformat_minor": 5
}
59 changes: 59 additions & 0 deletions examples/use_your_own_data/acm_mt.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,59 @@
---
version: 1.0
gsf:
basic:
model_encoder_type: rgcn
backend: gloo
verbose: false
gnn:
fanout: "50"
num_layers: 1
hidden_size: 256
use_mini_batch_infer: false
lp_decoder_type: dot_product
input:
restore_model_path: null
output:
save_model_path: /tmp/acm_lp/models
save_embeds_path: /tmp/acm_lp/embeds
hyperparam:
dropout: 0.
lr: 0.0001
lm_tune_lr: 0.0001
num_epochs: 200
batch_size: 1024
bert_infer_bs: 128
wd_l2norm: 0
alpha_l2norm: 0.
rgcn:
num_bases: -1
use_self_loop: true
sparse_optimizer_lr: 1e-2
use_node_embeddings: false
lp_decoder_type: dot_product
multi_task_learning:
- link_prediction:
num_negative_edges: 4
num_negative_edges_eval: 100
train_negative_sampler: joint
eval_etype:
- "paper,citing,paper"
train_etype:
- "paper,citing,paper"
exclude_training_targets: false
reverse_edge_types_map: ["paper,citing,cited,paper"]
mask_fields:
- "train_mask" # edge classification mask
- "val_mask"
- "test_mask"
- reconstruct_node_feat:
reconstruct_nfeat_name: "feat"
target_ntype: "paper"
batch_size: 32
mask_fields:
- "train_mask" # node classification mask 0
- "val_mask"
- "test_mask"
task_weight: 1.0
eval_metric:
- "mse"

0 comments on commit af7ae16

Please sign in to comment.