-
Notifications
You must be signed in to change notification settings - Fork 61
advanced own models
Use Your Own Models#
Currently GraphStorm has two built-in GNN models, i.e., the RGCN and the RGAT model. If users want to further explore different GNN models and leverage the GraphStorm’s ease-of-use and scalability, you can create your own GNN models according to the GraphStorm’s customer model APIs. This tutorial will explain in detail how to do this with a runnable example that customizes the HGT DGL model implementation.
Prerequisites#
Before following GraphStorm’s customized model APIs, please make sure your GNN models meet the prerequisites.
Use DGL to implement your GNN models#
The GraphStorm Framework relies on the DGL library to implement and run GNN models. Particularly, the GraphStorm’s scalability comes from the DGL’s distributed libraries. For this reason, your GNN models should be implemented with the DGL Library. You can learn how to do this via the DGL’s User Guide. In addition, there are many GNN model examples implemented by the DGL community. Please explore these materials to check if there is any model that may meet your requirements.
Modify you GNN models to use mini-batch training inference#
Many existing GNN models were implemented for running on popular academic graphs, which, compared to enterprise-level graphs, are relatively small and lack node/edge features. Therefore, implementors use the full-graph training/inference mode, i.e., feed the entire graph along with its node/edge features into GNN models in one epoch. When dealing with large graphs, this mode will fail due to either the limits of the GPUs’ memory, or the slow speed if using CPUs.
In order to tackle large graphs, we can change GNN models to perform stochastic mini-batch training. You can learn how to modify GNN models into mini-batch training/inference mode via the DGL User Guide Chapter 6. For examples of the different implementations between full-graph mode and mini-batch mode, please look for DGL model examples, in which mini-batch mode files normally have a file name ended with _mb` string, like the RGCN model, or file names including dist string, like the `GraphSage distributed model.
Learn how to run GraphStorm in a Docker environment#
Currently GraphStorm runs on Docker environment. The rest of the tutorial assumes execution within the GraphStorm Docker container. Please refer to the first two sections in the Environment Setup to learn how to run GraphStorm in a Docker environment, and set up your environment.
Modifications required for customer models#
Step 1: Convert your graph data into required format#
Users can follow the User Your Own Graph Data tutorial to prepare your graph data for GraphStorm.
Step 2: Modify your GNN model to use the GraphStorm APIs#
To plug your GNN models into GraphStorm, you need to use the GraphStorm model APIs. The key model APIs are the class GSgnnNodeModelBase, GSgnnEdgeModelBase, and GSgnnLinkPredictionModelBase. Your GNN models should inherit one of the three classes depending on your task.
Here we use the DGL HGT example model to demonstrate how to modify the GNN models.
class HGT(nn.Module): def __init__(self, ......): super(HTG, self).__init__() ...... def forward(self, G, out_key): h = {} for ntype in G.ntypes: n_id = self.node_dict[ntype] h[ntype] = F.gelu(self.adapt_ws[n_id](G.nodes[ntype].data["inp"])) for i in range(self.n_layers): h = self.gcs[i](G, h) return self.out(h[out_key])
The original HGT model implement uses full-graph training and inference mode. Its forward()
function takes a DGL graph, G
, and the to-be predicted node type, out_key
, as input arguments.
As the Prerequisites required, we first revise this model to use mini-batch training and inference mode as shown below.
class HGT_mb(nn.Module): def __init__(self, ......) super(HGT_mb, self).__init__() ...... def forward(self, blocks, n_feats_dict, out_ntype): h = {} for ntype in blocks[0].ntypes: if self.adapt_ws[ntype] is None: n_id = self.node_dict[ntype] emb_id = self.ntype_id_map[n_id] n_embed = self.ntype_embed(torch.Tensor([emb_id] * blocks[0].num_nodes(ntype)).long().to(self.device)) else: n_embed = self.adapt_ws[ntype](n_feats_dict[ntype]) h[ntype] = F.gelu(n_embed) for i in range(self.n_layers): h = self.gcs[i](blocks[i], h) return self.out(h[out_ntype])
The new HGT_mb
model’s forward()
function takes mini-batch blocks, blocks
, and their corresponding node feature dictionary, n_feats_dict
, as inputs to replace the original full graph data, G
.
Then to further make this HGT_mb
model work in GraphStorm, we need replace the PyTorch nn.Module
with GraphStorm’s GSgnnNodeModelBase
and implement required functions.
The GSgnnNodeModelBase
class, which is also a PyTorch Module extension, has three required functions that users’ own GNN model need to implement, including forward(self, blocks, node_feats, edge_feats, labels, input_nodes)
, predict(self, blocks, node_feats, edge_feats, input_nodes)
, and create_optimizer(self)
.
The GSgnnNodeModelBase
class’ forward()
function is similar to the PyTorch Module’s forward()
function except that its input arguments MUST include:
blocks, which are DGL blocks sampled for a mini-batch.
labels, which is a dictionary, whose key is the to-be predicted node type, and value is the labels of the to-be predicted nodes in a mini-batch.
node_feats, which is a dictionary, whose keys are node types in the graph, and values are the node features associated to.
edge_feats. Currently GraphStorm does NOT support edge features. So, leave as it is.
input_nodes, optional only if your GNN model needs them.
Unlike common cases where forward function returns logits computed by models, the return value of forward()
should be a loss value, which GraphStorm will use to perform backward operations. Because of this change, you need to include a loss function within your GNN models, instead of computing loss outside. Following these requirements, our revised model will have a few more lines added as shown below.
class HGT(gsmodel.GSgnnNodeModelBase): def __init__(self, ......) # use GraphStorm loss function components self._loss_fn = gsmodel.ClassifyLossFunc(multilabel=False) def forward(self, blocks, node_feats, edge_feats, labels, input_nodes): h = {} for ntype in blocks[0].ntypes: if self.adapt_ws[ntype] is None: n_id = self.node_dict[ntype] emb_id = self.ntype_id_map[n_id] embeding = self.ntype_embed(torch.Tensor([emb_id]).long().to('cuda')) n_embed = embeding.expand(blocks[0].num_nodes(ntype), -1) else: n_embed = self.adapt_ws[ntype](node_feats[ntype]) h[ntype] = F.gelu(n_embed) for i in range(self.num_layers): h = self.gcs[i](blocks[i], h) pred_loss = self._loss_fn(h[self.target_ntype], labels[self.target_ntype]) return pred_loss
You may notice that GraphStorm already provides common loss functions for classification, regression and link prediction, which can be easily imported and used in your model. But you are free to use any PyTorch loss functions or even your own loss function. In the above example, we also change the to-be predicted node type as a class variable, and use it for computing the loss value.
The predict()
function is for inference and it will not be used for backward. Its input arguments are similar to the forward() function, but no need for labels. The predict()
function will return two values. The first is the prediction results. The second is the the logits, which could be used for some specific purposes (this return value is uncommon for some users, and we are working on the fix this confusion). With these requirements, the predict()
function of the modified HGT model is like the code below.
def predict(self, blocks, node_feats, _, input_nodes): h = {} for ntype in blocks[0].ntypes: if self.adapt_ws[ntype] is None: n_id = self.node_dict[ntype] emb_id = self.ntype_id_map[n_id] embeding = self.ntype_embed(torch.Tensor([emb_id]).long().to('cuda')) n_embed = embeding.expand(blocks[0].num_nodes(ntype), -1) else: n_embed = self.adapt_ws[ntype](node_feats[ntype]) h[ntype] = F.gelu(n_embed) for i in range(self.num_layers): h = self.gcs[i](blocks[i], h) return h[self.target_ntype].argmax(dim=1), h[self.target_ntype] # return two values: one is the predict results, # while another is the computed node representations, which can be saved.
The create_optimizer()
function is for users to define their own optimizer. You can put the optimizer definition from the training flow inside the model, like the code below
def create_optimizer(self, lr=0.001): return torch.optim.Adam(self.parameters(), lr=lr)
There are other optional functions in the GSgnnNodeModelBase class, including restore_model(self, restore_model_path)
and save_model(self, model_path)
, which are used to restore and save models. If you want to save or restore models, implement these two functions too.
Step 3. Modify the training/inference flow with the GraphStorm APIs#
With the modified GNN models ready, the next step is to modify the training/inference loop by replacing datasets and dataloaders with the GraphStorm’s dataloading classes.
The original HGT_mb model uses the DGL Stochastic Trainingon Large Graph Guide method for the training/infernece flow. GraphStorm training/inference flow is similar with a few modifications.
Start training process with GraphStorm’s iniatilization#
Any GraphStorm training process MUST start with a proper initialization. You can use the following codes at the beginning of training flow.
import graphstorm as gs ......def main(args): gs.initialize(ip_config=args.ip_config, backend="gloo")
the ip_config
argument specifies a ip configuration file, which contains the IP addresses of machines in a GraphStorm distributed cluster. You can find its description at the Launch Training section of the Quick Start Tutorial.
Replace DGL DataLoader with the GraphStorm’s dataset and dataloader#
Because the GraphStorm uses distributed graphs, we need to first load the partitioned graph, which is created in the Step 1, with the GSgnnNodeTrainData class (for edge tasks, GraphStorm also provides GSgnnEdgeTrainData). The GSgnnNodeTrainData
could be created as shown in the code below.
train_data = GSgnnNodeTrainData(config.graph_name, config.part_config, train_ntypes=config.target_ntype, node_feat_field=node_feat_fields, label_field=config.label_field)
Arguments of this class include the partition configuration JSON file path, which are the outputs of the Step 1. The graph_name
can be found in the JSON file.
The other values, the train_ntypes
, the label_field
, and the node_feat_field
, should be consistent with the values in the raw data input configuration JSON defined in the Step 1. The train_ntypes
is the node_type
that has labels
specified. The label_fields
is the value specified in label_col
of the train_ntype
. The node_feat_field
is a dictionary, whose key is the values of node_type
, and value is the values of feature_name
.
Then we can put this dataset into GraphStorm’s GSgnnNodeDataLoader, which is like:
# Define the GraphStorm train dataloader dataloader = GSgnnNodeDataLoader(train_data, train_data.train_idxs, fanout=config.fanout, batch_size=config.batch_size, device=device, train_task=True) # Optional: Define the evaluation dataloader eval_dataloader = GSgnnNodeDataLoader(train_data, train_data.val_idxs,fanout=config.fanout, batch_size=config.eval_batch_size, device=device, train_task=False) # Optional: Define the evaluation dataloader test_dataloader = GSgnnNodeDataLoader(train_data, train_data.test_idxs,fanout=config.fanout, batch_size=config.eval_batch_size, device=device, train_task=False)
GraphStorm provides a set of dataloaders for different GML tasks. Here we deal with a node task, hence using the node dataloader, which takes the graph data created above as the first argument. The second argument is the label index that the GraphStorm dataset extracts from the graph as indicated in the target nodes’ train_mask
, val_mask
, and test_mask
, which are automatically generated by GraphStorm graph construction tool with the specified split_pct
field. The GSgnnNodeTrainData
automatically extracts these indexes out and set its properties so that you can directly use them like graph_data.train_idxs
and graph_data.val_idxs
, and graph_data.test_idxs
. The rest of arguments are similar to the common training flow, except that we set the train_task
to be False
for the evaluation and test dataloader.
Use GraphStorm’s model trainer to wrap your model and attach evaluator and task tracker to it#
Unlike the common flow, GraphStorm wraps GNN models with different trainers just like other frameworks, e.g. scikit-learn. GraphStorm provides node prediction, edge prediction, and link prediction trainers. Creation of them is easy.
First we create the modified HGT model like the following code.
# Define the HGT model model = HGT(node_dict, edge_dict, n_inp_dict=nfeat_dims, n_hid=config.hidden_size, n_out=config.num_classes, num_layers=num_layers, num_heads=args.num_heads, target_ntype=config.target_ntype, use_norm=True, alpha_l2norm=config.alpha_l2norm)
Then we can use the GSgnnNodePredictionTrainer class to wrap it like:
# Create a trainer for the node classification task. trainer = GSgnnNodePredictionTrainer(model, gs.get_rank())
The GSgnnNodePredictionTrainer
takes a GraphStorm model as the first argument. The seconde argument is for using different GPUs.
The GraphStorm trainers can have evaluators and task trackers associated. The following code shows how to do this.
# Optional: set up a evaluator evaluator = GSgnnAccEvaluator(config.eval_frequency, config.eval_metric, config.multilabel, config.use_early_stop, config.early_stop_burnin_rounds, config.early_stop_rounds, config.early_stop_strategy) trainer.setup_evaluator(evaluator) # Optional: set up a task tracker to show the progress of training. tracker = GSSageMakerTaskTracker(config, gs.get_rank()) trainer.setup_task_tracker(tracker)
GraphStorm’s evaluators could help to compute the required evaluation metrics, such as accuracy
, f1
, mrr
, and etc. Users can select the proper evaluator and use the trainer’s setup_evaluator()
method to attach them. GraphStorm’s task trackers serve as log collectors, which are used to show the process information.
Use trainer’s fit()
function to run training#
Once all trainers, evaluators, and task trackers are set, the last step is to use the trainer’s fit()
function to run training, validating, and testing on the three sets like the code below.
# Start the training process. trainer.fit(train_loader=dataloader, num_epochs=config.num_epochs, val_loader=eval_dataloader, test_loader=test_dataloader, save_model_path=config.save_model_path, use_mini_batch_infer=True)
The fit()
function wraps dataloaders, number of epochs, to replace the common “for loops” as seen in the common training flow. The fit()
function also takes additional arguments, such as save_model_path
to save different model artifacts. BUT before set these arguments, you need to implement the restore_model(self, restore_model_path)
and save_model(self, model_path)
functions in the Step 2.
Step 4. Handle the unused weights error#
Uncommonly seen in the full-graph training or mini-batch training on a single GPU, the unused weights error could frequently occur when we start to train models on multiple GPUs in parallel. PyTorch distributed framework’s inner mechanism causes this problem. One easy way to solve this error is to add a regularization to all trainable parameters into the loss computation like the code blow.
pred_loss = self._loss_fn(h[self.target_ntype], labels[self.target_ntype]) reg_loss = torch.tensor(0.).to(pred_loss.device) # L2 regularization of dense parameters for d_para in self.parameters(): reg_loss += d_para.square().sum() reg_loss = self.alpha_l2norm * reg_loss return pred_loss + reg_loss
You can add a coefficient, like the alpha_l2norm
, to control the influence of the regularization.
Step 5. Add a few additional arguments for the Python main function#
Because GraphStorm relys on a few arguments to launch training and inference command, including: part-config
, ip-config
, verbose
, and local_rank
. GraphStorm’s built-in launch scripts have this argument configured already. But for customized models, it is required to add them as arguments of the Python main function although these arguments are not used anywhere in the customized model. A sample code is shown below.
if __name__ == '__main__': argparser = argparse.ArgumentParser("Training HGT model with the GraphStorm Framework") ...... argparser.add_argument("--part-config", type=str, required=True, help="The partition config file. \ For customized models, MUST have this argument!!") argparser.add_argument("--ip-config", type=str, required=True, help="The IP config file for the cluster. \ For customized models, MUST have this argument!!") argparser.add_argument("--verbose", type=lambda x: (str(x).lower() in ['true', '1']), default=argparse.SUPPRESS, help="Print more information. \ For customized models, MUST have this argument!!") argparser.add_argument("--local_rank", type=int, help="The rank of the trainer. \ For customized models, MUST have this argument!!")
Note
PyTorch v2.0 change the argument local_rank
to local-rank
. Therefore, if users use PyTorch v2.0 or later version, please change this argument accordingly.
Step 6. Setup GraphStorm configuration YAML file#
GraphStorm has a set of parameters that control the various perspectives of the model training and inference process. You can find the details of these parameters in the GraphStorm Training and Inference Configurations. These parameters could be either passed as input arguments or set in a YAML format file. Below is an example of the YAML file.
--- version: 1.0 gsf: basic: backend: gloo ip_config: ip_list.txt part_config: /data/acm_nc/acm.json verbose: false alpha_l2norm: 0. gnn: fanout: "50,50" num_layers: 2 hidden_size: 256 use_mini_batch_infer: false input: restore_model_path: null output: topk_model_to_save: 7 save_model_path: /data/outputs save_embeds_path: /data/outputs save_prediction_path: /data/outputs hyperparam: dropout: 0. lr: 0.0001 num_epochs: 20 batch_size: 1024 wd_l2norm: 0 rgcn: num_bases: -1 use_self_loop: true sparse_optimizer_lr: 1e-2 use_node_embeddings: false node_classification: target_ntype: "paper" label_field: "label" multilabel: false num_classes: 14
Users can use an argument to read in this YAML file, and construct a GSConfig
object like the below code. And then use the GSConfig instance, e.g., config
, to provide arguments that the GraphStorm supports.
from graphstorm.config import GSConfig ...... argparser.add_argument("--yaml-config-file", type=str, required=True, help="The GraphStorm YAML configuration file path.") args = argparser.parse_args() config = GSConfig(args)
For users’ own configurations, you still can pass them as input argument of the training script, and extract them from the args
object.
Put Everything Together and Run them#
With all required modifications ready, let’s put everything of the modified HGT model together in a Python file, e.g, hgt_nc.py
. We can put the Python file and the related artifacts, including the YAML file, e.g., acm_nc.yaml
, and the ip_list.txt
file in a folder, e.g. /hgt_nc/
. And then use the GraphStorm’s launch script to run this modified HGT model.
python3 -m graphstorm.run.launch \ --workspace /graphstorm/examples/customized_models/HGT \ --part-config /data/acm_nc/acm.json \ --ip-config /data/ip_list.txt \ --num-trainers 2 \ --num-servers 1 \ --num-samplers 0 \ --ssh-port 2222 \ hgt_nc.py --yaml-config-file acm_nc.yaml \ --node-feat paper:feat-author:feat-subject:feat \ --num-heads 8
The argument value of --part-config
is the JSON file coming from the outputs of the Step 1.
Note
To try this runnable example, please follow the GraphStorm examples readme.
Get Started
- Environment Setup
- Standalone Mode Quick Start Tutorial
- Use Your Own Data Tutorial
- GraphStorm Configurations
Scale to Giant Graphs
Advanced Topics