Skip to content

Commit

Permalink
Update
Browse files Browse the repository at this point in the history
  • Loading branch information
Xiang Song committed Dec 20, 2023
1 parent 776a862 commit fc35ba1
Show file tree
Hide file tree
Showing 8 changed files with 208 additions and 58 deletions.
8 changes: 5 additions & 3 deletions python/graphstorm/dataloading/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,8 +24,9 @@
from .dataloading import GSgnnAllEtypeLinkPredictionDataLoader
from .dataloading import GSgnnEdgeDataLoader
from .dataloading import GSgnnNodeDataLoader, GSgnnNodeSemiSupDataLoader
from .dataloading import GSgnnLinkPredictionTestDataLoader
from .dataloading import GSgnnLinkPredictionJointTestDataLoader
from .dataloading import (GSgnnLinkPredictionTestDataLoader,
GSgnnLinkPredictionJointTestDataLoader,
GSgnnLinkPredictionPredefinedTestDataLoader)
from .dataloading import (FastGSgnnLinkPredictionDataLoader,
FastGSgnnLPLocalJointNegDataLoader,
FastGSgnnLPJointNegDataLoader,
Expand All @@ -43,7 +44,8 @@
BUILTIN_LP_JOINT_NEG_SAMPLER,
BUILTIN_LP_INBATCH_JOINT_NEG_SAMPLER,
BUILTIN_LP_LOCALUNIFORM_NEG_SAMPLER,
BUILTIN_LP_LOCALJOINT_NEG_SAMPLER)
BUILTIN_LP_LOCALJOINT_NEG_SAMPLER,
BUILTIN_LP_FIXED_NEG_SAMPLER)
from .dataloading import BUILTIN_LP_ALL_ETYPE_UNIFORM_NEG_SAMPLER
from .dataloading import BUILTIN_LP_ALL_ETYPE_JOINT_NEG_SAMPLER
from .dataloading import (BUILTIN_FAST_LP_UNIFORM_NEG_SAMPLER,
Expand Down
80 changes: 61 additions & 19 deletions python/graphstorm/dataloading/dataloading.py
Original file line number Diff line number Diff line change
Expand Up @@ -488,6 +488,10 @@ class GSgnnLinkPredictionDataLoader(GSgnnLinkPredictionDataLoaderBase):
The node types that requires to construct node features.
construct_feat_fanout : int
The fanout required to construct node features.
edge_dst_negative_field: str or dict of str
The feature field(s) that store the hard negative edges for each edge type.
num_hard_negs: int or dict of int
The number of hard negatives per positive edge for each edge type
Examples
------------
Expand All @@ -510,7 +514,9 @@ class GSgnnLinkPredictionDataLoader(GSgnnLinkPredictionDataLoaderBase):
def __init__(self, dataset, target_idx, fanout, batch_size, num_negative_edges, device='cpu',
train_task=True, reverse_edge_types_map=None, exclude_training_targets=False,
edge_mask_for_gnn_embeddings='train_mask',
construct_feat_ntype=None, construct_feat_fanout=5):
construct_feat_ntype=None, construct_feat_fanout=5,
edge_dst_negative_field=None,
num_hard_negs=None):
super().__init__(dataset, target_idx, fanout)
self._device = device
for etype in target_idx:
Expand All @@ -524,7 +530,9 @@ def __init__(self, dataset, target_idx, fanout, batch_size, num_negative_edges,
reverse_edge_types_map=reverse_edge_types_map,
edge_mask_for_gnn_embeddings=edge_mask_for_gnn_embeddings,
construct_feat_ntype=construct_feat_ntype,
construct_feat_fanout=construct_feat_fanout)
construct_feat_fanout=construct_feat_fanout,
edge_dst_negative_field=edge_dst_negative_field,
num_hard_negs=num_hard_negs)

