Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add edge hard negative support. #678

Merged
merged 18 commits into from
Jan 11, 2024
121 changes: 120 additions & 1 deletion python/graphstorm/config/argument.py
Original file line number Diff line number Diff line change
Expand Up @@ -1874,9 +1874,128 @@ def lp_edge_weight_for_loss(self):

return None

@property
def train_hard_edge_dstnode_negative(self):
classicsong marked this conversation as resolved.
Show resolved Hide resolved
""" The list of canonical etypes that have hard negative sets
classicsong marked this conversation as resolved.
Show resolved Hide resolved

The format of the arguement should be:
train_hard_edge_dstnode_negative:
- src_type,rel_type0,dst_type:negative_nid_field
- src_type,rel_type1,dst_type:negative_nid_field
Each edge type can have different fields storing the hard negatives.

or
train_hard_edge_dstnode_negative:
- negative_nid_field
All the edge types use the same filed storing the hard negatives.
"""
# pylint: disable=no-member
if hasattr(self, "_train_hard_edge_dstnode_negative"):
assert self.task_type == BUILTIN_TASK_LINK_PREDICTION, \
"Hard negative only works with link prediction"
hard_negatives = self._train_hard_edge_dstnode_negative
if len(hard_negatives) == 1 and \
":" not in hard_negatives[0]:
# global feat_name
return hard_negatives[0]

# per edge type feature
hard_negative_dict = {}
for hard_negative in hard_negatives:
negative_info = hard_negative.split(":")
classicsong marked this conversation as resolved.
Show resolved Hide resolved
etype = tuple(negative_info[0].split(","))
assert etype not in hard_negative_dict, \
f"You already specify the fixed negative of {etype} " \
f"as {hard_negative_dict[etype]}"

hard_negative_dict[etype] = negative_info[1]
return hard_negative_dict

# By default fixed negative is not used
return None

@property
def num_hard_negatives(self):
classicsong marked this conversation as resolved.
Show resolved Hide resolved
""" Number of hard negatives per edge type

The format of the arguement should be:
num_hard_negatives:
- src_type,rel_type0,dst_type:num_negatives
- src_type,rel_type1,dst_type:num_negatives
Each edge type can have different number of hard negatives.

or
num_hard_negatives:
- num_negatives
All the edge types use the same number of hard negatives.
"""
# pylint: disable=no-member
if hasattr(self, "_num_hard_negatives"):
assert self.task_type == BUILTIN_TASK_LINK_PREDICTION, \
"Hard negative only works with link prediction"
num_negatives = self._num_hard_negatives
if len(num_negatives) == 1 and \
":" not in num_negatives[0]:
# global feat_name
return int(num_negatives[0])

# per edge type feature
num_hard_negative_dict = {}
for num_negative in num_negatives:
negative_info = num_negative.split(":")
classicsong marked this conversation as resolved.
Show resolved Hide resolved
etype = tuple(negative_info[0].split(","))
assert etype not in num_hard_negative_dict, \
f"You already specify the fixed negative of {etype} " \
f"as {num_hard_negative_dict[etype]}"

num_hard_negative_dict[etype] = int(negative_info[1])
return num_hard_negative_dict

return None

@property
def eval_fixed_edge_dstnode_negative(self):
classicsong marked this conversation as resolved.
Show resolved Hide resolved
""" The list of canonical etypes that have predefined negative sets

The format of the arguement should be:
eval_fixed_edge_dstnode_negative:
- src_type,rel_type0,dst_type:negative_nid_field
- src_type,rel_type1,dst_type:negative_nid_field
Each edge type can have different fields storing the fixed negatives.

or
eval_fixed_edge_dstnode_negative:
- negative_nid_field
All the edge types use the same filed storing the fixed negatives.
"""
# pylint: disable=no-member
if hasattr(self, "_eval_fixed_edge_dstnode_negative"):
assert self.task_type == BUILTIN_TASK_LINK_PREDICTION, \
"Fixed negative only works with link prediction"
fixed_negatives = self._eval_fixed_edge_dstnode_negative
if len(fixed_negatives) == 1 and \
":" not in fixed_negatives[0]:
# global feat_name
return fixed_negatives[0]

