Skip to content

Commit

Permalink
Update docs for Dataset and DataLoader
Browse files Browse the repository at this point in the history
  • Loading branch information
Xiang Song committed Apr 29, 2024
1 parent 6ccfa64 commit 03fa76f
Show file tree
Hide file tree
Showing 2 changed files with 40 additions and 18 deletions.
47 changes: 31 additions & 16 deletions docs/source/advanced/own-models.rst
Original file line number Diff line number Diff line change
Expand Up @@ -198,37 +198,52 @@ the ``ip_config`` argument specifies a ip configuration file, which contains the

Replace DGL DataLoader with the GraphStorm's dataset and dataloader
`````````````````````````````````````````````````````````````````````
Because the GraphStorm uses distributed graphs, we need to first load the partitioned graph, which is created in the :ref:`Step 1 <step-1>`, with the `GSgnnNodeTrainData <https://github.com/awslabs/graphstorm/blob/main/python/graphstorm/dataloading/dataset.py#L469>`_ class (for edge tasks, GraphStorm also provides `GSgnnEdgeTrainData <https://github.com/awslabs/graphstorm/blob/main/python/graphstorm/dataloading/dataset.py#L216>`_). The ``GSgnnNodeTrainData`` could be created as shown in the code below.
Because the GraphStorm uses distributed graphs, we need to first load the partitioned graph, which is created in the :ref:`Step 1 <step-1>`, with the `GSgnnData <https://github.com/awslabs/graphstorm/blob/main/python/graphstorm/dataloading/dataset.py#L57>`_ class (for edge tasks, the same class is used). The ``GSgnnData`` could be created as shown in the code below.

.. code-block:: python
train_data = GSgnnNodeTrainData(config.graph_name,
config.part_config,
train_ntypes=config.target_ntype,
node_feat_field=node_feat_fields,
label_field=config.label_field)
train_data = GSgnnData(config.part_config,
node_feat_field=node_feat_fields)
Arguments of this class include the partition configuration JSON file path, which are the outputs of the :ref:`Step 1 <step-1>`. The ``graph_name`` can be found in the JSON file.
Arguments of this class include the partition configuration JSON file path, which are the outputs of the :ref:`Step 1 <step-1>`.

The other values, the ``train_ntypes``, the ``label_field``, and the ``node_feat_field``, should be consistent with the values in the raw data :ref:`input configuration JSON <input-config>` defined in the :ref:`Step 1 <step-1>`. The ``train_ntypes`` is the ``node_type`` that has ``labels`` specified. The ``label_fields`` is the value specified in ``label_col`` of the ``train_ntype``. The ``node_feat_field`` is a dictionary, whose key is the values of ``node_type``, and value is the values of ``feature_name``.
The ``node_feat_field``, should be consistent with the values in the raw data :ref:`input configuration JSON <input-config>` defined in the :ref:`Step 1 <step-1>`. The ``node_feat_field`` is a dictionary, whose keys are the values of ``node_type``, and values are the values of ``feature_name``.

Then we can put this dataset into GraphStorm's `GSgnnNodeDataLoader <https://github.com/awslabs/graphstorm/blob/main/python/graphstorm/dataloading/dataloading.py#L544>`_, which is like:
Then we can put this dataset into GraphStorm's `GSgnnNodeDataLoader <https://github.com/awslabs/graphstorm/blob/main/python/graphstorm/dataloading/dataloading.py#L1237>`_, which is like:

.. code-block:: python
# Get train idx
train_idxs = train_data.get_node_train_set(config.target_ntype)
# Define the GraphStorm train dataloader
dataloader = GSgnnNodeDataLoader(train_data, train_data.train_idxs, fanout=config.fanout,
batch_size=config.batch_size, device=device, train_task=True)
dataloader = GSgnnNodeDataLoader(train_data,
train_idxs, fanout=config.fanout,
batch_size=config.batch_size,
node_feats=node_feat_fields,
label_field=config.label_field,
device=device, train_task=True)
# Optional: Define the evaluation dataloader
eval_dataloader = GSgnnNodeDataLoader(train_data, train_data.val_idxs,fanout=config.fanout,
batch_size=config.eval_batch_size, device=device,
val_idxs = train_data.get_node_val_set(eval_ntype)
eval_dataloader = GSgnnNodeDataLoader(train_data,
val_idxs,
fanout=config.fanout,
batch_size=config.eval_batch_size,
node_feats=node_feat_fields,
label_field=config.label_field,
device=device,
train_task=False)
# Optional: Define the evaluation dataloader
test_dataloader = GSgnnNodeDataLoader(train_data, train_data.test_idxs,fanout=config.fanout,
batch_size=config.eval_batch_size, device=device,
test_idxs = train_data.get_node_test_set(eval_ntype)
test_dataloader = GSgnnNodeDataLoader(train_data,
test_idxs,
fanout=config.fanout,
batch_size=config.eval_batch_size,
node_feats=node_feat_fields,
label_field=config.label_field,device=device,
train_task=False)
GraphStorm provides a set of dataloaders for different GML tasks. Here we deal with a node task, hence using the node dataloader, which takes the graph data created above as the first argument. The second argument is the label index that the GraphStorm dataset extracts from the graph as indicated in the target nodes' ``train_mask``, ``val_mask``, and ``test_mask``, which are automatically generated by GraphStorm graph construction tool with the specified ``split_pct`` field. The ``GSgnnNodeTrainData`` automatically extracts these indexes out and set its properties so that you can directly use them like ``graph_data.train_idxs`` and ``graph_data.val_idxs``, and ``graph_data.test_idxs``. The rest of arguments are similar to the common training flow, except that we set the ``train_task`` to be ``False`` for the evaluation and test dataloader.
GraphStorm provides a set of dataloaders for different GML tasks. Here we deal with a node task, hence using the node dataloader, which takes the graph data created above as the first argument. The second argument is the label index that the GraphStorm dataset extracts from the graph as indicated in the target nodes' ``train_mask``, ``val_mask``, and ``test_mask``, which are automatically generated by GraphStorm graph construction tool with the specified ``split_pct`` field. The ``GSgnnData`` provides functions to get the indexes of train data, validation data and test data through ``get_node_train_set``, ``get_node_val_set`` and ``get_node_test_set``, respectively. The ``label_field`` is also required by the GSgnnNodeDataLoader to get the labels for model training and evaluation. The rest of arguments are similar to the common training flow, except that we set the ``train_task`` to be ``False`` for the evaluation and test dataloader.

Use GraphStorm's model trainer to wrap your model and attach evaluator and task tracker to it
````````````````````````````````````````````````````````````````````````````````````````````````
Expand Down
11 changes: 9 additions & 2 deletions python/graphstorm/dataloading/dataloading.py
Original file line number Diff line number Diff line change
Expand Up @@ -355,7 +355,11 @@ class GSgnnEdgeDataLoader(GSgnnEdgeDataLoaderBase):
from graphstorm.trainer import GSgnnEdgePredictionTrainer
ep_data = GSgnnData(...)
ep_dataloader = GSgnnEdgeDataLoader(ep_data, target_idx, fanout=[15, 10], batch_size=128)
target_idx = ep_data.get_edge_train_set(...)
ep_dataloader = GSgnnEdgeDataLoader(
ep_data, target_idx,
fanout=[15, 10], batch_size=128,
label_field=config.label_field)
ep_trainer = GSgnnEdgePredictionTrainer(...)
ep_trainer.fit(ep_dataloader, num_epochs=10)
"""
Expand Down Expand Up @@ -679,6 +683,7 @@ class GSgnnLinkPredictionDataLoader(GSgnnLinkPredictionDataLoaderBase):
from graphstorm.trainer import GSgnnLinkPredictionTrainer
lp_data = GSgnnData(...)
target_idx = lp_data.get_edge_train_set(...)
lp_dataloader = GSgnnLinkPredictionDataLoader(lp_data, target_idx, fanout=[15, 10],
num_negative_edges=10, batch_size=128)
lp_trainer = GSgnnLinkPredictionTrainer(...)
Expand Down Expand Up @@ -1542,9 +1547,11 @@ class GSgnnNodeDataLoader(GSgnnNodeDataLoaderBase):
from graphstorm.trainer import GSgnnNodePredictionTrainer
np_data = GSgnnData(...)
target_idx = np_data.get_node_train_set(...)
np_dataloader = GSgnnNodeDataLoader(np_data, target_idx, fanout=[15, 10],
batch_size=128,
label_field="label", node_feats="feat")
label_field="label",
node_feats="feat")
np_trainer = GSgnnNodePredictionTrainer(...)
np_trainer.fit(np_dataloader, num_epochs=10)
"""
Expand Down

0 comments on commit 03fa76f

Please sign in to comment.