Skip to content

Commit

Permalink
update
Browse files Browse the repository at this point in the history
  • Loading branch information
Xiang Song committed Dec 20, 2023
1 parent 7961e42 commit f9b8740
Show file tree
Hide file tree
Showing 2 changed files with 33 additions and 259 deletions.
47 changes: 13 additions & 34 deletions python/graphstorm/dataloading/sampler.py
Original file line number Diff line number Diff line change
Expand Up @@ -224,40 +224,19 @@ def _gen_neg_pair(pos_pair, canonical_etype):
src, _, pos_dst, neg_dst = random_neg_pairs[canonical_etype]
return (src, None, pos_dst, neg_dst)

hard_negatives = g.edges[canonical_etype].data[dst_negative_field][eids]
# It is possible that different edges may have different number of
# pre-defined negatives. For pre-defined negatives, the corresponding
# value in `hard_negatives` will be integers representing the node ids.
# For others, they will be -1s meaning there are missing fixed negatives.
if th.sum(hard_negatives == -1) == 0:
# Fast track, there is no -1 in hard_negatives
num_hard_neg = hard_negatives.shape[1]
if self._k < num_hard_neg:
hard_negatives = hard_negatives[:,:self._k]
return (src, None, pos_dst, hard_negatives)
else:
# random negative are needed
random_neg_pairs = \
self._negative_sampler.gen_neg_pairs(g,
{canonical_etype:pos_pair})
src, _, pos_dst, neg_dst = random_neg_pairs[canonical_etype]
neg_dst[:,:num_hard_neg] = hard_negatives
return (src, None, pos_dst, neg_dst)
else:
# slow track, we need to handle cases when there are -1s
hard_negatives, _ = th.sort(hard_negatives, dim=1, descending=True)

random_neg_pairs = \
self._negative_sampler.gen_neg_pairs(g, {canonical_etype:pos_pair})
src, _, pos_dst, neg_dst = random_neg_pairs[canonical_etype]
for i in range(len(eids)):
hard_negative = hard_negatives[i]
# ignore -1s
hard_negative = hard_negative[hard_negative > -1]
num_hard_neg = hard_negative.shape[0]
neg_dst[i][:num_hard_neg if num_hard_neg < self._k else self._k] = \
hard_negative[:num_hard_neg if num_hard_neg < self._k else self._k]
return (src, _, pos_dst, neg_dst)
fixed_negatives = g.edges[canonical_etype].data[dst_negative_field][eids]

# Users may use HardEdgeDstNegativeTransform
# to prepare the fixed negatives.
assert th.sum(fixed_negatives == -1) == 0, \
"When using fixed negative destination nodes to construct testing edges," \
"it is required that for each positive edge there are enough negative " \
f"destination nodes. Please check the {dst_negative_field} feature " \
f"of edge type {canonical_etype}"

num_fixed_neg = fixed_negatives.shape[1]
logging.debug("The number of fixed negative is %d", num_fixed_neg)
return (src, None, pos_dst, fixed_negatives)

if isinstance(pos_pairs, Mapping):
pos_neg_tuple = {}
Expand Down
245 changes: 20 additions & 225 deletions tests/unit-tests/test_dataloading.py
Original file line number Diff line number Diff line change
Expand Up @@ -58,7 +58,8 @@

from graphstorm.dataloading.sampler import InbatchJointUniform
from graphstorm.dataloading.sampler import GlobalUniform
from graphstorm.dataloading.sampler import GSHardEdgeDstNegativeSampler
from graphstorm.dataloading.sampler import (GSHardEdgeDstNegativeSampler,
GSFixedEdgeDstNegativeSampler)

from graphstorm.dataloading.dataset import (prepare_batch_input,
prepare_batch_edge_input)
Expand Down Expand Up @@ -1719,252 +1720,47 @@ def check_none_hard_negs(hard_neg_sampler, target_etype, hard_neg_data):
num_negs, "hard_negative", sampler, {})
check_none_hard_negs(hard_sampler, etype2, hard2)

