diff --git a/python/graphstorm/trainer/ep_trainer.py b/python/graphstorm/trainer/ep_trainer.py index 5530f2ae87..255de81b70 100644 --- a/python/graphstorm/trainer/ep_trainer.py +++ b/python/graphstorm/trainer/ep_trainer.py @@ -48,6 +48,26 @@ class GSgnnEdgePredictionTrainer(GSgnnTrainer): The GNN model for edge prediction. topk_model_to_save : int The top K model to save. + + Example + ------- + + .. code:: python + + from graphstorm.dataloading import GSgnnEdgeDataLoader + from graphstorm.dataset import GSgnnEdgeData + from graphstorm.model import GSgnnEdgeModel + from graphstorm.trainer import GSgnnEdgePredictionTrainer + + my_dataset = GSgnnEdgeData("my_graph", "/path/to/part_config") + target_idx = {"edge_type": target_edges_tensor} + my_data_loader = GSgnnEdgeDataLoader( + my_dataset, target_idx, fanout=[10], batch_size=1024) + my_model = GSgnnEdgeModel(alpha_l2norm=0.0) + + trainer = GSgnnEdgePredictionTrainer(my_model, topk_model_to_save=1) + + trainer.fit(my_data_loader, num_epochs=2) """ def __init__(self, model, topk_model_to_save): super(GSgnnEdgePredictionTrainer, self).__init__(model, topk_model_to_save) diff --git a/python/graphstorm/trainer/glem_np_trainer.py b/python/graphstorm/trainer/glem_np_trainer.py index 269f01b127..5e6ac664f8 100644 --- a/python/graphstorm/trainer/glem_np_trainer.py +++ b/python/graphstorm/trainer/glem_np_trainer.py @@ -52,6 +52,26 @@ class GLEMNodePredictionTrainer(GSgnnNodePredictionTrainer): `model.node_glem.GLEM`. topk_model_to_save : int The top K model to save. + + Example + ------- + + .. code:: python + + from graphstorm.dataloading import GSgnnNodeDataLoader + from graphstorm.dataset import GSgnnNodeData + from graphstorm.model.node_glem import GLEM + from graphstorm.trainer import GLEMNodePredictionTrainer + + my_dataset = GSgnnNodeData("my_graph", "/path/to/part_config") + target_idx = {"my_node_type": target_nodes_tensor} + my_data_loader = GSgnnNodeDataLoader( + my_dataset, target_idx, fanout=[10], batch_size=1024, device='cpu') + my_model = GLEM(alpha_l2norm=0.0, target_ntype="my_node_type") + + trainer = GLEMNodePredictionTrainer(my_model, topk_model_to_save=1) + + trainer.fit(my_data_loader, num_epochs=2) """ def __init__(self, model, topk_model_to_save=1): super(GLEMNodePredictionTrainer, self).__init__(model, topk_model_to_save) diff --git a/python/graphstorm/trainer/gsgnn_trainer.py b/python/graphstorm/trainer/gsgnn_trainer.py index 488add2e7c..90b656c9f9 100644 --- a/python/graphstorm/trainer/gsgnn_trainer.py +++ b/python/graphstorm/trainer/gsgnn_trainer.py @@ -37,6 +37,9 @@ class GSgnnTrainer(): It contains functions that can be used in the implementing classes' `fit` and `eval` functions. + To implement your own trainers, extend this class and add implementations + for the `fit` and `eval` functions. + Parameters ---------- model : GSgnnModel diff --git a/python/graphstorm/trainer/np_trainer.py b/python/graphstorm/trainer/np_trainer.py index 68bb52217d..f565a069d6 100644 --- a/python/graphstorm/trainer/np_trainer.py +++ b/python/graphstorm/trainer/np_trainer.py @@ -47,6 +47,26 @@ class GSgnnNodePredictionTrainer(GSgnnTrainer): The GNN model for node prediction. topk_model_to_save : int The top K model to save. + + Example + ------- + + .. code:: python + + from graphstorm.dataloading import GSgnnNodeDataLoader + from graphstorm.dataset import GSgnnNodeData + from graphstorm.model.node_gnn import GSgnnNodeModel + from graphstorm.trainer import GSgnnNodePredictionTrainer + + my_dataset = GSgnnNodeData("my_graph", "/path/to/part_config") + target_idx = {"my_node_type": target_nodes_tensor} + my_data_loader = GSgnnNodeDataLoader( + my_dataset, target_idx, fanout=[10], batch_size=1024, device='cpu') + my_model = GSgnnNodeModel(alpha_l2norm=0.0) + + trainer = GSgnnNodePredictionTrainer(my_model, topk_model_to_save=1) + + trainer.fit(my_data_loader, num_epochs=2) """ def __init__(self, model, topk_model_to_save=1): super(GSgnnNodePredictionTrainer, self).__init__(model, topk_model_to_save)