Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add edge hard negative support. #678

Merged
merged 18 commits into from
Jan 11, 2024
166 changes: 165 additions & 1 deletion python/graphstorm/config/argument.py
Original file line number Diff line number Diff line change
Expand Up @@ -1874,9 +1874,143 @@ def lp_edge_weight_for_loss(self):

return None

@property
def train_etypes_negative_dstnode(self):
""" The list of canonical etypes that have hard negative edges
constructed by corrupting destination nodes.

The format of the arguement should be:
train_etypes_negative_dstnode:
- src_type,rel_type0,dst_type:negative_nid_field
- src_type,rel_type1,dst_type:negative_nid_field
Each edge type can have different fields storing the hard negatives.

or
train_etypes_negative_dstnode:
- negative_nid_field
All the edge types use the same filed storing the hard negatives.
"""
# pylint: disable=no-member
if hasattr(self, "_train_etypes_negative_dstnode"):
assert self.task_type == BUILTIN_TASK_LINK_PREDICTION, \
"Hard negative only works with link prediction"
hard_negatives = self._train_etypes_negative_dstnode
if len(hard_negatives) == 1 and \
":" not in hard_negatives[0]:
# global feat_name
return hard_negatives[0]

# per edge type feature
hard_negative_dict = {}
for hard_negative in hard_negatives:
negative_info = hard_negative.split(":")
classicsong marked this conversation as resolved.
Show resolved Hide resolved
assert len(negative_info) == 2, \
"negative dstnode information must be provided in format of " \
f"src,relation,dst:feature_name, but get {hard_negative}"

etype = tuple(negative_info[0].split(","))
assert len(etype) == 3, \
f"Edge type must in format of (src,relation,dst), but get {etype}"

assert etype not in hard_negative_dict, \
f"You already specify the fixed negative of {etype} " \
f"as {hard_negative_dict[etype]}"

hard_negative_dict[etype] = negative_info[1]
return hard_negative_dict

# By default fixed negative is not used
return None

@property
def num_train_hard_negatives(self):
""" Number of hard negatives per edge type

The format of the arguement should be:
num_train_hard_negatives:
- src_type,rel_type0,dst_type:num_negatives
- src_type,rel_type1,dst_type:num_negatives
Each edge type can have different number of hard negatives.

or
num_train_hard_negatives:
- num_negatives
All the edge types use the same number of hard negatives.
"""
# pylint: disable=no-member
if hasattr(self, "_num_train_hard_negatives"):
assert self.task_type == BUILTIN_TASK_LINK_PREDICTION, \
"Hard negative only works with link prediction"
num_negatives = self._num_train_hard_negatives
if len(num_negatives) == 1 and \
":" not in num_negatives[0]:
# global feat_name
return int(num_negatives[0])

# per edge type feature
num_hard_negative_dict = {}
for num_negative in num_negatives:
negative_info = num_negative.split(":")
classicsong marked this conversation as resolved.
Show resolved Hide resolved
etype = tuple(negative_info[0].split(","))
assert etype not in num_hard_negative_dict, \
f"You already specify the fixed negative of {etype} " \
f"as {num_hard_negative_dict[etype]}"

num_hard_negative_dict[etype] = int(negative_info[1])
return num_hard_negative_dict

return None

@property
def eval_etypes_negative_dstnode(self):
""" The list of canonical etypes that have predefined negative edges
constructed by corrupting destination nodes.

The format of the arguement should be:
eval_etypes_negative_dstnode:
- src_type,rel_type0,dst_type:negative_nid_field
- src_type,rel_type1,dst_type:negative_nid_field
Each edge type can have different fields storing the fixed negatives.

or
eval_etypes_negative_dstnode:
- negative_nid_field
All the edge types use the same filed storing the fixed negatives.
"""
# pylint: disable=no-member
if hasattr(self, "_eval_etypes_negative_dstnode"):
assert self.task_type == BUILTIN_TASK_LINK_PREDICTION, \
"Fixed negative only works with link prediction"
fixed_negatives = self._eval_etypes_negative_dstnode
if len(fixed_negatives) == 1 and \
":" not in fixed_negatives[0]:
# global feat_name
return fixed_negatives[0]

# per edge type feature
fixed_negative_dict = {}
for fixed_negative in fixed_negatives:
negative_info = fixed_negative.split(":")
assert len(negative_info) == 2, \
"negative dstnode information must be provided in format of " \
f"src,relation,dst:feature_name, but get {fixed_negative}"

etype = tuple(negative_info[0].split(","))
assert len(etype) == 3, \
f"Edge type must in format of (src,relation,dst), but get {etype}"
assert etype not in fixed_negative_dict, \
f"You already specify the fixed negative of {etype} " \
f"as {fixed_negative_dict[etype]}"

fixed_negative_dict[etype] = negative_info[1]
return fixed_negative_dict
classicsong marked this conversation as resolved.
Show resolved Hide resolved

# By default fixed negative is not used
return None

@property
def train_etype(self):
""" The list of canonical etype that will be added as
""" The list of canonical etypes that will be added as
training target with the target e type(s)

