Skip to content

Commit

Permalink
[GConstruct] Allow users define train/validation/test masks (#804)
Browse files Browse the repository at this point in the history
*Issue #, if available:*
#789 

*Description of changes:*
This is the first PR to implement #789. We need to allow use to define
train, validation and test mask names themselves.


By submitting this pull request, I confirm that you can use, modify,
copy, and redistribute this contribution, under the terms of your
choice.

---------

Co-authored-by: Xiang Song <[email protected]>
  • Loading branch information
classicsong and Xiang Song authored Apr 24, 2024
1 parent 13851d2 commit b4318f3
Show file tree
Hide file tree
Showing 8 changed files with 867 additions and 93 deletions.
127 changes: 88 additions & 39 deletions python/graphstorm/gconstruct/construct_graph.py
Original file line number Diff line number Diff line change
Expand Up @@ -286,12 +286,19 @@ def process_node_data(process_confs, arr_merger, remap_id,
Returns
-------
dict: node ID map
dict: node features.
node_id_map: dict
Node ID map.
node_data: dict
Node features.
label_stats: dict
Node label statistics.
label_masks: dict
Node label mask field names.
"""
node_data = {}
node_id_map = {}
label_stats = {}
label_masks = {}
for process_conf in process_confs:
# each iteration is to process a node type.
assert 'node_type' in process_conf, \
Expand Down Expand Up @@ -374,6 +381,15 @@ def process_node_data(process_confs, arr_merger, remap_id,

if node_type not in label_stats:
label_stats[node_type] = {}
label_masks[node_type] = []

if label_ops is not None:
for label_op in label_ops:
train_mask = label_op.train_mask_name
val_mask = label_op.val_mask_name
test_mask = label_op.test_mask_name
label_masks[node_type].append((train_mask, val_mask, test_mask))

for feat_name in list(type_node_data):
# features start with LABEL_STATS_FIELD store label statistics
if feat_name.startswith(LABEL_STATS_FIELD):
Expand Down Expand Up @@ -416,7 +432,7 @@ def process_node_data(process_confs, arr_merger, remap_id,
f"Node data and node IDs for node type {node_type} does not match: " + \
f"{len(data)} vs. {len(node_id_map[node_type])}"
sys_tracker.check('Finish processing node data')
return (node_id_map, node_data, label_stats)
return (node_id_map, node_data, label_stats, label_masks)

def process_edge_data(process_confs, node_id_map, arr_merger,
ext_mem_workspace=None, num_processes=1,
Expand Down Expand Up @@ -469,12 +485,15 @@ def process_edge_data(process_confs, node_id_map, arr_merger,
Edge features.
label_stats: dict
Edge label statistics.
label_masks: dict
Edge label mask field names.
hard_edge_neg_ops: list
Hard edge negative ops.
"""
edges = {}
edge_data = {}
label_stats = {}
label_masks = {}
for process_conf in process_confs:
# each iteration is to process an edge type.
assert 'relation' in process_conf, \
Expand Down Expand Up @@ -543,6 +562,15 @@ def process_edge_data(process_confs, node_id_map, arr_merger,
edge_type = tuple(edge_type)
if edge_type not in label_stats:
label_stats[edge_type] = {}
label_masks[edge_type] = []

if label_ops is not None:
for label_op in label_ops:
train_mask = label_op.train_mask_name
val_mask = label_op.val_mask_name
test_mask = label_op.test_mask_name
label_masks[edge_type].append((train_mask, val_mask, test_mask))

# handle edge type
for feat_name in list(type_edge_data):
# features start with LABEL_STATS_FIELD store label statistics
Expand Down Expand Up @@ -595,7 +623,7 @@ def process_edge_data(process_confs, node_id_map, arr_merger,
f"does not match the number of edges of {edge_type}. " \
f"Expecting {len(edges[edge_type][0])}, but get {len(efeats)}"

return (edges, edge_data, label_stats, hard_edge_neg_ops)
return (edges, edge_data, label_stats, label_masks, hard_edge_neg_ops)

def is_homogeneous(confs):
""" Verify if it is a homogeneous graph
Expand Down Expand Up @@ -638,7 +666,8 @@ def verify_confs(confs):
for edge in confs['edges']:
edge['relation'] = list(DEFAULT_ETYPE)

def print_graph_info(g, node_data, edge_data, node_label_stats, edge_label_stats):
def print_graph_info(g, node_data, edge_data, node_label_stats, edge_label_stats,
node_label_masks, edge_label_masks):
""" Print graph information.
Parameters
Expand All @@ -653,6 +682,10 @@ def print_graph_info(g, node_data, edge_data, node_label_stats, edge_label_stats
Node label stats
edge_label_stats: dict of dict of tuple.
Edge label stats
node_label_masks: dict of list of tuple.
Node label masks
edge_label_masks: dict of list of tuple.
Edge label masks
"""
logging.info("The graph has %d node types and %d edge types.",
len(g.ntypes), len(g.etypes))
Expand All @@ -664,27 +697,35 @@ def print_graph_info(g, node_data, edge_data, node_label_stats, edge_label_stats
for ntype in node_data:
feat_names = list(node_data[ntype].keys())
logging.info("Node type %s has features: %s.", ntype, str(feat_names))
num_train = np.sum(node_data[ntype]["train_mask"]) \
if "train_mask" in node_data[ntype] else 0
num_val = np.sum(node_data[ntype]["val_mask"]) \
if "val_mask" in node_data[ntype] else 0
num_test = np.sum(node_data[ntype]["test_mask"]) \
if "test_mask" in node_data[ntype] else 0
if num_train + num_val + num_test > 0:
logging.info("Train/val/test on %s: %d, %d, %d",
ntype, num_train, num_val, num_test)

for label_mask in node_label_masks[ntype]:
train_mask, val_mask, test_mask = label_mask
num_train = np.sum(node_data[ntype][train_mask]) \
if train_mask in node_data[ntype] else 0
num_val = np.sum(node_data[ntype][val_mask]) \
if val_mask in node_data[ntype] else 0
num_test = np.sum(node_data[ntype][test_mask]) \
if test_mask in node_data[ntype] else 0
if num_train + num_val + num_test > 0:
logging.info("Train/val/test on %s with mask %s, %s, %s: %d, %d, %d",
ntype, train_mask, val_mask, test_mask,
num_train, num_val, num_test)
for etype in edge_data:
feat_names = list(edge_data[etype].keys())
logging.info("Edge type %s has features: %s.", str(etype), str(feat_names))
num_train = np.sum(edge_data[etype]["train_mask"]) \
if "train_mask" in edge_data[etype] else 0
num_val = np.sum(edge_data[etype]["val_mask"]) \
if "val_mask" in edge_data[etype] else 0
num_test = np.sum(edge_data[etype]["test_mask"]) \
if "test_mask" in edge_data[etype] else 0
if num_train + num_val + num_test > 0:
logging.info("Train/val/test on %s: %d, %d, %d",
str(etype), num_train, num_val, num_test)

for label_mask in edge_label_masks[etype]:
train_mask, val_mask, test_mask = label_mask
num_train = np.sum(edge_data[etype][train_mask]) \
if train_mask in edge_data[etype] else 0
num_val = np.sum(edge_data[etype][val_mask]) \
if val_mask in edge_data[etype] else 0
num_test = np.sum(edge_data[etype][test_mask]) \
if test_mask in edge_data[etype] else 0
if num_train + num_val + num_test > 0:
logging.info("Train/val/test on %s with mask %s, %s, %s: %d, %d, %d",
str(etype), train_mask, val_mask, test_mask,
num_train, num_val, num_test)

for ntype in node_label_stats:
for label_name, stats in node_label_stats[ntype].items():
Expand Down Expand Up @@ -717,12 +758,12 @@ def process_graph(args):
if len(output_format) == 1 and output_format[0] == "DistDGL" else None
convert2ext_mem = ExtMemArrayMerger(ext_mem_workspace, args.ext_mem_feat_size)

raw_node_id_maps, node_data, node_label_stats = \
raw_node_id_maps, node_data, node_label_stats, node_label_masks = \
process_node_data(process_confs['nodes'], convert2ext_mem,
args.remap_node_id, ext_mem_workspace,
num_processes=num_processes_for_nodes)
sys_tracker.check('Process the node data')
edges, edge_data, edge_label_stats, hard_edge_neg_ops = \
edges, edge_data, edge_label_stats, edge_label_masks, hard_edge_neg_ops = \
process_edge_data(process_confs['edges'], raw_node_id_maps,
convert2ext_mem, ext_mem_workspace,
num_processes=num_processes_for_edges,
Expand Down Expand Up @@ -751,8 +792,13 @@ def process_graph(args):
data = edge_data[DEFAULT_ETYPE]
logging.warning("Reverse edge for homogeneous graph will have same feature as "
"what we have in the original edges")
print(edge_label_masks)
edge_masks = []
for masks in edge_label_masks[DEFAULT_ETYPE]:
edge_masks.extend(list(masks))

for key, value in data.items():
if key not in ["train_mask", "test_mask", "val_mask"]:
if key not in edge_masks:
data[key] = np.concatenate([value, value])
else:
data[key] = np.concatenate([value, np.zeros(value.shape,
Expand All @@ -768,24 +814,27 @@ def process_graph(args):
edges = edges1
sys_tracker.check('Add reverse edges')
g = dgl.heterograph(edges, num_nodes_dict=num_nodes)
print_graph_info(g, node_data, edge_data, node_label_stats, edge_label_stats)
print_graph_info(g, node_data, edge_data, node_label_stats, edge_label_stats,
node_label_masks, edge_label_masks)
os.makedirs(args.output_dir, exist_ok=True)
sys_tracker.check('Construct DGL graph')

# reshape customized mask
for srctype_etype_dsttype in edge_data:
if "train_mask" in edge_data[srctype_etype_dsttype].keys() and \
len(edge_data[srctype_etype_dsttype]["train_mask"].shape) == 2:
edge_data[srctype_etype_dsttype]["train_mask"] = \
edge_data[srctype_etype_dsttype]["train_mask"].squeeze(1).astype('int8')
if "val_mask" in edge_data[srctype_etype_dsttype].keys() and \
len(edge_data[srctype_etype_dsttype]["val_mask"].shape) == 2:
edge_data[srctype_etype_dsttype]["val_mask"] = \
edge_data[srctype_etype_dsttype]["val_mask"].squeeze(1).astype('int8')
if "test_mask" in edge_data[srctype_etype_dsttype].keys() and \
len(edge_data[srctype_etype_dsttype]["test_mask"].shape) == 2:
edge_data[srctype_etype_dsttype]["test_mask"] = \
edge_data[srctype_etype_dsttype]["test_mask"].squeeze(1).astype('int8')
for label_mask in edge_label_masks[srctype_etype_dsttype]:
train_mask, val_mask, test_mask = label_mask
if train_mask in edge_data[srctype_etype_dsttype].keys() and \
len(edge_data[srctype_etype_dsttype][train_mask].shape) == 2:
edge_data[srctype_etype_dsttype][train_mask] = \
edge_data[srctype_etype_dsttype][train_mask].squeeze(1).astype('int8')
if val_mask in edge_data[srctype_etype_dsttype].keys() and \
len(edge_data[srctype_etype_dsttype][val_mask].shape) == 2:
edge_data[srctype_etype_dsttype][val_mask] = \
edge_data[srctype_etype_dsttype][val_mask].squeeze(1).astype('int8')
if test_mask in edge_data[srctype_etype_dsttype].keys() and \
len(edge_data[srctype_etype_dsttype][test_mask].shape) == 2:
edge_data[srctype_etype_dsttype][test_mask] = \
edge_data[srctype_etype_dsttype][test_mask].squeeze(1).astype('int8')

if "DistDGL" in output_format:
assert args.part_method in ["metis", "random"], \
Expand Down
Loading

0 comments on commit b4318f3

Please sign in to comment.