Skip to content

Commit

Permalink
resolve some comments
Browse files Browse the repository at this point in the history
  • Loading branch information
Xiang Song committed Dec 19, 2023
1 parent 497ab0e commit 7961e42
Show file tree
Hide file tree
Showing 4 changed files with 102 additions and 44 deletions.
79 changes: 62 additions & 17 deletions python/graphstorm/config/argument.py
Original file line number Diff line number Diff line change
Expand Up @@ -1875,25 +1875,26 @@ def lp_edge_weight_for_loss(self):
return None

@property
def train_hard_edge_dstnode_negative(self):
""" The list of canonical etypes that have hard negative sets
def train_etypes_negative_dstnode(self):
""" The list of canonical etypes that have hard negative edges
constructed by corrupting destination nodes.
The format of the arguement should be:
train_hard_edge_dstnode_negative:
train_etypes_negative_dstnode:
- src_type,rel_type0,dst_type:negative_nid_field
- src_type,rel_type1,dst_type:negative_nid_field
Each edge type can have different fields storing the hard negatives.
or
train_hard_edge_dstnode_negative:
train_etypes_negative_dstnode:
- negative_nid_field
All the edge types use the same filed storing the hard negatives.
"""
# pylint: disable=no-member
if hasattr(self, "_train_hard_edge_dstnode_negative"):
if hasattr(self, "_train_etypes_negative_dstnode"):
assert self.task_type == BUILTIN_TASK_LINK_PREDICTION, \
"Hard negative only works with link prediction"
hard_negatives = self._train_hard_edge_dstnode_negative
hard_negatives = self._train_etypes_negative_dstnode
if len(hard_negatives) == 1 and \
":" not in hard_negatives[0]:
# global feat_name
Expand All @@ -1903,7 +1904,14 @@ def train_hard_edge_dstnode_negative(self):
hard_negative_dict = {}
for hard_negative in hard_negatives:
negative_info = hard_negative.split(":")
assert len(negative_info) == 2, \
"negative dstnode information must be provided in format of " \
f"src,relation,dst:feature_name, but get {hard_negative}"

etype = tuple(negative_info[0].split(","))
assert len(etype) == 3, \
f"Edge type must in format of (src,relation,dst), but get {etype}"

assert etype not in hard_negative_dict, \
f"You already specify the fixed negative of {etype} " \
f"as {hard_negative_dict[etype]}"
Expand All @@ -1915,25 +1923,25 @@ def train_hard_edge_dstnode_negative(self):
return None

@property
def num_hard_negatives(self):
def num_train_hard_negatives(self):
""" Number of hard negatives per edge type
The format of the arguement should be:
num_hard_negatives:
num_train_hard_negatives:
- src_type,rel_type0,dst_type:num_negatives
- src_type,rel_type1,dst_type:num_negatives
Each edge type can have different number of hard negatives.
or
num_hard_negatives:
num_train_hard_negatives:
- num_negatives
All the edge types use the same number of hard negatives.
"""
# pylint: disable=no-member
if hasattr(self, "_num_hard_negatives"):
if hasattr(self, "_num_train_hard_negatives"):
assert self.task_type == BUILTIN_TASK_LINK_PREDICTION, \
"Hard negative only works with link prediction"
num_negatives = self._num_hard_negatives
num_negatives = self._num_train_hard_negatives
if len(num_negatives) == 1 and \
":" not in num_negatives[0]:
# global feat_name
Expand All @@ -1954,25 +1962,26 @@ def num_hard_negatives(self):
return None

@property
def eval_fixed_edge_dstnode_negative(self):
""" The list of canonical etypes that have predefined negative sets
def eval_etypes_negative_dstnode(self):
""" The list of canonical etypes that have predefined negative edges
constructed by corrupting destination nodes.
The format of the arguement should be:
eval_fixed_edge_dstnode_negative:
eval_etypes_negative_dstnode:
- src_type,rel_type0,dst_type:negative_nid_field
- src_type,rel_type1,dst_type:negative_nid_field
Each edge type can have different fields storing the fixed negatives.
or
eval_fixed_edge_dstnode_negative:
eval_etypes_negative_dstnode:
- negative_nid_field
All the edge types use the same filed storing the fixed negatives.
"""
# pylint: disable=no-member
if hasattr(self, "_eval_fixed_edge_dstnode_negative"):
if hasattr(self, "_eval_etypes_negative_dstnode"):
assert self.task_type == BUILTIN_TASK_LINK_PREDICTION, \
"Fixed negative only works with link prediction"
fixed_negatives = self._eval_fixed_edge_dstnode_negative
fixed_negatives = self._eval_etypes_negative_dstnode
if len(fixed_negatives) == 1 and \
":" not in fixed_negatives[0]:
# global feat_name
Expand All @@ -1982,7 +1991,13 @@ def eval_fixed_edge_dstnode_negative(self):
fixed_negative_dict = {}
for fixed_negative in fixed_negatives:
negative_info = fixed_negative.split(":")
assert len(negative_info) == 2, \
"negative dstnode information must be provided in format of " \
f"src,relation,dst:feature_name, but get {fixed_negative}"

