Skip to content

Commit

Permalink
Add unitest for HardEdgeDstNegativeTransform and fix bugs
Browse files Browse the repository at this point in the history
  • Loading branch information
Xiang Song committed Dec 12, 2023
1 parent 81a37ee commit fa1f257
Show file tree
Hide file tree
Showing 3 changed files with 105 additions and 8 deletions.
2 changes: 1 addition & 1 deletion python/graphstorm/gconstruct/construct_graph.py
Original file line number Diff line number Diff line change
Expand Up @@ -499,7 +499,7 @@ def process_edge_data(process_confs, node_id_map, arr_merger,
# For edge hard negative transformation ops, more information is needed
for op in hard_edge_neg_ops:
op.set_target_etype(edge_type)
op.set_id_map(id_map)
op.set_id_maps(id_map)

multiprocessing = do_multiprocess_transform(process_conf,
feat_ops,
Expand Down
25 changes: 19 additions & 6 deletions python/graphstorm/gconstruct/transform.py
Original file line number Diff line number Diff line change
Expand Up @@ -936,7 +936,7 @@ def call(self, feats):
return {self.feat_name: feats}

class HardEdgeNegativeTransform(TwoPhaseFeatTransform):
""" Translate input data into node ids for hard negative
""" Translate input data into node ids for hard negative stored as edge features
Parameters
----------
Expand All @@ -954,25 +954,33 @@ def __init__(self, col_name, feat_name, separator=None):
super().__init__(col_name, feat_name, out_dtype=np.int64)

def set_target_etype(self, etype):
""" Set the etype of this hard edge negative transformation ops.
""" Set the etype of this hard edge negative transformation ops
and associated hard negative information. For example,
self._target_ntype.
Parameters
----------
etype : tuple of str
The edge type the hard negatives belonging to.
"""
self._target_etype = etype
raise NotImplementedError

@property
def target_etype(self):
""" The the edge type of this hard negative transformation.
"""
return self._target_etype

def set_id_map(self, id_map):
def set_id_maps(self, id_maps):
""" Set ID mapping for converting raw string ID to Graph ID
"""
self._nid_map = id_map
assert self._target_ntype is not None, \
"The target node type should be set, it can be the source node type " \
"or the destination node type depending on the hard negative case."
assert self._target_ntype in id_maps, \
f"The nid mapping should have the mapping for {self._target_ntype}. " \
f"But only has {id_maps.keys()}"
self._nid_map = id_maps

def pre_process(self, feats):
""" Pre-process data
Expand All @@ -992,7 +1000,9 @@ def pre_process(self, feats):
max_dim = feats.shape[1]
else:
assert feats.dtype.type is np.str_, \
"We can only convert strings to multiple hard negatives with separators."
"We can only convert strings to multiple hard negatives when a separator is given."
assert len(feats.shape) == 1 or feats.shape[1] == 1, \
"When a separator is given, the input feats must be a list of strings."
max_dim = 0
for feat in feats:
dim_size = len(feat.split(self._separator))
Expand Down Expand Up @@ -1034,6 +1044,9 @@ def call(self, feats):
return {self.feat_name: neg_ids}

class HardEdgeDstNegativeTransform(HardEdgeNegativeTransform):
""" Translate input data (destination node raw id) into GraphStorm node ids
for hard negative stored as edge features.
"""

def set_target_etype(self, etype):
self._target_etype = etype
Expand Down
86 changes: 85 additions & 1 deletion tests/unit-tests/gconstruct/test_transform.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,13 +25,15 @@
Noop,
RankGaussTransform,
CategoricalTransform,
BucketTransform)
BucketTransform,
HardEdgeDstNegativeTransform)
from graphstorm.gconstruct.transform import (_check_label_stats_type,
collect_label_stats,
CustomLabelProcessor,
ClassificationProcessor)
from graphstorm.gconstruct.transform import (LABEL_STATS_FIELD,
LABEL_STATS_FREQUENCY_COUNT)
from graphstorm.gconstruct.id_map import IdMap

def test_get_output_dtype():
assert _get_output_dtype("float16") == np.float16
Expand Down Expand Up @@ -779,7 +781,89 @@ def test_bucket_transform(out_dtype):
feats_tar = np.array([[1, 1], [1, 1], [1, 1], [1, 1]], dtype=out_dtype)
assert_equal(bucket_feats['test'], feats_tar)