# per edge type feature
fixed_negative_dict = {}
for fixed_negative in fixed_negatives:
negative_info = fixed_negative.split(":")
etype = tuple(negative_info[0].split(","))
assert etype not in fixed_negative_dict, \
f"You already specify the fixed negative of {etype} " \
f"as {fixed_negative_dict[etype]}"

fixed_negative_dict[etype] = negative_info[1]
return fixed_negative_dict
classicsong marked this conversation as resolved.
Show resolved Hide resolved

# By default fixed negative is not used
return None

@property
def train_etype(self):
""" The list of canonical etype that will be added as
""" The list of canonical etypes that will be added as
training target with the target e type(s)

If not provided, all edge types will be used as training target.
Expand Down
18 changes: 15 additions & 3 deletions python/graphstorm/dataloading/dataloading.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,8 @@
JointLocalUniform,
InbatchJointUniform,
FastMultiLayerNeighborSampler,
DistributedFileSampler)
DistributedFileSampler,
GSHardEdgeDstNegative)
from .utils import trim_data, modify_fanout_for_target_etype
from .dataset import GSDistillData

Expand Down Expand Up @@ -972,9 +973,11 @@ class GSgnnLinkPredictionTestDataLoader():
When test is huge, using fixed_test_size
can save validation and test time.
Default: None.
fixed_edge_dst_negative_field: str or list of str
The feature field(s) that store the fixed negative set for each edge.
"""
def __init__(self, dataset, target_idx, batch_size, num_negative_edges, fanout=None,
fixed_test_size=None):
def __init__(self, dataset, target_idx, batch_size, num_negative_edges,
fanout=None, fixed_test_size=None, fixed_edge_dst_negative_field=None):
self._data = dataset
self._fanout = fanout
for etype in target_idx:
Expand All @@ -991,6 +994,7 @@ def __init__(self, dataset, target_idx, batch_size, num_negative_edges, fanout=N
"is %d, which is smaller than the expected"
"test size %d, force it to %d",
etype, len(t_idx), self._fixed_test_size[etype], len(t_idx))
self._fixed_edge_dst_negative_field = fixed_edge_dst_negative_field
self._negative_sampler = self._prepare_negative_sampler(num_negative_edges)
self._reinit_dataset()

Expand All @@ -1009,6 +1013,10 @@ def _prepare_negative_sampler(self, num_negative_edges):
# the default negative sampler is uniform sampler
self._neg_sample_type = BUILTIN_LP_UNIFORM_NEG_SAMPLER
negative_sampler = GlobalUniform(num_negative_edges)
if self._fixed_edge_dst_negative_field:
negative_sampler = GSHardEdgeDstNegative(num_negative_edges,
self._fixed_edge_dst_negative_field,
negative_sampler)
return negative_sampler

def __iter__(self):
Expand Down Expand Up @@ -1057,6 +1065,10 @@ def _prepare_negative_sampler(self, num_negative_edges):
# the default negative sampler is uniform sampler
negative_sampler = JointUniform(num_negative_edges)
self._neg_sample_type = BUILTIN_LP_JOINT_NEG_SAMPLER
if self._fixed_edge_dst_negative_field:
negative_sampler = GSHardEdgeDstNegative(num_negative_edges,
self._fixed_edge_dst_negative_field,
negative_sampler)
return negative_sampler

################ Minibatch DataLoader (Node classification) #######################
Expand Down
186 changes: 186 additions & 0 deletions python/graphstorm/dataloading/sampler.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,8 @@
from dgl.dataloading import NeighborSampler
from dgl.transforms import to_block

from ..utils import is_wholegraph

class LocalUniform(Uniform):
"""Negative sampler that randomly chooses negative destination nodes
for each source node according to a uniform distribution.
Expand Down Expand Up @@ -70,6 +72,190 @@ def _generate(self, g, eids, canonical_etype):
dst = F.randint(shape, dtype, ctx, 0, self._local_neg_nids[vtype].shape[0])
return src, self._local_neg_nids[vtype][dst]