etype = tuple(negative_info[0].split(","))
assert len(etype) == 3, \
f"Edge type must in format of (src,relation,dst), but get {etype}"
assert etype not in fixed_negative_dict, \
f"You already specify the fixed negative of {etype} " \
f"as {fixed_negative_dict[etype]}"
Expand Down Expand Up @@ -2599,6 +2614,36 @@ def _add_link_prediction_args(parser):
"metrics of each edge type to select the best model"
"2) '--model-select-etype query,adds,item': Use the evaluation "
"metric of the edge type (query,adds,item) to select the best model")
group.add_argument("--train-etypes-negative-dstnode", nargs='+',
type=str, default=argparse.SUPPRESS,
help="Edge feature field name for user defined negative destination ndoes "
"for training. The negative nodes are used to construct hard negative edges "
"by corrupting positive edges' destination nodes."
"It can be in following format: "
"1) '--train-etypes-negative-dstnode negative_nid_field', "
"if all edge types use the same negative destination node filed."
"2) '--train-etypes-negative-dstnode query,adds,asin:neg0 query,clicks,asin:neg1 ...'"
"Different edge types have different negative destination node fields."
)
group.add_argument("--eval-etypes-negative-dstnode", nargs='+',
type=str, default=argparse.SUPPRESS,
help="Edge feature field name for user defined negative destination ndoes "
"for evaluation. The negative nodes are used to construct negative edges "
"by corrupting test edges' destination nodes."
"It can be in following format: "
"1) '--eval-etypes-negative-dstnode negative_nid_field', "
"if all edge types use the same negative destination node filed."
"2) '--eval-etypes-negative-dstnode query,adds,asin:neg0 query,clicks,asin:neg1 ...'"
"Different edge types have different negative destination node fields."
)
group.add_argument("--num-train-hard-negatives", nargs='+',
type=str, default=argparse.SUPPRESS,
help="Number of hard negatives for each edge type during training."
"It can be in following format: "
"1) '--num-train-hard-negatives 10', "
"if all edge types use the same number of hard negatives."
"2) '--num-train-hard-negatives query,adds,asin:5 query,clicks,asin:10 ...'"
"Different edge types have different number of hard negatives.")

return parser

Expand Down
6 changes: 3 additions & 3 deletions python/graphstorm/dataloading/dataloading.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,7 @@
InbatchJointUniform,
FastMultiLayerNeighborSampler,
DistributedFileSampler,
GSHardEdgeDstNegative)
GSHardEdgeDstNegativeSampler)
from .utils import trim_data, modify_fanout_for_target_etype
from .dataset import GSDistillData

Expand Down Expand Up @@ -1014,7 +1014,7 @@ def _prepare_negative_sampler(self, num_negative_edges):
self._neg_sample_type = BUILTIN_LP_UNIFORM_NEG_SAMPLER
negative_sampler = GlobalUniform(num_negative_edges)
if self._fixed_edge_dst_negative_field:
negative_sampler = GSHardEdgeDstNegative(num_negative_edges,
negative_sampler = GSHardEdgeDstNegativeSampler(num_negative_edges,
self._fixed_edge_dst_negative_field,
negative_sampler)
return negative_sampler
Expand Down Expand Up @@ -1066,7 +1066,7 @@ def _prepare_negative_sampler(self, num_negative_edges):
negative_sampler = JointUniform(num_negative_edges)
self._neg_sample_type = BUILTIN_LP_JOINT_NEG_SAMPLER
if self._fixed_edge_dst_negative_field:
negative_sampler = GSHardEdgeDstNegative(num_negative_edges,
negative_sampler = GSHardEdgeDstNegativeSampler(num_negative_edges,
self._fixed_edge_dst_negative_field,
negative_sampler)
return negative_sampler
Expand Down
19 changes: 16 additions & 3 deletions python/graphstorm/dataloading/sampler.py
Original file line number Diff line number Diff line change
Expand Up @@ -72,8 +72,8 @@ def _generate(self, g, eids, canonical_etype):
dst = F.randint(shape, dtype, ctx, 0, self._local_neg_nids[vtype].shape[0])
return src, self._local_neg_nids[vtype][dst]

class GSHardEdgeDstNegative(object):
""" GraphStorm negativer sampler that chooses negative destination nodes
class GSHardEdgeDstNegativeSampler(object):
""" GraphStorm negative sampler that chooses negative destination nodes
from a fixed set to create negative edges.
Parameters
Expand Down Expand Up @@ -127,7 +127,6 @@ def _generate(self, g, eids, canonical_etype):
# shuffle the hard negatives
hard_negatives = hard_negatives[:,neg_idx]

print(f"{required_num_hard_neg} {max_num_hard_neg} {self._k}")
if required_num_hard_neg >= self._k and max_num_hard_neg >= self._k:
# All negative should be hard negative and
# there are enough hard negatives.
Expand Down Expand Up @@ -175,6 +174,20 @@ def _generate(self, g, eids, canonical_etype):
hard_negative[:num_hard_neg if num_hard_neg < self._k else self._k]
return src, neg

class GSFixedEdgeDstNegativeSampler(object):
""" GraphStorm negative sampler that uses fixed negative destination nodes
to create negative edges.
Parameters
----------
dst_negative_field: str or dict of str
The field storing the hard negatives.
"""
def __init__(self, dst_negative_field):
assert is_wholegraph() is False, \
"Hard negative is not supported for WholeGraph."
self._dst_negative_field = dst_negative_field

def gen_neg_pairs(self, g, pos_pairs):
""" Returns negative examples associated with positive examples.
It only return dst negatives.
Expand Down
Loading

0 comments on commit 7961e42

Please sign in to comment.