def test_hard_edge_dst_negative_transform():
hard_neg_trasnform = HardEdgeDstNegativeTransform("hard_neg", "hard_neg")
assert hard_neg_trasnform.col_name == "hard_neg"
assert hard_neg_trasnform.feat_name == "hard_neg"
assert hard_neg_trasnform.out_dtype == np.int64

str_ids = np.array([(99-i) for i in range(100)])
id_maps = {"src": IdMap(str_ids.astype(str))}
pass_set_id_maps = False
try:
# set_id_maps will fail if target_ntype is None
assert hard_neg_trasnform._target_ntype is None
hard_neg_trasnform.set_id_maps(id_maps)
except:
pass_set_id_maps = True
assert pass_set_id_maps

hard_neg_trasnform.set_target_etype(("src", "rel", "dst"))
try:
# set_id_maps will fail as target_ntype is dst
# but only src has id mapping.
hard_neg_trasnform.set_id_maps(id_maps)
except:
pass_set_id_maps = True
assert pass_set_id_maps

id_maps = {"dst": IdMap(str_ids.astype(str))}
hard_neg_trasnform.set_id_maps(id_maps)

input_feats0 = np.random.randint(0, 100, size=(20, 10), dtype=np.int64)
input_str_feats0 = input_feats0.astype(str)
info0 = hard_neg_trasnform.pre_process(input_str_feats0)
assert info0["hard_neg"] == 10

input_feats1 = np.random.randint(0, 100, size=(20, 20), dtype=np.int64)
input_str_feats1 = input_feats1.astype(str)
info1 = hard_neg_trasnform.pre_process(input_str_feats1)
assert info1["hard_neg"] == 20

hard_neg_trasnform.update_info([info0["hard_neg"], info1["hard_neg"]])
assert hard_neg_trasnform._max_dim == 20

neg0 = hard_neg_trasnform(input_str_feats0)
assert_equal(neg0["hard_neg"][:,:10], 99-input_feats0)
assert_equal(neg0["hard_neg"][:,10:], np.full((20, 10), -1, dtype=np.int64))
neg1 = hard_neg_trasnform(input_str_feats1)
assert_equal(neg1["hard_neg"], 99-input_feats1)

hard_neg_trasnform = HardEdgeDstNegativeTransform("hard_neg", "hard_neg", separator=",")
hard_neg_trasnform.set_target_etype(("src", "rel", "dst"))
hard_neg_trasnform.set_id_maps(id_maps)

input_feats0 = np.random.randint(0, 100, size=(20, 10), dtype=np.int64)
input_str_feats0 = [",".join(feats) for feats in input_feats0.astype(str).tolist()]
input_str_feats0.append(",".join([str(i) for i in range(15)]))
input_str_feats0 = np.array(input_str_feats0)
info0 = hard_neg_trasnform.pre_process(input_str_feats0)
assert info0["hard_neg"] == 15

input_feats1 = np.random.randint(0, 100, size=(20, 20), dtype=np.int64)
input_str_feats1 = [",".join(feats) for feats in input_feats1.astype(str).tolist()]
input_str_feats1.append(",".join([str(i) for i in range(15)]))
input_str_feats1 = np.array(input_str_feats1)
info1 = hard_neg_trasnform.pre_process(input_str_feats1)
assert info1["hard_neg"] == 20

hard_neg_trasnform.update_info([info0["hard_neg"], info1["hard_neg"]])
assert hard_neg_trasnform._max_dim == 20

neg0 = hard_neg_trasnform(input_str_feats0)
assert_equal(neg0["hard_neg"][:20,:10], 99-input_feats0)
assert_equal(neg0["hard_neg"][:20,10:], np.full((20, 10), -1, dtype=np.int64))
assert_equal(neg0["hard_neg"][20][:15], np.array([(99-i) for i in range(15)]))
assert_equal(neg0["hard_neg"][20][15:], np.full((5,), -1, dtype=np.int64))
neg1 = hard_neg_trasnform(input_str_feats1)
assert_equal(neg1["hard_neg"][:20], 99-input_feats1)
assert_equal(neg1["hard_neg"][20][:15], np.array([(99-i) for i in range(15)]))
assert_equal(neg1["hard_neg"][20][15:], np.full((5,), -1, dtype=np.int64))


if __name__ == '__main__':
test_hard_edge_dst_negative_transform()

test_categorize_transform()
test_get_output_dtype()
test_fp_transform(np.cfloat)
Expand Down

0 comments on commit fa1f257

Please sign in to comment.