def test_hard_edge_dst_negative_sample_gen_neg_pairs_complex_case():
# test GSHardEdgeDstNegativeSampler.gen_neg_pairs with slow track when not all edges have enough predefined negatives
num_nodes = 1000
# test GSHardEdgeDstNegativeSampler._generate when all some pos edges do not have enough hard negatives defined
num_negs = 10
etype0, etype1, etype2, hard0, hard1, hard2, src, dst, g = _create_hard_neg_graph(num_nodes, num_negs)

# not enough predefined hard negatives
# for hard0[0] and hard0[1]
hard0[0] = th.randperm(num_nodes)[:4]
hard0[0][-1] = -1
hard0[0][-2] = -1
hard0[1][-1] = -1

# not enough predefined hard negatives
# for hard0[0] and hard0[1]
hard1[0] = th.randperm(num_nodes)[:num_negs]
hard1[1] = th.randperm(num_nodes)[:num_negs]
hard1[0][-1] = -1
hard1[0][-2] = -1
hard1[1][-1] = -1

# not enough predefined hard negatives
# for hard0[0] and hard0[1]
hard2[0] = th.randperm(num_nodes)[:num_negs*2]
hard2[1] = th.randperm(num_nodes)[:num_negs*2]
hard2[0][-1] = -1
hard2[0][-2] = -1
hard2[1][-1] = -1

num_edges = 10
pos_pairs = {etype0: (th.arange(10), th.arange(10)),
etype1: (th.arange(10), th.arange(10)),
etype2: (th.arange(10), th.arange(10))}

def test_missing_hard_negs(neg_dst, hard_neg_data, num_hard_neg):
# hardx[0][-1] and hardx[0][-2] is -1,
# which means hardx[0] does not enough predefined negatives
# Random sample will be applied to -1s.
hard_neg_dst = neg_dst[0][:num_hard_neg]
hard_neg_rand_0 = hard_neg_dst[-1]
hard_neg_rand_1 = hard_neg_dst[-2]
hard_neg_dst = set(hard_neg_dst[:-2].tolist())
rand_neg_dst = neg_dst[0][num_hard_neg:]
rand_neg_dst = set(rand_neg_dst.tolist())
hard_neg_set = set(hard_neg_data[0].tolist())
assert hard_neg_dst.issubset(hard_neg_set)
assert len(rand_neg_dst) == 0 or \
rand_neg_dst.issubset(hard_neg_set) is False

rand_0_check = hard_neg_rand_0 not in hard_neg_set
rand_1_check = hard_neg_rand_1 not in hard_neg_set

# hardx[1][-1] is -1,
# which means hardx[0] does not enough predefined negatives
# Random sample will be applied to -1s.
hard_neg_dst = neg_dst[1][:num_hard_neg]
hard_neg_rand_2 = hard_neg_dst[-1]
hard_neg_dst = set(hard_neg_dst[:-1].tolist())
rand_neg_dst = neg_dst[1][num_hard_neg:]
rand_neg_dst = set(rand_neg_dst.tolist())
hard_neg_set = set(hard_neg_data[1].tolist())
assert hard_neg_dst.issubset(hard_neg_set)
assert len(rand_neg_dst) == 0 or \
rand_neg_dst.issubset(hard_neg_set) is False

rand_2_check = hard_neg_rand_2 not in hard_neg_set
# The chance is very to to have rand_0_check,
# rand_1_check and rand_2_check be true at the same time
# The change is (4/1000)^3
assert rand_0_check or rand_1_check or rand_2_check

def check_hard_negs(pos_neg_tuple, etype, hard_neg_data,
num_hard_neg, check_missing_hard_neg):
neg_src, _, pos_dst, neg_dst = pos_neg_tuple[etype]

assert len(neg_src) == num_edges
assert len(pos_dst) == num_edges
assert neg_dst.shape[0] == num_edges
assert neg_dst.shape[1] == num_negs
assert_equal(src[:10].numpy(), neg_src.numpy())
assert_equal(dst[:10].numpy(), pos_dst.numpy())

if check_missing_hard_neg:
test_missing_hard_negs(neg_dst, hard_neg_data, num_hard_neg)