class GSHardEdgeDstNegative(object):
classicsong marked this conversation as resolved.
Show resolved Hide resolved
""" GraphStorm negativer sampler that chooses negative destination nodes
from a fixed set to create negative edges.

Parameters
----------
k: int
Number of negatives to sample.
dst_negative_field: str or dict of str
The field storing the hard negatives.
negative_sampler: sampler
The negative sampler to generate negatives
if there is not enough hard negatives.
num_hard_negs: int or dict of int
Number of hard negatives.
"""
def __init__(self, k, dst_negative_field, negative_sampler, num_hard_negs=None):
assert is_wholegraph() is False, \
"Hard negative is not supported for WholeGraph."
self._dst_negative_field = dst_negative_field
self._k = k
self._negative_sampler = negative_sampler
self._num_hard_negs = num_hard_negs

def _generate(self, g, eids, canonical_etype):
classicsong marked this conversation as resolved.
Show resolved Hide resolved
if isinstance(self._dst_negative_field, str):
dst_negative_field = self._dst_negative_field
elif canonical_etype in self._dst_negative_field:
dst_negative_field = self._dst_negative_field[canonical_etype]
else:
dst_negative_field = None

if isinstance(self._num_hard_negs, int):
required_num_hard_neg = self._num_hard_negs
elif canonical_etype in self._num_hard_negs:
required_num_hard_neg = self._num_hard_negs[canonical_etype]
else:
required_num_hard_neg = 0

if dst_negative_field is None or required_num_hard_neg == 0:
# no hard negative, fallback to random negative
return self._negative_sampler._generate(g, eids, canonical_etype)

hard_negatives = g.edges[canonical_etype].data[dst_negative_field][eids]
# It is possible that different edges may have different number of
# pre-defined negatives. For pre-defined negatives, the corresponding
# value in `hard_negatives` will be integers representing the node ids.
# For others, they will be -1s meaning there are missing fixed negatives.
if th.sum(hard_negatives == -1) == 0:
# Fast track, there is no -1 in hard_negatives
max_num_hard_neg = hard_negatives.shape[1]
neg_idx = th.randperm(max_num_hard_neg)
# shuffle the hard negatives
hard_negatives = hard_negatives[:,neg_idx]

print(f"{required_num_hard_neg} {max_num_hard_neg} {self._k}")
classicsong marked this conversation as resolved.
Show resolved Hide resolved
if required_num_hard_neg >= self._k and max_num_hard_neg >= self._k:
# All negative should be hard negative and
# there are enough hard negatives.
hard_negatives = hard_negatives[:,:self._k]
src, _ = g.find_edges(eids, etype=canonical_etype)
src = F.repeat(src, self._k, 0)
return src, hard_negatives.reshape((-1,))
else:
if required_num_hard_neg < max_num_hard_neg:
# Only need required_num_hard_neg hard negatives.
hard_negatives = hard_negatives[:,:required_num_hard_neg]
num_hard_neg = required_num_hard_neg
else:
# There is not enough hard negative to fill required_num_hard_neg
num_hard_neg = max_num_hard_neg

# There is not enough negatives
src, neg = self._negative_sampler._generate(g, eids, canonical_etype)
# replace random negatives with fixed negatives
neg = neg.reshape(-1, self._k)
neg[:,:num_hard_neg] = hard_negatives[:,:num_hard_neg]
return src, neg.reshape((-1,))
else:
classicsong marked this conversation as resolved.
Show resolved Hide resolved
# slow track, we need to handle cases when there are -1s
hard_negatives, _ = th.sort(hard_negatives, dim=1, descending=True)

src, neg = self._negative_sampler._generate(g, eids, canonical_etype)
for i in range(len(eids)):
hard_negative = hard_negatives[i]
# ignore -1s
hard_negative = hard_negative[hard_negative > -1]
max_num_hard_neg = hard_negative.shape[0]
hard_negative = hard_negative[th.randperm(max_num_hard_neg)]

if required_num_hard_neg < max_num_hard_neg:
# Only need required_num_hard_neg hard negatives.
hard_negative = hard_negative[:required_num_hard_neg]
num_hard_neg = required_num_hard_neg
else:
num_hard_neg = max_num_hard_neg

