-
Notifications
You must be signed in to change notification settings - Fork 62
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Merge branch 'main' into gsprocessing-hard-negative
- Loading branch information
Showing
2 changed files
with
372 additions
and
0 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,313 @@ | ||
{ | ||
"cells": [ | ||
{ | ||
"cell_type": "markdown", | ||
"id": "fa464346", | ||
"metadata": {}, | ||
"source": [ | ||
"# Use GraphStorm CLIs for Multi-task Learning" | ||
] | ||
}, | ||
{ | ||
"cell_type": "markdown", | ||
"id": "bb315f10", | ||
"metadata": {}, | ||
"source": [ | ||
"This notebook demonstrates how to use GraphStorm Command Line Interfaces (CLIs) to run multi-task GNN model training and inference. By playing with this nodebook, users will be able to get familiar with GraphStom CLIs, hence furhter using them on their own tasks and models." | ||
] | ||
}, | ||
{ | ||
"cell_type": "markdown", | ||
"id": "9a5c6df6", | ||
"metadata": {}, | ||
"source": [ | ||
"In this notebook, we will train a RGCN model on the ACM dataset with two training supervisions, i.e., link prediction and node feature reconstruction." | ||
] | ||
}, | ||
{ | ||
"cell_type": "markdown", | ||
"id": "3087d88f", | ||
"metadata": {}, | ||
"source": [ | ||
"**Note:** For more details about multi-task learning please refer to [Multi-task Learning in GraphStorm](https://graphstorm.readthedocs.io/en/latest/advanced/multi-task-learning.html)" | ||
] | ||
}, | ||
{ | ||
"cell_type": "markdown", | ||
"id": "a0c7764a", | ||
"metadata": {}, | ||
"source": [ | ||
"### 0. Setup environment\n", | ||
"----\n", | ||
"First let's install GraphStorm and its dependencies, PyTorch and DGL." | ||
] | ||
}, | ||
{ | ||
"cell_type": "code", | ||
"execution_count": 1, | ||
"id": "7e6423b7", | ||
"metadata": {}, | ||
"outputs": [], | ||
"source": [ | ||
"!pip install scikit-learn==1.4.2\n", | ||
"!pip install scipy==1.13.0\n", | ||
"!pip install pandas==1.3.5\n", | ||
"!pip install pyarrow==14.0.0\n", | ||
"!pip install graphstorm\n", | ||
"!pip install torch==2.1.0 torchvision torchaudio --index-url https://download.pytorch.org/whl/cpu\n", | ||
"!pip install dgl==1.1.3 -f https://data.dgl.ai/wheels-internal/repo.html" | ||
] | ||
}, | ||
{ | ||
"cell_type": "markdown", | ||
"id": "76f8f9c9", | ||
"metadata": {}, | ||
"source": [ | ||
"### 1. Create the example ACM graph data\n", | ||
"---\n", | ||
"This notebook uses the ACM graph as an example. We use the following script to create the ACM graph data." | ||
] | ||
}, | ||
{ | ||
"cell_type": "code", | ||
"execution_count": 2, | ||
"id": "e3ae0728", | ||
"metadata": {}, | ||
"outputs": [], | ||
"source": [ | ||
"!mkdir example\n", | ||
"!wget -O /example/acm_data.py https://github.com/awslabs/graphstorm/raw/main/examples/acm_data.py\n", | ||
"!python /example/acm_data.py --output-path /example/acm_raw" | ||
] | ||
}, | ||
{ | ||
"cell_type": "markdown", | ||
"id": "e524e0c6", | ||
"metadata": {}, | ||
"source": [ | ||
"The ACM graph data includes node files and edge files. It also includes a JSON configuration file describing how to construct a graph for model training. More details can be found in [Use Your Own Data (ACM data example)](https://graphstorm.readthedocs.io/en/latest/tutorials/own-data.html#use-your-own-data)." | ||
] | ||
}, | ||
{ | ||
"cell_type": "code", | ||
"execution_count": 3, | ||
"id": "2c3e63a9", | ||
"metadata": {}, | ||
"outputs": [], | ||
"source": [ | ||
"!ls -al /example/acm_raw/" | ||
] | ||
}, | ||
{ | ||
"cell_type": "markdown", | ||
"id": "ee3ca793", | ||
"metadata": {}, | ||
"source": [ | ||
"### 2. Construct and Partition ACM Graph\n", | ||
"---\n", | ||
"Since GraphStorm is designed naturally for distributed GNN training, we need to construct a graph and split it into multiple partitions. In this example, for simplicity, we create a graph with one partition (no actual splitting)." | ||
] | ||
}, | ||
{ | ||
"cell_type": "code", | ||
"execution_count": 4, | ||
"id": "8602b955", | ||
"metadata": {}, | ||
"outputs": [], | ||
"source": [ | ||
"!python -m graphstorm.gconstruct.construct_graph \\\n", | ||
" --conf-file /example/acm_raw/config.json \\\n", | ||
" --output-dir /example/acm_gs \\\n", | ||
" --num-parts 1 \\\n", | ||
" --graph-name acm" | ||
] | ||
}, | ||
{ | ||
"cell_type": "markdown", | ||
"id": "de68edcf", | ||
"metadata": {}, | ||
"source": [ | ||
"The generated ACM graph contains all the information required for GNN model training. For more details of preparing data for multi-task learning, please refer to [Preparing multi-task learning data](https://graphstorm.readthedocs.io/en/latest/advanced/multi-task-learning.html#preparing-the-training-data)." | ||
] | ||
}, | ||
{ | ||
"cell_type": "markdown", | ||
"id": "5735fdf1", | ||
"metadata": {}, | ||
"source": [ | ||
"### 3. GNN Model Training \n", | ||
"---\n", | ||
"\n", | ||
"Once the graph constucted, we can call the GraphStorm multi-task learning CLI to run model training. Before kicking off the model training, we need to create a YAML configuration file for the CLI." | ||
] | ||
}, | ||
{ | ||
"cell_type": "code", | ||
"execution_count": 5, | ||
"id": "92e982df", | ||
"metadata": {}, | ||
"outputs": [], | ||
"source": [ | ||
"!wget -O /example/acm_mt.yaml https://github.com/awslabs/graphstorm/raw/main/examples/use_your_own_data/acm_mt.yaml" | ||
] | ||
}, | ||
{ | ||
"cell_type": "code", | ||
"execution_count": 6, | ||
"id": "ba138428", | ||
"metadata": {}, | ||
"outputs": [], | ||
"source": [ | ||
"!cat /example/acm_mt.yaml" | ||
] | ||
}, | ||
{ | ||
"cell_type": "markdown", | ||
"id": "994ea961", | ||
"metadata": {}, | ||
"source": [ | ||
"The YAML configuration file defines two training tasks: \n", | ||
" * A link prediction task on the `<paper, citing, paper>` edges. The task specific settings are under the`gsf::multi_task_learning::link_prediction` configuration block.\n", | ||
" * A node feature reconstruction task on the `paper` nodes with the node feature `label` to be reconstructed. The task specific settings are under the`gsf::multi_task_learning::reconstruct_node_feat` configuration block.\n", | ||
" \n", | ||
"For more details of multi-task YAML configuration, please refer to [Define Multi-task for training](https://graphstorm.readthedocs.io/en/latest/advanced/multi-task-learning.html#define-multi-task-for-training)." | ||
] | ||
}, | ||
{ | ||
"cell_type": "markdown", | ||
"id": "e7d3182a", | ||
"metadata": {}, | ||
"source": [ | ||
"#### Launch the training job" | ||
] | ||
}, | ||
{ | ||
"cell_type": "code", | ||
"execution_count": 7, | ||
"id": "67397950", | ||
"metadata": {}, | ||
"outputs": [], | ||
"source": [ | ||
"!python -m graphstorm.run.gs_multi_task_learning \\\n", | ||
" --workspace /example \\\n", | ||
" --part-config /example/acm_gs/acm.json \\\n", | ||
" --num-trainers 1 \\\n", | ||
" --cf /example/acm_mt.yaml \\\n", | ||
" --num-epochs 4" | ||
] | ||
}, | ||
{ | ||
"cell_type": "markdown", | ||
"id": "e20bce57", | ||
"metadata": {}, | ||
"source": [ | ||
"The saved model is under `/example/acm_lp/models/`." | ||
] | ||
}, | ||
{ | ||
"cell_type": "code", | ||
"execution_count": 8, | ||
"id": "0212a5b1", | ||
"metadata": {}, | ||
"outputs": [], | ||
"source": [ | ||
"!ls -a /example/acm_lp/models/" | ||
] | ||
}, | ||
{ | ||
"cell_type": "markdown", | ||
"id": "8b1ca0af", | ||
"metadata": {}, | ||
"source": [ | ||
"### 4. GNN Model Inference \n", | ||
"---\n", | ||
"\n", | ||
"Once the model is trained, we can do model inference with the trained model artifacts by using the GraphStorm multi-task learning CLI. We can use the same YAML configuration file for model inference." | ||
] | ||
}, | ||
{ | ||
"cell_type": "markdown", | ||
"id": "27fe766a", | ||
"metadata": {}, | ||
"source": [ | ||
"#### Launch the inference job\n", | ||
"We load the model checkpoint of epoch-2 in the example to do inference. The inference command will report the test scores for both link prediction task and node feature reconstruction task." | ||
] | ||
}, | ||
{ | ||
"cell_type": "code", | ||
"execution_count": 9, | ||
"id": "a3324ee5", | ||
"metadata": {}, | ||
"outputs": [], | ||
"source": [ | ||
"!python -m graphstorm.run.gs_multi_task_learning \\\n", | ||
" --inference \\\n", | ||
" --workspace /example \\\n", | ||
" --part-config /example/acm_gs/acm.json \\\n", | ||
" --restore-model-path /example/acm_lp/models/epoch-2 \\\n", | ||
" --num-trainers 1 \\\n", | ||
" --cf /example/acm_mt.yaml" | ||
] | ||
}, | ||
{ | ||
"cell_type": "markdown", | ||
"id": "9134523f", | ||
"metadata": {}, | ||
"source": [ | ||
"#### Launch the embedding generation inference job\n", | ||
"\n", | ||
"You can also use the GraphStorm `gs_gen_node_embedding` CLI to generate node embeddings with the trained GNN model on the ACM graph. The saved node embeddings are under `/example/acm_lp/emb/`." | ||
] | ||
}, | ||
{ | ||
"cell_type": "code", | ||
"execution_count": 10, | ||
"id": "b868af38", | ||
"metadata": {}, | ||
"outputs": [], | ||
"source": [ | ||
"!python -m graphstorm.run.gs_gen_node_embedding \\\n", | ||
" --inference \\\n", | ||
" --workspace /example \\\n", | ||
" --part-config /example/acm_gs/acm.json \\\n", | ||
" --restore-model-path /example/acm_lp/models/epoch-2 \\\n", | ||
" --save-embed-path /example/acm_lp/emb/ \\\n", | ||
" --restore-model-layers \"embed,gnn\" \\\n", | ||
" --num-trainers 1 \\\n", | ||
" --cf /example/acm_mt.yaml" | ||
] | ||
}, | ||
{ | ||
"cell_type": "code", | ||
"execution_count": 11, | ||
"id": "16161461", | ||
"metadata": {}, | ||
"outputs": [], | ||
"source": [ | ||
"!ls -al /example/acm_lp/emb/" | ||
] | ||
} | ||
], | ||
"metadata": { | ||
"kernelspec": { | ||
"display_name": "conda_python3", | ||
"language": "python", | ||
"name": "conda_python3" | ||
}, | ||
"language_info": { | ||
"codemirror_mode": { | ||
"name": "ipython", | ||
"version": 3 | ||
}, | ||
"file_extension": ".py", | ||
"mimetype": "text/x-python", | ||
"name": "python", | ||
"nbconvert_exporter": "python", | ||
"pygments_lexer": "ipython3", | ||
"version": "3.10.15" | ||
} | ||
}, | ||
"nbformat": 4, | ||
"nbformat_minor": 5 | ||
} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,59 @@ | ||
--- | ||
version: 1.0 | ||
gsf: | ||
basic: | ||
model_encoder_type: rgcn | ||
backend: gloo | ||
verbose: false | ||
gnn: | ||
fanout: "50" | ||
num_layers: 1 | ||
hidden_size: 256 | ||
use_mini_batch_infer: false | ||
lp_decoder_type: dot_product | ||
input: | ||
restore_model_path: null | ||
output: | ||
save_model_path: /tmp/acm_lp/models | ||
save_embeds_path: /tmp/acm_lp/embeds | ||
hyperparam: | ||
dropout: 0. | ||
lr: 0.0001 | ||
lm_tune_lr: 0.0001 | ||
num_epochs: 200 | ||
batch_size: 1024 | ||
bert_infer_bs: 128 | ||
wd_l2norm: 0 | ||
alpha_l2norm: 0. | ||
rgcn: | ||
num_bases: -1 | ||
use_self_loop: true | ||
sparse_optimizer_lr: 1e-2 | ||
use_node_embeddings: false | ||
lp_decoder_type: dot_product | ||
multi_task_learning: | ||
- link_prediction: | ||
num_negative_edges: 4 | ||
num_negative_edges_eval: 100 | ||
train_negative_sampler: joint | ||
eval_etype: | ||
- "paper,citing,paper" | ||
train_etype: | ||
- "paper,citing,paper" | ||
exclude_training_targets: false | ||
reverse_edge_types_map: ["paper,citing,cited,paper"] | ||
mask_fields: | ||
- "train_mask" # edge classification mask | ||
- "val_mask" | ||
- "test_mask" | ||
- reconstruct_node_feat: | ||
reconstruct_nfeat_name: "feat" | ||
target_ntype: "paper" | ||
batch_size: 32 | ||
mask_fields: | ||
- "train_mask" # node classification mask 0 | ||
- "val_mask" | ||
- "test_mask" | ||
task_weight: 1.0 | ||
eval_metric: | ||
- "mse" |