start = 2 if check_missing_hard_neg else 0
for i in range(start, num_edges):
hard_neg_dst = neg_dst[i][:num_hard_neg]
hard_neg_dst = set(hard_neg_dst.tolist())
rand_neg_dst = neg_dst[i][num_hard_neg:]
rand_neg_dst = set(rand_neg_dst.tolist())
hard_neg_set = set(hard_neg_data[i].tolist())
assert hard_neg_dst.issubset(hard_neg_set)
assert len(rand_neg_dst) == 0 or \
rand_neg_dst.issubset(hard_neg_set) is False

sampler = GlobalUniform(num_negs)
hard_sampler = GSHardEdgeDstNegativeSampler(num_negs, "hard_negative", sampler)
pos_neg_tuple = hard_sampler.gen_neg_pairs(g, pos_pairs)

# Case 1:
# 1. hard_negative field is string
# 2. The is not enough predefined negative for gen_neg_pairs
# 3. fast track
# 4. slow track (-1 exists in hard neg feature)
#
# expected behavior:
# 1. Only 4 hard negatives are returned
# 2. Others will be random negatives
check_hard_negs(pos_neg_tuple, etype0, hard0, hard0.shape[1], check_missing_hard_neg=True)
# Case 2:
# 1. hard_negative field is string
# 2. num_negs == total number of predefined negatives
# 3. fast track
# 4. slow track (-1 exists in hard neg feature)
#
# expected behavior:
# 1. all negatives are predefined negatives
check_hard_negs(pos_neg_tuple, etype1, hard1, hard1.shape[1], check_missing_hard_neg=True)
# Case 3:
# 1. hard_negative field is string
# 2. num_negs < total number of predefined negatives
# 3. fast track
# 4. slow track (-1 exists in hard neg feature)
#
# expected behavior:
# 1. all negatives are predefined negatives
check_hard_negs(pos_neg_tuple, etype2, hard2, num_negs, check_missing_hard_neg=False)


def test_hard_edge_dst_negative_sample_gen_neg_pairs():
def test_edge_fixed_dst_negative_sample_gen_neg_pairs():
# test GSHardEdgeDstNegativeSampler.gen_neg_pairs with fast track when all edges have enough predefined negatives
num_nodes = 1000
# test GSHardEdgeDstNegativeSampler._generate when all some pos edges do not have enough hard negatives defined
num_negs = 10
etype0, etype1, etype2, hard0, hard1, hard2, src, dst, g = _create_hard_neg_graph(num_nodes, num_negs)
etype0, etype1, etype2, hard0, _, _, src, dst, g = _create_hard_neg_graph(num_nodes, num_negs)

num_edges = 10
pos_pairs = {etype0: (th.arange(10), th.arange(10)),
etype1: (th.arange(10), th.arange(10)),
etype2: (th.arange(10), th.arange(10))}

def check_hard_negs(pos_neg_tuple, etype, hard_neg_data, num_hard_neg):
def check_fixed_negs(pos_neg_tuple, etype, hard_neg_data):
neg_src, _, pos_dst, neg_dst = pos_neg_tuple[etype]

assert len(neg_src) == num_edges
assert len(pos_dst) == num_edges
assert neg_dst.shape[0] == num_edges
assert neg_dst.shape[1] == num_negs
assert_equal(src[:10].numpy(), neg_src.numpy())
assert_equal(dst[:10].numpy(), pos_dst.numpy())

# check hard negative
for i in range(num_edges):
hard_neg_dst = neg_dst[i][:num_hard_neg]
hard_neg_dst = set(hard_neg_dst.tolist())
rand_neg_dst = neg_dst[i][num_hard_neg:]
rand_neg_dst = set(rand_neg_dst.tolist())
hard_neg_set = set(hard_neg_data[i].tolist())
assert hard_neg_dst.issubset(hard_neg_set)
assert len(rand_neg_dst) == 0 or \
rand_neg_dst.issubset(hard_neg_set) is False
assert_equal(hard_neg_data[:10].numpy(), neg_dst.numpy())

sampler = GlobalUniform(num_negs)
hard_sampler = GSHardEdgeDstNegativeSampler(num_negs, "hard_negative", sampler)
hard_sampler = GSFixedEdgeDstNegativeSampler("hard_negative")
pos_neg_tuple = hard_sampler.gen_neg_pairs(g, pos_pairs)
check_fixed_negs(pos_neg_tuple, etype0, hard0)