# replace random negatives with fixed negatives
neg[i*self._k:i*self._k + (num_hard_neg \
if num_hard_neg < self._k else self._k)] = \
hard_negative[:num_hard_neg if num_hard_neg < self._k else self._k]
return src, neg

def gen_neg_pairs(self, g, pos_pairs):
classicsong marked this conversation as resolved.
Show resolved Hide resolved
""" Returns negative examples associated with positive examples.
It only return dst negatives.

Parameters
----------
g : DGLGraph
The graph.
pos_pairs : (Tensor, Tensor) or dict[etype, (Tensor, Tensor)]
The positive node pairs

Returns
-------
tuple[Tensor, Tensor, Tensor, Tensor] or
dict[etype, tuple(Tensor, Tensor Tensor, Tensor)
The returned [positive source, negative source,
postive destination, negatve destination]
tuples as pos-neg examples.
"""
def _gen_neg_pair(pos_pair, canonical_etype):
classicsong marked this conversation as resolved.
Show resolved Hide resolved
src, pos_dst = pos_pair
eids = g.edge_ids(src, pos_dst, etype=canonical_etype)

if isinstance(self._dst_negative_field, str):
dst_negative_field = self._dst_negative_field
elif canonical_etype in self._dst_negative_field:
dst_negative_field = self._dst_negative_field[canonical_etype]
else:
dst_negative_field = None

if dst_negative_field is None:
random_neg_pairs = \
self._negative_sampler.gen_neg_pairs(g, {canonical_etype:pos_pair})
src, _, pos_dst, neg_dst = random_neg_pairs[canonical_etype]
return (src, None, pos_dst, neg_dst)

hard_negatives = g.edges[canonical_etype].data[dst_negative_field][eids]
# It is possible that different edges may have different number of
# pre-defined negatives. For pre-defined negatives, the corresponding
# value in `hard_negatives` will be integers representing the node ids.
# For others, they will be -1s meaning there are missing fixed negatives.
if th.sum(hard_negatives == -1) == 0:
# Fast track, there is no -1 in hard_negatives
num_hard_neg = hard_negatives.shape[1]
if self._k < num_hard_neg:
hard_negatives = hard_negatives[:,:self._k]
return (src, None, pos_dst, hard_negatives)
else:
# random negative are needed
random_neg_pairs = \
self._negative_sampler.gen_neg_pairs(g,
{canonical_etype:pos_pair})
src, _, pos_dst, neg_dst = random_neg_pairs[canonical_etype]
neg_dst[:,:num_hard_neg] = hard_negatives
return (src, None, pos_dst, neg_dst)
else:
# slow track, we need to handle cases when there are -1s
hard_negatives, _ = th.sort(hard_negatives, dim=1, descending=True)

random_neg_pairs = \
self._negative_sampler.gen_neg_pairs(g, {canonical_etype:pos_pair})
src, _, pos_dst, neg_dst = random_neg_pairs[canonical_etype]
for i in range(len(eids)):
hard_negative = hard_negatives[i]
# ignore -1s
hard_negative = hard_negative[hard_negative > -1]
num_hard_neg = hard_negative.shape[0]
neg_dst[i][:num_hard_neg if num_hard_neg < self._k else self._k] = \
hard_negative[:num_hard_neg if num_hard_neg < self._k else self._k]
return (src, _, pos_dst, neg_dst)

if isinstance(pos_pairs, Mapping):
pos_neg_tuple = {}
for canonical_etype, pos_pair in pos_pairs.items():
pos_neg_tuple[canonical_etype] = _gen_neg_pair(pos_pair, canonical_etype)
else:
assert len(g.canonical_etypes) == 1, \
'please specify a dict of etypes and ids for graphs with multiple edge types'
pos_neg_tuple = _gen_neg_pair(pos_pairs, canonical_etype)
return pos_neg_tuple

class GlobalUniform(Uniform):
"""Negative sampler that randomly chooses negative destination nodes
for each source node according to a uniform distribution.
Expand Down
Loading