If not provided, all edge types will be used as training target.
Expand Down Expand Up @@ -2480,6 +2614,36 @@ def _add_link_prediction_args(parser):
"metrics of each edge type to select the best model"
"2) '--model-select-etype query,adds,item': Use the evaluation "
"metric of the edge type (query,adds,item) to select the best model")
group.add_argument("--train-etypes-negative-dstnode", nargs='+',
type=str, default=argparse.SUPPRESS,
help="Edge feature field name for user defined negative destination ndoes "
"for training. The negative nodes are used to construct hard negative edges "
"by corrupting positive edges' destination nodes."
"It can be in following format: "
"1) '--train-etypes-negative-dstnode negative_nid_field', "
"if all edge types use the same negative destination node filed."
"2) '--train-etypes-negative-dstnode query,adds,asin:neg0 query,clicks,asin:neg1 ...'"
"Different edge types have different negative destination node fields."
)
group.add_argument("--eval-etypes-negative-dstnode", nargs='+',
type=str, default=argparse.SUPPRESS,
help="Edge feature field name for user defined negative destination ndoes "
"for evaluation. The negative nodes are used to construct negative edges "
"by corrupting test edges' destination nodes."
"It can be in following format: "
"1) '--eval-etypes-negative-dstnode negative_nid_field', "
"if all edge types use the same negative destination node filed."
"2) '--eval-etypes-negative-dstnode query,adds,asin:neg0 query,clicks,asin:neg1 ...'"
"Different edge types have different negative destination node fields."
)
group.add_argument("--num-train-hard-negatives", nargs='+',
type=str, default=argparse.SUPPRESS,
help="Number of hard negatives for each edge type during training."
"It can be in following format: "
"1) '--num-train-hard-negatives 10', "
"if all edge types use the same number of hard negatives."
"2) '--num-train-hard-negatives query,adds,asin:5 query,clicks,asin:10 ...'"
"Different edge types have different number of hard negatives.")

return parser

Expand Down
24 changes: 20 additions & 4 deletions python/graphstorm/dataloading/dataloading.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,9 @@
JointLocalUniform,
InbatchJointUniform,
FastMultiLayerNeighborSampler,
DistributedFileSampler)
DistributedFileSampler,
GSHardEdgeDstNegativeSampler,
GSFixedEdgeDstNegativeSampler)
from .utils import trim_data, modify_fanout_for_target_etype
from .dataset import GSDistillData

Expand Down Expand Up @@ -368,6 +370,7 @@ def fanout(self):
BUILTIN_FAST_LP_JOINT_NEG_SAMPLER = 'fast_joint'
BUILTIN_FAST_LP_LOCALUNIFORM_NEG_SAMPLER = 'fast_localuniform'
BUILTIN_FAST_LP_LOCALJOINT_NEG_SAMPLER = 'fast_localjoint'
BUILTIN_LP_FIXED_NEG_SAMPLER = 'fixed'

class GSgnnLinkPredictionDataLoaderBase():
""" The base class of link prediction dataloader.
Expand Down Expand Up @@ -972,9 +975,11 @@ class GSgnnLinkPredictionTestDataLoader():
When test is huge, using fixed_test_size
can save validation and test time.
Default: None.
fixed_edge_dst_negative_field: str or list of str
The feature field(s) that store the fixed negative set for each edge.
"""
def __init__(self, dataset, target_idx, batch_size, num_negative_edges, fanout=None,
fixed_test_size=None):
def __init__(self, dataset, target_idx, batch_size, num_negative_edges,
fanout=None, fixed_test_size=None, fixed_edge_dst_negative_field=None):
self._data = dataset
self._fanout = fanout
for etype in target_idx:
Expand All @@ -991,6 +996,7 @@ def __init__(self, dataset, target_idx, batch_size, num_negative_edges, fanout=N
"is %d, which is smaller than the expected"
"test size %d, force it to %d",
etype, len(t_idx), self._fixed_test_size[etype], len(t_idx))
self._fixed_edge_dst_negative_field = fixed_edge_dst_negative_field
self._negative_sampler = self._prepare_negative_sampler(num_negative_edges)
self._reinit_dataset()

Expand All @@ -1008,7 +1014,13 @@ def _reinit_dataset(self):
def _prepare_negative_sampler(self, num_negative_edges):
# the default negative sampler is uniform sampler
self._neg_sample_type = BUILTIN_LP_UNIFORM_NEG_SAMPLER
negative_sampler = GlobalUniform(num_negative_edges)

if self._fixed_edge_dst_negative_field:
negative_sampler = GSFixedEdgeDstNegativeSampler(self._fixed_edge_dst_negative_field)
self._neg_sample_type = BUILTIN_LP_FIXED_NEG_SAMPLER
else:
negative_sampler = GlobalUniform(num_negative_edges)
self._neg_sample_type = BUILTIN_LP_UNIFORM_NEG_SAMPLER
return negative_sampler

def __iter__(self):
Expand Down Expand Up @@ -1057,6 +1069,10 @@ def _prepare_negative_sampler(self, num_negative_edges):
# the default negative sampler is uniform sampler
negative_sampler = JointUniform(num_negative_edges)
self._neg_sample_type = BUILTIN_LP_JOINT_NEG_SAMPLER
if self._fixed_edge_dst_negative_field:
negative_sampler = GSHardEdgeDstNegativeSampler(num_negative_edges,
self._fixed_edge_dst_negative_field,
negative_sampler)
return negative_sampler

################ Minibatch DataLoader (Node classification) #######################
Expand Down
Loading