Skip to content

Commit

Permalink
Refactor Dataloading and Dataset (#795)
Browse files Browse the repository at this point in the history
*Issue #, if available:*
#755 
#756 

*Description of changes:*
Change the Dataloader API and Dataset API according to #755 and #756.
Update the unit tests accordingly
Update the training and inference codes accordingly.

TODO: Update examples
TODO: Update docs.


By submitting this pull request, I confirm that you can use, modify,
copy, and redistribute this contribution, under the terms of your
choice.

---------

Co-authored-by: Xiang Song <[email protected]>
  • Loading branch information
classicsong and Xiang Song authored Apr 18, 2024
1 parent 8ee4db3 commit 9a3242f
Show file tree
Hide file tree
Showing 31 changed files with 1,871 additions and 1,556 deletions.
6 changes: 3 additions & 3 deletions docs/source/tutorials/own-data.rst
Original file line number Diff line number Diff line change
Expand Up @@ -244,8 +244,8 @@ Required DGL graph format
```````````````````````````
- a `dgl.heterograph <https://docs.dgl.ai/generated/dgl.heterograph.html#dgl.heterograph>`_.
- All nodes/edges features are set in nodes/edges' data field, and remember the feature names, which will be used in the later steps.
- For nodes' features, the common way to set features is like ``g.nodes['nodetypename'].data['featurename']=nodefeaturetensor``, The formal explanation of DGL's node feature could be found in the `Using node features <https://docs.dgl.ai/generated/dgl.DGLGraph.nodes.html>`_.
- For edges' features, the common way to set features is like ``g.edges['edgetypename'].data['featurename']=edgefeaturetensor``, The formal explanation of DGL's edge feature could be found in the `Using edge features <https://docs.dgl.ai/generated/dgl.DGLGraph.edges.html>`_.
- For nodes' features, the common way to set features is like ``g.nodes['nodetypename'].data['featurename']=nodefeaturetensor``, The formal explanation of DGL's node feature could be found in the `Using node features <https://docs.dgl.ai/generated/dgl.DGLGraph.nodes.html>`_. Please make sure every node feature is a 2D tensor.
- For edges' features, the common way to set features is like ``g.edges['edgetypename'].data['featurename']=edgefeaturetensor``, The formal explanation of DGL's edge feature could be found in the `Using edge features <https://docs.dgl.ai/generated/dgl.DGLGraph.edges.html>`_. Please make sure every edge feature is a 2D tensor.
- Save labels (for node/edge tasks) into the target nodes/edges as a feature, and remember the label feature names, which will be used in the later steps.
- The common way to set node-related labels as a feature is like ``g.nodes['predictnodetypename'].data['labelname']=nodelabeltensor``.
- The common way to set edge-related labels as a feature is like ``g.nodes['predictedgetypename'].data['labelname']=edgelabeltensor``.
Expand Down Expand Up @@ -435,7 +435,7 @@ Similar to the :ref:`Quick-Start <quick-start-standalone>` tutorial, users can l
--node-feat-name paper:feat author:feat subject:feat \
--restore-model-path /tmp/acm_nc/models/epoch-0 \
--save-prediction-path /tmp/acm_nc/predictions
# Link Prediction
python -m graphstorm.run.gs_link_prediction \
--inference \
Expand Down
8 changes: 8 additions & 0 deletions python/graphstorm/config/argument.py
Original file line number Diff line number Diff line change
Expand Up @@ -754,6 +754,14 @@ def input_activate(self):
"and 'relu' for torch.nn.functional.relu")
return None

@property
def edge_feat_name(self):
""" User defined edge feature name
Not used by GraphStorm, reserved for future usage.
"""
return None

@property
def node_feat_name(self):
""" User defined node feature name
Expand Down
5 changes: 1 addition & 4 deletions python/graphstorm/dataloading/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,10 +35,7 @@
GSgnnLinkPredictionDataLoaderBase,
GSgnnNodeDataLoaderBase)

from .dataset import GSgnnEdgeTrainData, GSgnnLPTrainData
from .dataset import GSgnnEdgeInferData
from .dataset import GSgnnNodeTrainData
from .dataset import GSgnnNodeInferData
from .dataset import GSgnnData

from .dataloading import (BUILTIN_LP_UNIFORM_NEG_SAMPLER,
BUILTIN_LP_JOINT_NEG_SAMPLER,
Expand Down
Loading

0 comments on commit 9a3242f

Please sign in to comment.