Skip to content

Commit

Permalink
store features in the external mem
Browse files Browse the repository at this point in the history
  • Loading branch information
jalencato committed Oct 25, 2023
1 parent 0e9b2cf commit e13f7b4
Show file tree
Hide file tree
Showing 3 changed files with 24 additions and 17 deletions.
27 changes: 16 additions & 11 deletions python/graphstorm/gconstruct/construct_graph.py
Original file line number Diff line number Diff line change
Expand Up @@ -71,7 +71,7 @@ def prepare_node_data(in_file, feat_ops, read_file):

return feat_info

def parse_node_data(in_file, feat_ops, label_ops, node_id_col, read_file):
def parse_node_data(in_file, feat_ops, label_ops, node_id_col, read_file, ext_mem):
""" Parse node data.
The function parses a node file that contains node IDs, features and labels
Expand All @@ -90,13 +90,15 @@ def parse_node_data(in_file, feat_ops, label_ops, node_id_col, read_file):
The column name that contains the node ID.
read_file : callable
The function to read the node file
ext_mem: str
The address of external memory for multi-column feature
Returns
-------
tuple : node ID array and a dict of node feature tensors.
"""
data = read_file(in_file)
feat_data = process_features(data, feat_ops) if feat_ops is not None else {}
feat_data = process_features(data, feat_ops, ext_mem) if feat_ops is not None else {}
if label_ops is not None:
label_data = process_labels(data, label_ops)
for key, val in label_data.items():
Expand Down Expand Up @@ -131,7 +133,7 @@ def prepare_edge_data(in_file, feat_ops, read_file):
return feat_info

def parse_edge_data(in_file, feat_ops, label_ops, node_id_map, read_file,
conf, skip_nonexist_edges):
conf, skip_nonexist_edges, ext_mem):
""" Parse edge data.
The function parses an edge file that contains the source and destination node
Expand Down Expand Up @@ -167,7 +169,7 @@ def parse_edge_data(in_file, feat_ops, label_ops, node_id_map, read_file,
edge_type = conf['relation']

data = read_file(in_file)
feat_data = process_features(data, feat_ops) if feat_ops is not None else {}
feat_data = process_features(data, feat_ops, ext_mem) if feat_ops is not None else {}
if label_ops is not None:
label_data = process_labels(data, label_ops)
for key, val in label_data.items():
Expand Down Expand Up @@ -230,7 +232,7 @@ def _process_data(user_pre_parser, user_parser,
return return_dict


def process_node_data(process_confs, arr_merger, remap_id, num_processes=1):
def process_node_data(process_confs, arr_merger, remap_id, ext_mem, num_processes=1):
""" Process node data
We need to process all node data before we can process edge data.
Expand Down Expand Up @@ -306,7 +308,8 @@ def process_node_data(process_confs, arr_merger, remap_id, num_processes=1):
user_parser = partial(parse_node_data, feat_ops=feat_ops,
label_ops=label_ops,
node_id_col=node_id_col,
read_file=read_file)
read_file=read_file,
ext_mem=ext_mem)

return_dict = _process_data(user_pre_parser,
user_parser,
Expand Down Expand Up @@ -400,7 +403,7 @@ def process_node_data(process_confs, arr_merger, remap_id, num_processes=1):
return (node_id_map, node_data, label_stats)

def process_edge_data(process_confs, node_id_map, arr_merger,
num_processes=1,
ext_mem, num_processes=1,
skip_nonexist_edges=False):
""" Process edge data
Expand Down Expand Up @@ -483,7 +486,8 @@ def process_edge_data(process_confs, node_id_map, arr_merger,
node_id_map=id_map,
read_file=read_file,
conf=process_conf,
skip_nonexist_edges=skip_nonexist_edges)
skip_nonexist_edges=skip_nonexist_edges,
ext_mem=ext_mem)

return_dict = _process_data(user_pre_parser,
user_parser,
Expand Down Expand Up @@ -659,12 +663,13 @@ def process_graph(args):

node_id_map, node_data, node_label_stats = \
process_node_data(process_confs['nodes'], convert2ext_mem,
args.remap_node_id,
num_processes=num_processes_for_nodes)
args.remap_node_id, ext_mem_workspace,
num_processes=num_processes_for_nodes
)
sys_tracker.check('Process the node data')
edges, edge_data, edge_label_stats = \
process_edge_data(process_confs['edges'], node_id_map,
convert2ext_mem,
convert2ext_mem, ext_mem_workspace,
num_processes=num_processes_for_edges,
skip_nonexist_edges=args.skip_nonexist_edges)
sys_tracker.check('Process the edge data')
Expand Down
10 changes: 7 additions & 3 deletions python/graphstorm/gconstruct/transform.py
Original file line number Diff line number Diff line change
Expand Up @@ -1069,7 +1069,7 @@ def preprocess_features(data, ops):

return pre_data

def process_features(data, ops):
def process_features(data, ops, ext_mem=None):
""" Process the data with the specified operations.
This function runs the input operations on the corresponding data
Expand All @@ -1081,14 +1081,17 @@ def process_features(data, ops):
The data stored as a dict.
ops : list of FeatTransform
The operations that transform features.
ext_mem: str
The address of external memory
Returns
-------
dict : the key is the data name, the value is the processed data.
"""
new_data = {}
for op in ops:
feature_path = 'feature_{}'.format(op.feat_name)
feature_path = ext_mem + 'feature_intermediate/feature_{}'\
.format(op.feat_name)
if os.path.exists(feature_path):
shutil.rmtree(feature_path)
if isinstance(op.col_name, str):
Expand All @@ -1110,7 +1113,8 @@ def process_features(data, ops):
new_data[key] = val
else:
tmp_key = key
feature_path = 'feature_{}'.format(op.feat_name)
assert ext_mem is not None, \
"external memory is necessary for multiple column"
if not os.path.exists(feature_path):
os.makedirs(feature_path)
wrapper = ExtFeatureWrapper(feature_path, val.shape, val.dtype)
Expand Down
4 changes: 1 addition & 3 deletions python/graphstorm/gconstruct/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -613,11 +613,9 @@ def merge(self):
out_arr[:, col_start:col_end] = arr
col_start = col_end

print("out_arr:", out_arr)
out_arr.flush()
del out_arr




def _merge_arrs(arrs, tensor_path):
""" Merge the arrays.
Expand Down

0 comments on commit e13f7b4

Please sign in to comment.