# Case 1:
# 1. hard_negative field is string
# 2. The is not enough predefined negative for gen_neg_pairs
# 3. fast track
#
# expected behavior:
# 1. Only 4 hard negatives are returned
# 2. Others will be random negatives
check_hard_negs(pos_neg_tuple, etype0, hard0, hard0.shape[1])
# Case 2:
# 1. hard_negative field is string
# 2. num_negs == total number of predefined negatives
# 3. fast track
#
# expected behavior:
# 1. all negatives are predefined negatives
check_hard_negs(pos_neg_tuple, etype1, hard1, hard1.shape[1])
# Case 3:
# 1. hard_negative field is string
# 2. num_negs < total number of predefined negatives
# 3. fast track
#
# expected behavior:
# 1. all negatives are predefined negatives
check_hard_negs(pos_neg_tuple, etype2, hard2, num_negs)

hard_sampler = GSHardEdgeDstNegativeSampler(num_negs,
{etype0: "hard_negative",
hard_sampler = GSFixedEdgeDstNegativeSampler({etype0: "hard_negative",
etype1: "hard_negative",
etype2: "hard_negative"},
sampler)
# Case 4:
# 1. hard_negative field is dict
# 2. The is not enough predefined negative for gen_neg_pairs
# 3. fast track
#
# expected behavior:
# 1. Only 4 hard negatives are returned
# 2. Others will be random negatives
check_hard_negs(pos_neg_tuple, etype0, hard0, hard0.shape[1])
# Case 5:
# 1. hard_negative field is dict
# 2. num_negs == total number of predefined negatives
# 3. fast track
#
# expected behavior:
# 1. all negatives are predefined negatives
check_hard_negs(pos_neg_tuple, etype1, hard1, hard1.shape[1])
# Case 6:
# 1. hard_negative field is dict
# 2. num_negs < total number of predefined negatives
# 3. fast track
#
# expected behavior:
# 1. all negatives are predefined negatives
check_hard_negs(pos_neg_tuple, etype2, hard2, num_negs)
etype2: "hard_negative"})
check_fixed_negs(pos_neg_tuple, etype0, hard0)

def check_none_hard_negs(pos_neg_tuple, etype, hard_neg_data):
neg_src, _, pos_dst, neg_dst = pos_neg_tuple[etype]

assert len(neg_src) == num_edges
assert len(pos_dst) == num_edges
assert neg_dst.shape[0] == num_edges
assert neg_dst.shape[1] == num_negs
assert_equal(src[:10].numpy(), neg_src.numpy())
assert_equal(dst[:10].numpy(), pos_dst.numpy())

for i in range(num_edges):
hard_neg_dst = set(neg_dst[i].tolist())
hard_neg_set = set(hard_neg_data[i].tolist())
assert hard_neg_dst.issubset(hard_neg_set) is False
# each positive edge should have enough fixed negatives
hard0[0][-1] = -1
fail = False
try:
pos_neg_tuple = hard_sampler.gen_neg_pairs(g, pos_pairs)
except:
fail = True
assert fail

# Case 9:
# dst_negative_field is not provided
hard_sampler = GSHardEdgeDstNegativeSampler(
num_negs, {}, sampler)
pos_neg_tuple = hard_sampler.gen_neg_pairs(g, pos_pairs)
check_none_hard_negs(pos_neg_tuple, etype2, hard2)

@pytest.mark.parametrize("num_pos", [2, 10])
@pytest.mark.parametrize("num_neg", [5, 20])
Expand Down Expand Up @@ -1992,8 +1788,7 @@ def test_inbatch_joint_neg_sampler(num_pos, num_neg):


if __name__ == '__main__':
test_hard_edge_dst_negative_sample_gen_neg_pairs_complex_case()
test_hard_edge_dst_negative_sample_gen_neg_pairs()
test_edge_fixed_dst_negative_sample_gen_neg_pairs()
test_hard_edge_dst_negative_sample_generate_complex_case()
test_hard_edge_dst_negative_sample_generate()
test_inbatch_joint_neg_sampler(10, 20)
Expand Down

0 comments on commit f9b8740

Please sign in to comment.