Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix the bug when there are edges with missing src or dst nodes, the num edge features != num edges #585

Merged
merged 2 commits into from
Oct 19, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
16 changes: 15 additions & 1 deletion python/graphstorm/gconstruct/construct_graph.py
Original file line number Diff line number Diff line change
Expand Up @@ -175,8 +175,22 @@ def parse_edge_data(in_file, feat_ops, label_ops, node_id_map, read_file,
src_ids = data[src_id_col] if src_id_col is not None else None
dst_ids = data[dst_id_col] if dst_id_col is not None else None
if src_ids is not None:
src_ids, dst_ids = map_node_ids(src_ids, dst_ids, edge_type, node_id_map,
src_ids, dst_ids, src_exist_locs, dst_exist_locs = \
map_node_ids(src_ids, dst_ids, edge_type, node_id_map,
skip_nonexist_edges)
if src_exist_locs is not None:
feat_data = {key: feat[src_exist_locs] \
for key, feat in feat_data.items()}
if dst_exist_locs is not None:
feat_data = {key: feat[dst_exist_locs] \
for key, feat in feat_data.items()}
# do some check
if src_exist_locs is not None or dst_exist_locs is not None:
for key, feat in feat_data.items():
assert len(src_ids) == len(feat), \
f"Expecting the edge feature {key} has the same length" \
f"as num existing edges {len(src_ids)}, but get {len(feat)}"

return (src_ids, dst_ids, feat_data)

def _process_data(user_pre_parser, user_parser,
Expand Down
19 changes: 17 additions & 2 deletions python/graphstorm/gconstruct/id_map.py
Original file line number Diff line number Diff line change
Expand Up @@ -180,10 +180,23 @@ def map_node_ids(src_ids, dst_ids, edge_type, node_id_map, skip_nonexist_edges):

Returns
-------
tuple of tensors : the remapped source and destination node IDs.
tuple of tensors :
src_ids: The remapped source node IDs.
dst_ids: The remapped destination node IDs.
src_exist_locs: The locations of source node IDs that
have existing edges. Only valid when
skip_nonexist_edges is True.
dst_exist_locs: The location of destination node IDs that
have existing edges. Only valid when
skip_nonexist_edges is True.

How to use src_exist_locs and dst_exist_locs:
feat_data = feat_data[src_exist_locs][dst_exist_locs]
"""
src_type, _, dst_type = edge_type
new_src_ids, orig_locs = node_id_map[src_type].map_id(src_ids)
src_exist_locs = None
dst_exist_locs = None
# If some of the source nodes don't exist in the node set.
if len(orig_locs) != len(src_ids):
bool_mask = np.ones(len(src_ids), dtype=bool)
Expand All @@ -195,6 +208,7 @@ def map_node_ids(src_ids, dst_ids, edge_type, node_id_map, skip_nonexist_edges):
else:
raise ValueError(f"source nodes of {src_type} do not exist: {src_ids[bool_mask]}")
dst_ids = dst_ids[orig_locs] if len(orig_locs) > 0 else np.array([], dtype=dst_ids.dtype)
src_exist_locs = orig_locs
src_ids = new_src_ids

new_dst_ids, orig_locs = node_id_map[dst_type].map_id(dst_ids)
Expand All @@ -210,5 +224,6 @@ def map_node_ids(src_ids, dst_ids, edge_type, node_id_map, skip_nonexist_edges):
raise ValueError(f"dest nodes of {dst_type} do not exist: {dst_ids[bool_mask]}")
# We need to remove the source nodes as well.
src_ids = src_ids[orig_locs] if len(orig_locs) > 0 else np.array([], dtype=src_ids.dtype)
dst_exist_locs = orig_locs
dst_ids = new_dst_ids
return src_ids, dst_ids
return src_ids, dst_ids, src_exist_locs, dst_exist_locs
78 changes: 68 additions & 10 deletions tests/unit-tests/gconstruct/test_construct_graph.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,16 +23,18 @@
import dgl
import torch as th

from functools import partial
from numpy.testing import assert_equal, assert_almost_equal

from graphstorm.gconstruct.construct_graph import parse_edge_data
from graphstorm.gconstruct.file_io import write_data_parquet, read_data_parquet
from graphstorm.gconstruct.file_io import write_data_json, read_data_json
from graphstorm.gconstruct.file_io import write_data_csv, read_data_csv
from graphstorm.gconstruct.file_io import write_data_hdf5, read_data_hdf5, HDF5Array
from graphstorm.gconstruct.file_io import write_index_json
from graphstorm.gconstruct.transform import parse_feat_ops, process_features, preprocess_features
from graphstorm.gconstruct.transform import parse_label_ops, process_labels
from graphstorm.gconstruct.transform import Noop, do_multiprocess_transform
from graphstorm.gconstruct.transform import Noop, do_multiprocess_transform, LinkPredictionProcessor
from graphstorm.gconstruct.id_map import IdMap, map_node_ids
from graphstorm.gconstruct.utils import (ExtMemArrayMerger,
ExtMemArrayWrapper,
Expand Down Expand Up @@ -907,10 +909,13 @@ def check_map_node_ids_exist(str_src_ids, str_dst_ids, id_map):
# Test the case that both source node IDs and destination node IDs exist.
src_ids = np.array([str(random.randint(0, len(str_src_ids) - 1)) for _ in range(15)])
dst_ids = np.array([str(random.randint(0, len(str_dst_ids) - 1)) for _ in range(15)])
new_src_ids, new_dst_ids = map_node_ids(src_ids, dst_ids, ("src", None, "dst"),
new_src_ids, new_dst_ids, src_exist_locs, dst_exist_locs = \
map_node_ids(src_ids, dst_ids, ("src", None, "dst"),
id_map, False)
assert len(new_src_ids) == len(src_ids)
assert len(new_dst_ids) == len(dst_ids)
assert src_exist_locs is None
assert dst_exist_locs is None
for src_id1, src_id2 in zip(new_src_ids, src_ids):
assert src_id1 == int(src_id2)
for dst_id1, dst_id2 in zip(new_dst_ids, dst_ids):
Expand All @@ -921,22 +926,28 @@ def check_map_node_ids_src_not_exist(str_src_ids, str_dst_ids, id_map):
src_ids = np.array([str(random.randint(0, 20)) for _ in range(15)])
dst_ids = np.array([str(random.randint(0, len(str_dst_ids) - 1)) for _ in range(15)])
try:
new_src_ids, new_dst_ids = map_node_ids(src_ids, dst_ids, ("src", None, "dst"),
new_src_ids, new_dst_ids, _, _ = \
map_node_ids(src_ids, dst_ids, ("src", None, "dst"),
id_map, False)
raise ValueError("fail")
except:
pass

# Test the case that source node IDs don't exist and we skip non exist edges.
new_src_ids, new_dst_ids = map_node_ids(src_ids, dst_ids, ("src", None, "dst"),
new_src_ids, new_dst_ids, src_exist_locs, dst_exist_locs \
= map_node_ids(src_ids, dst_ids, ("src", None, "dst"),
id_map, True)
num_valid = sum([int(id_) < len(str_src_ids) for id_ in src_ids])
assert len(new_src_ids) == num_valid
assert len(new_dst_ids) == num_valid
assert src_exist_locs is not None
assert_equal(src_ids[src_exist_locs].astype(np.int64), new_src_ids)
assert_equal(dst_ids[src_exist_locs].astype(np.int64), new_dst_ids)
assert dst_exist_locs is None

# Test the case that none of the source node IDs exists and we skip non exist edges.
src_ids = np.array([str(random.randint(20, 100)) for _ in range(15)])
new_src_ids, new_dst_ids = map_node_ids(src_ids, dst_ids, ("src", None, "dst"),
new_src_ids, new_dst_ids, _, _ = map_node_ids(src_ids, dst_ids, ("src", None, "dst"),
id_map, True)
assert len(new_src_ids) == 0
assert len(new_dst_ids) == 0
Expand All @@ -946,22 +957,27 @@ def check_map_node_ids_dst_not_exist(str_src_ids, str_dst_ids, id_map):
src_ids = np.array([str(random.randint(0, len(str_src_ids) - 1)) for _ in range(15)])
dst_ids = np.array([str(random.randint(0, 20)) for _ in range(15)])
try:
new_src_ids, new_dst_ids = map_node_ids(src_ids, dst_ids, ("src", None, "dst"),
new_src_ids, new_dst_ids, _, _ = \
map_node_ids(src_ids, dst_ids, ("src", None, "dst"),
id_map, False)
raise ValueError("fail")
except:
pass

# Test the case that destination node IDs don't exist and we skip non exist edges.
new_src_ids, new_dst_ids = map_node_ids(src_ids, dst_ids, ("src", None, "dst"),
new_src_ids, new_dst_ids, src_exist_locs, dst_exist_locs = map_node_ids(src_ids, dst_ids, ("src", None, "dst"),
id_map, True)
num_valid = sum([int(id_) < len(str_dst_ids) for id_ in dst_ids])
assert len(new_src_ids) == num_valid
assert len(new_dst_ids) == num_valid
assert src_exist_locs is None
assert_equal(src_ids[dst_exist_locs].astype(np.int64), new_src_ids)
assert_equal(dst_ids[dst_exist_locs].astype(np.int64), new_dst_ids)
assert dst_exist_locs is not None

# Test the case that none of the destination node IDs exists and we skip non exist edges.
dst_ids = np.array([str(random.randint(20, 100)) for _ in range(15)])
new_src_ids, new_dst_ids = map_node_ids(src_ids, dst_ids, ("src", None, "dst"),
new_src_ids, new_dst_ids, _, _ = map_node_ids(src_ids, dst_ids, ("src", None, "dst"),
id_map, True)
assert len(new_src_ids) == 0
assert len(new_dst_ids) == 0
Expand Down Expand Up @@ -1217,12 +1233,54 @@ def test_multiprocessing_checks():
multiprocessing = do_multiprocess_transform(conf, feat_ops, label_ops, in_files)
assert multiprocessing == False

def test_parse_edge_data():
with tempfile.TemporaryDirectory() as tmpdirname:
str_src_ids = np.array([str(i) for i in range(10)])
str_dst_ids = np.array([str(i) for i in range(15)])
node_id_map = {"src": IdMap(str_src_ids),
"dst": IdMap(str_dst_ids)}

src_ids = np.array([str(random.randint(0, 20)) for _ in range(15)])
dst_ids = np.array([str(random.randint(0, 25)) for _ in range(15)])
feat = np.random.rand(15, 10)
data = {
"src_id": src_ids,
"dst_id": dst_ids,
"feat": feat,
}

feat_ops = [Noop("feat", "feat", None)]
label_ops = [
LinkPredictionProcessor(None, None, [0.7,0.1,0.2], None)]
data_file = os.path.join(tmpdirname, "data.parquet")
write_data_parquet(data, data_file)

conf = {
"source_id_col": "src_id",
"dest_id_col": "dst_id",
"relation": ("src", "rel", "dst")
}
keys = ["src_id", "dst_id", "feat"]
src_ids, dst_ids, feat_data = \
parse_edge_data(data_file, feat_ops, label_ops, node_id_map,
partial(read_data_parquet, data_fields=keys),
conf, skip_nonexist_edges=True)
for _, val in feat_data.items():
assert len(src_ids) == len(val)
assert len(dst_ids) == len(val)

assert "feat" in feat_data
assert "train_mask" in feat_data
assert "val_mask" in feat_data
assert "test_mask" in feat_data

if __name__ == '__main__':
test_parse_edge_data()
test_multiprocessing_checks()
test_csv()
test_csv(None)
test_hdf5()
test_json()
test_partition_graph()
test_partition_graph(1)
test_merge_arrays()
test_map_node_ids()
test_id_map()
Expand Down
Loading