Skip to content

Commit

Permalink
Review comments.
Browse files Browse the repository at this point in the history
  • Loading branch information
thvasilo committed Sep 28, 2023
1 parent fa9aeea commit 029f11e
Show file tree
Hide file tree
Showing 4 changed files with 63 additions and 0 deletions.
20 changes: 20 additions & 0 deletions python/graphstorm/trainer/ep_trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -48,6 +48,26 @@ class GSgnnEdgePredictionTrainer(GSgnnTrainer):
The GNN model for edge prediction.
topk_model_to_save : int
The top K model to save.
Example
-------
.. code:: python
from graphstorm.dataloading import GSgnnEdgeDataLoader
from graphstorm.dataset import GSgnnEdgeData
from graphstorm.model import GSgnnEdgeModel
from graphstorm.trainer import GSgnnEdgePredictionTrainer
my_dataset = GSgnnEdgeData("my_graph", "/path/to/part_config")
target_idx = {"edge_type": target_edges_tensor}
my_data_loader = GSgnnEdgeDataLoader(
my_dataset, target_idx, fanout=[10], batch_size=1024)
my_model = GSgnnEdgeModel(alpha_l2norm=0.0)
trainer = GSgnnEdgePredictionTrainer(my_model, topk_model_to_save=1)
trainer.fit(my_data_loader, num_epochs=2)
"""
def __init__(self, model, topk_model_to_save):
super(GSgnnEdgePredictionTrainer, self).__init__(model, topk_model_to_save)
Expand Down
20 changes: 20 additions & 0 deletions python/graphstorm/trainer/glem_np_trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -52,6 +52,26 @@ class GLEMNodePredictionTrainer(GSgnnNodePredictionTrainer):
`model.node_glem.GLEM`.
topk_model_to_save : int
The top K model to save.
Example
-------
.. code:: python
from graphstorm.dataloading import GSgnnNodeDataLoader
from graphstorm.dataset import GSgnnNodeData
from graphstorm.model.node_glem import GLEM
from graphstorm.trainer import GLEMNodePredictionTrainer
my_dataset = GSgnnNodeData("my_graph", "/path/to/part_config")
target_idx = {"my_node_type": target_nodes_tensor}
my_data_loader = GSgnnNodeDataLoader(
my_dataset, target_idx, fanout=[10], batch_size=1024, device='cpu')
my_model = GLEM(alpha_l2norm=0.0, target_ntype="my_node_type")
trainer = GLEMNodePredictionTrainer(my_model, topk_model_to_save=1)
trainer.fit(my_data_loader, num_epochs=2)
"""
def __init__(self, model, topk_model_to_save=1):
super(GLEMNodePredictionTrainer, self).__init__(model, topk_model_to_save)
Expand Down
3 changes: 3 additions & 0 deletions python/graphstorm/trainer/gsgnn_trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,9 @@ class GSgnnTrainer():
It contains functions that can be used in the implementing classes'
`fit` and `eval` functions.
To implement your own trainers, extend this class and add implementations
for the `fit` and `eval` functions.
Parameters
----------
model : GSgnnModel
Expand Down
20 changes: 20 additions & 0 deletions python/graphstorm/trainer/np_trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,6 +47,26 @@ class GSgnnNodePredictionTrainer(GSgnnTrainer):
The GNN model for node prediction.
topk_model_to_save : int
The top K model to save.
Example
-------
.. code:: python
from graphstorm.dataloading import GSgnnNodeDataLoader
from graphstorm.dataset import GSgnnNodeData
from graphstorm.model.node_gnn import GSgnnNodeModel
from graphstorm.trainer import GSgnnNodePredictionTrainer
my_dataset = GSgnnNodeData("my_graph", "/path/to/part_config")
target_idx = {"my_node_type": target_nodes_tensor}
my_data_loader = GSgnnNodeDataLoader(
my_dataset, target_idx, fanout=[10], batch_size=1024, device='cpu')
my_model = GSgnnNodeModel(alpha_l2norm=0.0)
trainer = GSgnnNodePredictionTrainer(my_model, topk_model_to_save=1)
trainer.fit(my_data_loader, num_epochs=2)
"""
def __init__(self, model, topk_model_to_save=1):
super(GSgnnNodePredictionTrainer, self).__init__(model, topk_model_to_save)
Expand Down

0 comments on commit 029f11e

Please sign in to comment.