def _prepare_negative_sampler(self, num_negative_edges):
# the default negative sampler is uniform sampler
Expand All @@ -535,7 +543,8 @@ def _prepare_dataloader(self, dataset, target_idxs, fanout,
num_negative_edges, batch_size, device, train_task=True,
exclude_training_targets=False, reverse_edge_types_map=None,
edge_mask_for_gnn_embeddings=None, construct_feat_ntype=None,
construct_feat_fanout=5):
construct_feat_fanout=5, edge_dst_negative_field=None,
num_hard_negs=None):
g = dataset.g
if construct_feat_ntype is None:
construct_feat_ntype = []
Expand All @@ -556,6 +565,11 @@ def _prepare_dataloader(self, dataset, target_idxs, fanout,
sampler = MultiLayerNeighborSamplerForReconstruct(sampler,
dataset, construct_feat_ntype, construct_feat_fanout)
negative_sampler = self._prepare_negative_sampler(num_negative_edges)
if edge_dst_negative_field is not None:
negative_sampler = GSHardEdgeDstNegativeSampler(num_negative_edges,
edge_dst_negative_field,
negative_sampler,
num_hard_negs)

# edge loader
if train_task:
Expand Down Expand Up @@ -642,7 +656,8 @@ def _prepare_dataloader(self, dataset, target_idxs, fanout,
num_negative_edges, batch_size, device, train_task=True,
exclude_training_targets=False, reverse_edge_types_map=None,
edge_mask_for_gnn_embeddings=None, construct_feat_ntype=None,
construct_feat_fanout=5):
construct_feat_fanout=5, edge_dst_negative_field=None,
num_hard_negs=None):
g = dataset.g
if construct_feat_ntype is None:
construct_feat_ntype = []
Expand All @@ -663,6 +678,11 @@ def _prepare_dataloader(self, dataset, target_idxs, fanout,
sampler = MultiLayerNeighborSamplerForReconstruct(sampler,
dataset, construct_feat_ntype, construct_feat_fanout)
negative_sampler = self._prepare_negative_sampler(num_negative_edges)
if edge_dst_negative_field is not None:
negative_sampler = GSHardEdgeDstNegativeSampler(num_negative_edges,
edge_dst_negative_field,
negative_sampler,
num_hard_negs)

# edge loader
if train_task:
Expand Down Expand Up @@ -975,11 +995,9 @@ class GSgnnLinkPredictionTestDataLoader():
When test is huge, using fixed_test_size
can save validation and test time.
Default: None.
fixed_edge_dst_negative_field: str or list of str
The feature field(s) that store the fixed negative set for each edge.
"""
def __init__(self, dataset, target_idx, batch_size, num_negative_edges,
fanout=None, fixed_test_size=None, fixed_edge_dst_negative_field=None):
fanout=None, fixed_test_size=None):
self._data = dataset
self._fanout = fanout
for etype in target_idx:
Expand All @@ -996,7 +1014,6 @@ def __init__(self, dataset, target_idx, batch_size, num_negative_edges,
"is %d, which is smaller than the expected"
"test size %d, force it to %d",
etype, len(t_idx), self._fixed_test_size[etype], len(t_idx))
self._fixed_edge_dst_negative_field = fixed_edge_dst_negative_field
self._negative_sampler = self._prepare_negative_sampler(num_negative_edges)
self._reinit_dataset()

Expand All @@ -1014,13 +1031,7 @@ def _reinit_dataset(self):
def _prepare_negative_sampler(self, num_negative_edges):
# the default negative sampler is uniform sampler
self._neg_sample_type = BUILTIN_LP_UNIFORM_NEG_SAMPLER

if self._fixed_edge_dst_negative_field:
negative_sampler = GSFixedEdgeDstNegativeSampler(self._fixed_edge_dst_negative_field)
self._neg_sample_type = BUILTIN_LP_FIXED_NEG_SAMPLER
else:
negative_sampler = GlobalUniform(num_negative_edges)
self._neg_sample_type = BUILTIN_LP_UNIFORM_NEG_SAMPLER
negative_sampler = GlobalUniform(num_negative_edges)
return negative_sampler

def __iter__(self):
Expand Down Expand Up @@ -1069,10 +1080,41 @@ def _prepare_negative_sampler(self, num_negative_edges):
# the default negative sampler is uniform sampler
negative_sampler = JointUniform(num_negative_edges)
self._neg_sample_type = BUILTIN_LP_JOINT_NEG_SAMPLER
if self._fixed_edge_dst_negative_field:
negative_sampler = GSHardEdgeDstNegativeSampler(num_negative_edges,
self._fixed_edge_dst_negative_field,
negative_sampler)
return negative_sampler

class GSgnnLinkPredictionPredefinedTestDataLoader(GSgnnLinkPredictionTestDataLoader):
""" Link prediction minibatch dataloader for validation and test
with predefined negatives.
Parameters
-----------
dataset: GSgnnEdgeData
The GraphStorm edge dataset
target_idx : dict of Tensors
The target edges for prediction
batch_size: int
Batch size
fanout: int
Evaluation fanout for computing node embedding
fixed_test_size: int
Fixed number of test data used in evaluation.
If it is none, use the whole testset.
When test is huge, using fixed_test_size
can save validation and test time.
Default: None.
fixed_edge_dst_negative_field: str or list of str
The feature field(s) that store the fixed negative set for each edge.
"""
def __init__(self, dataset, target_idx, batch_size, fixed_edge_dst_negative_field,
fanout=None, fixed_test_size=None):
self._fixed_edge_dst_negative_field = fixed_edge_dst_negative_field
super().__init__(dataset, target_idx, batch_size,
num_negative_edges=0, # num_negative_edges is not used
fanout=fanout, fixed_test_size=fixed_test_size)

def _prepare_negative_sampler(self, _):
negative_sampler = GSFixedEdgeDstNegativeSampler(self._fixed_edge_dst_negative_field)
self._neg_sample_type = BUILTIN_LP_FIXED_NEG_SAMPLER
return negative_sampler

################ Minibatch DataLoader (Node classification) #######################
Expand Down
2 changes: 1 addition & 1 deletion python/graphstorm/dataloading/sampler.py
Original file line number Diff line number Diff line change
Expand Up @@ -179,7 +179,7 @@ def _generate(self, g, eids, canonical_etype):
hard_negative[:num_hard_neg if num_hard_neg < self._k else self._k]
return src, neg

def gen_neg_pairs(self, g, pos_pairs):
def gen_neg_pairs(self, _):
""" TODO: Do not support generating negative pairs for evaluation in the same way as
generating negative pairs for training now.
Please use GSFixedEdgeDstNegativeSampler instead.
Expand Down
29 changes: 21 additions & 8 deletions python/graphstorm/run/gsgnn_lp/gsgnn_lm_lp.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,8 +30,9 @@
GSgnnLPInBatchJointNegDataLoader)
from graphstorm.dataloading import GSgnnAllEtypeLPJointNegDataLoader
from graphstorm.dataloading import GSgnnAllEtypeLinkPredictionDataLoader
from graphstorm.dataloading import GSgnnLinkPredictionTestDataLoader
from graphstorm.dataloading import GSgnnLinkPredictionJointTestDataLoader
from graphstorm.dataloading import (GSgnnLinkPredictionTestDataLoader,
GSgnnLinkPredictionJointTestDataLoader,
GSgnnLinkPredictionPredefinedTestDataLoader)
from graphstorm.dataloading import (BUILTIN_LP_UNIFORM_NEG_SAMPLER,
BUILTIN_LP_JOINT_NEG_SAMPLER,
BUILTIN_LP_INBATCH_JOINT_NEG_SAMPLER,
Expand Down Expand Up @@ -129,10 +130,14 @@ def main(config_args):
raise ValueError('Unknown negative sampler')
dataloader = dataloader_cls(train_data, train_data.train_idxs, [],
config.batch_size, config.num_negative_edges, device,
train_task=True)
train_task=True,
edge_dst_negative_field=config.train_etypes_negative_dstnode,
num_hard_negs=config.num_train_hard_negatives)

