Skip to content

Commit

Permalink
Add some unitests
Browse files Browse the repository at this point in the history
  • Loading branch information
Xiang Song committed Dec 12, 2023
1 parent fa1f257 commit dd2074d
Show file tree
Hide file tree
Showing 2 changed files with 43 additions and 7 deletions.
30 changes: 23 additions & 7 deletions python/graphstorm/gconstruct/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -1001,19 +1001,34 @@ def partition_graph(g, node_data, edge_data, graph_name, num_partitions, output_
save_maps(output_dir, "edge_mapping", new_edge_mapping)

def get_hard_edge_negs_feats(hard_edge_neg_ops):
""" Get feature names of hard negatives for each edge type.
Parameters
----------
hard_edge_neg_ops: HardEdgeNegativeTransform
A list of edge hard negative transformations.
"""
"""
hard_edge_negs = {}
hard_edge_neg_feats = {}
for hard_edge_neg_op in hard_edge_neg_ops:
edge_type = hard_edge_neg_op.target_etype
if edge_type not in hard_edge_negs:
hard_edge_negs[edge_type] = [hard_edge_neg_op.feat_name]
if edge_type not in hard_edge_neg_feats:
hard_edge_neg_feats[edge_type] = [hard_edge_neg_op.feat_name]
else:
hard_edge_negs[edge_type].append(hard_edge_neg_op.feat_name)
return hard_edge_negs
hard_edge_neg_feats[edge_type].append(hard_edge_neg_op.feat_name)
return hard_edge_neg_feats

def shuffle_hard_nids(data_path, num_parts, hard_edge_neg_feats):
"""
""" Shuffle node ids of hard negatives from Graph node id space to
Partition Node id space.
Parameters
----------
data_path: str
Path to the directory storing the partitioned graph.
num_parts: int
Number of partitions.
hard_edge_neg_feats: dict of lists
A directory storing hard negative features for each edge type.
"""
# Load node id remapping
node_mapping = load_maps(data_path, "node_mapping")
Expand All @@ -1022,3 +1037,4 @@ def shuffle_hard_nids(data_path, num_parts, hard_edge_neg_feats):
for i in range(num_parts):
part_path = os.path.join(data_path, f"part{i}")
edge_faet_path = os.path.join(part_path, "edge_feat.dgl")

20 changes: 20 additions & 0 deletions tests/unit-tests/gconstruct/test_gconstruct_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,13 +27,15 @@
from graphstorm.gconstruct.utils import HDF5Array, ExtNumpyWrapper
from graphstorm.gconstruct.utils import convert_to_ext_mem_numpy, _to_ext_memory
from graphstorm.gconstruct.utils import multiprocessing_data_read
from graphstorm.gconstruct.utils import get_hard_edge_negs_feats
from graphstorm.gconstruct.file_io import (write_data_hdf5,
read_data_hdf5,
get_in_files,
write_data_parquet)
from graphstorm.gconstruct.file_io import (read_data_csv,
read_data_json,
read_data_parquet)
from graphstorm.gconstruct.transform import HardEdgeDstNegativeTransform

def gen_data():
data_th = th.zeros((1024, 16), dtype=th.float32)
Expand Down Expand Up @@ -298,7 +300,25 @@ def test_get_in_files():
pass_test = True
assert pass_test

def test_get_hard_edge_negs_feats():
hard_trans0 = HardEdgeDstNegativeTransform("hard_neg", "hard_neg")
hard_trans0.set_target_etype(("src", "rel0", "dst"))

hard_trans1 = HardEdgeDstNegativeTransform("hard_neg", "hard_neg1")
hard_trans1.set_target_etype(("src", "rel0", "dst"))

hard_trans2 = HardEdgeDstNegativeTransform("hard_neg", "hard_neg")
hard_trans2.set_target_etype(("src", "rel1", "dst"))

hard_edge_neg_feats = get_hard_edge_negs_feats([hard_trans0, hard_trans1, hard_trans2])
assert len(hard_edge_neg_feats) == 2
assert len(hard_edge_neg_feats[("src", "rel0", "dst")]) == 2
assert set(hard_edge_neg_feats[("src", "rel0", "dst")]) == set(["hard_neg", "hard_neg1"])
assert len(hard_edge_neg_feats[("src", "rel1", "dst")]) == 1


if __name__ == '__main__':
test_get_hard_edge_negs_feats()
test_get_in_files()
test_read_empty_parquet()
test_read_empty_json()
Expand Down

0 comments on commit dd2074d

Please sign in to comment.