# TODO(zhengda) let's use full-graph inference for now.
if config.eval_negative_sampler == BUILTIN_LP_UNIFORM_NEG_SAMPLER:
if config.eval_etypes_negative_dstnode is not None:
test_dataloader_cls = GSgnnLinkPredictionPredefinedTestDataLoader
elif config.eval_negative_sampler == BUILTIN_LP_UNIFORM_NEG_SAMPLER:
test_dataloader_cls = GSgnnLinkPredictionTestDataLoader
elif config.eval_negative_sampler == BUILTIN_LP_JOINT_NEG_SAMPLER:
test_dataloader_cls = GSgnnLinkPredictionJointTestDataLoader
Expand All @@ -143,11 +148,19 @@ def main(config_args):
val_dataloader = None
test_dataloader = None
if len(train_data.val_idxs) > 0:
val_dataloader = test_dataloader_cls(train_data, train_data.val_idxs,
config.eval_batch_size, config.num_negative_edges_eval)
if config.eval_etypes_negative_dstnode is not None:
val_dataloader = test_dataloader_cls(train_data, train_data.val_idxs,
config.eval_batch_size, config.eval_etypes_negative_dstnode)
else:
val_dataloader = test_dataloader_cls(train_data, train_data.val_idxs,
config.eval_batch_size, config.num_negative_edges_eval)
if len(train_data.test_idxs) > 0:
test_dataloader = test_dataloader_cls(train_data, train_data.test_idxs,
config.eval_batch_size, config.num_negative_edges_eval)
if config.eval_etypes_negative_dstnode is not None:
test_dataloader = test_dataloader_cls(train_data, train_data.test_idxs,
config.eval_batch_size, config.eval_etypes_negative_dstnode)
else:
test_dataloader = test_dataloader_cls(train_data, train_data.test_idxs,
config.eval_batch_size, config.num_negative_edges_eval)

# Preparing input layer for training or inference.
# The input layer can pre-compute node features in the preparing step if needed.
Expand Down
39 changes: 29 additions & 10 deletions python/graphstorm/run/gsgnn_lp/gsgnn_lp.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,8 +30,9 @@
GSgnnLPInBatchJointNegDataLoader)
from graphstorm.dataloading import GSgnnAllEtypeLPJointNegDataLoader
from graphstorm.dataloading import GSgnnAllEtypeLinkPredictionDataLoader
from graphstorm.dataloading import GSgnnLinkPredictionTestDataLoader
from graphstorm.dataloading import GSgnnLinkPredictionJointTestDataLoader
from graphstorm.dataloading import (GSgnnLinkPredictionTestDataLoader,
GSgnnLinkPredictionJointTestDataLoader,
GSgnnLinkPredictionPredefinedTestDataLoader)
from graphstorm.dataloading import (BUILTIN_LP_UNIFORM_NEG_SAMPLER,
BUILTIN_LP_JOINT_NEG_SAMPLER,
BUILTIN_LP_INBATCH_JOINT_NEG_SAMPLER,
Expand Down Expand Up @@ -152,10 +153,14 @@ def main(config_args):
reverse_edge_types_map=config.reverse_edge_types_map,
exclude_training_targets=config.exclude_training_targets,
construct_feat_ntype=config.construct_feat_ntype,
construct_feat_fanout=config.construct_feat_fanout)
construct_feat_fanout=config.construct_feat_fanout,
edge_dst_negative_field=config.train_etypes_negative_dstnode,
num_hard_negs=config.num_train_hard_negatives)

# TODO(zhengda) let's use full-graph inference for now.
if config.eval_negative_sampler == BUILTIN_LP_UNIFORM_NEG_SAMPLER:
if config.eval_etypes_negative_dstnode is not None:
test_dataloader_cls = GSgnnLinkPredictionPredefinedTestDataLoader
elif config.eval_negative_sampler == BUILTIN_LP_UNIFORM_NEG_SAMPLER:
test_dataloader_cls = GSgnnLinkPredictionTestDataLoader
elif config.eval_negative_sampler == BUILTIN_LP_JOINT_NEG_SAMPLER:
test_dataloader_cls = GSgnnLinkPredictionJointTestDataLoader
Expand All @@ -166,13 +171,27 @@ def main(config_args):
val_dataloader = None
test_dataloader = None
if len(train_data.val_idxs) > 0:
val_dataloader = test_dataloader_cls(train_data, train_data.val_idxs,
config.eval_batch_size, config.num_negative_edges_eval, config.eval_fanout,
fixed_test_size=config.fixed_test_size)
if config.eval_etypes_negative_dstnode is not None:
val_dataloader = test_dataloader_cls(train_data, train_data.val_idxs,
config.eval_batch_size,
fixed_edge_dst_negative_field=config.eval_etypes_negative_dstnode,
fanout=config.eval_fanout,
fixed_test_size=config.fixed_test_size)
else:
val_dataloader = test_dataloader_cls(train_data, train_data.val_idxs,
config.eval_batch_size, config.num_negative_edges_eval, config.eval_fanout,
fixed_test_size=config.fixed_test_size)
if len(train_data.test_idxs) > 0:
test_dataloader = test_dataloader_cls(train_data, train_data.test_idxs,
config.eval_batch_size, config.num_negative_edges_eval, config.eval_fanout,
fixed_test_size=config.fixed_test_size)
if config.eval_etypes_negative_dstnode is not None:
test_dataloader = test_dataloader_cls(train_data, train_data.test_idxs,
config.eval_batch_size,
fixed_edge_dst_negative_field=config.eval_etypes_negative_dstnode,
fanout=config.eval_fanout,
fixed_test_size=config.fixed_test_size)
else:
test_dataloader = test_dataloader_cls(train_data, train_data.test_idxs,
config.eval_batch_size, config.num_negative_edges_eval, config.eval_fanout,
fixed_test_size=config.fixed_test_size)

# Preparing input layer for training or inference.
# The input layer can pre-compute node features in the preparing step if needed.
Expand Down
Loading

0 comments on commit fc35ba1

Please sign in to comment.