diff --git a/python/graphstorm/dataloading/dataset.py b/python/graphstorm/dataloading/dataset.py index 227ceff8bc..14e25eaf10 100644 --- a/python/graphstorm/dataloading/dataset.py +++ b/python/graphstorm/dataloading/dataset.py @@ -17,7 +17,6 @@ """ import os import abc -import json import logging import re @@ -27,9 +26,11 @@ from torch.utils.data import Dataset import pandas as pd -from ..utils import get_rank, get_world_size, is_distributed, barrier -from ..utils import sys_tracker, is_wholegraph -from .utils import dist_sum, flip_node_mask, is_wholegraph_embedding +from ..utils import get_rank, get_world_size, is_distributed, barrier, is_wholegraph +from ..utils import sys_tracker +from .utils import dist_sum, flip_node_mask + +from ..wholegraph import is_wholegraph_embedding def split_full_edge_list(g, etype, rank): ''' Split the full edge list of a graph. @@ -168,6 +169,7 @@ def __init__(self, graph_name, part_config, node_feat_field, edge_feat_field, # Use wholegraph for feature transfer if is_distributed() and is_wholegraph(): + from ..wholegraph import load_wg_feat logging.info("Allocate features with Wholegraph") num_parts = self._g.get_partition_book().num_partitions() @@ -184,7 +186,7 @@ def __init__(self, graph_name, part_config, node_feat_field, edge_feat_field, f"Feature '{name}' of '{ntype}' is not in WholeGraph format. " \ f"Please convert all the available features to WholeGraph " \ f"format to utilize WholeGraph." - data[name] = self.load_wg_feat(part_config, num_parts, ntype, name) + data[name] = load_wg_feat(part_config, num_parts, ntype, name) if len(self._g.ntypes) == 1: self._g._ndata_store.update(data) else: @@ -206,7 +208,7 @@ def __init__(self, graph_name, part_config, node_feat_field, edge_feat_field, f"Feature '{name}' of '{etype}' is not in WholeGraph format. " \ f"Please convert all the available features to WholeGraph " \ f"format to utilize WholeGraph." - data[name] = self.load_wg_feat(part_config, num_parts, etype_wg, name) + data[name] = load_wg_feat(part_config, num_parts, etype_wg, name) if len(self._g.canonical_etypes) == 1: self._g._edata_store.update(data) else: @@ -241,53 +243,6 @@ def edge_feat_field(self): """the field of edge feature""" return self._edge_feat_field - def load_wg_feat(self, part_config_path, num_parts, type_name, name): - """Load features from wholegraph memory - - Parameters - ---------- - part_config_path : str - The path of the partition configuration file. - num_parts : int - The number of partitions of the dataset - type_name: str - The type of node or edge for which to fetch features or labels for. - name: str - The name of the features to load - """ - import pylibwholegraph.torch as wgth - - global_comm = wgth.comm.get_global_communicator() - feature_comm = global_comm - embedding_wholememory_type = 'distributed' - embedding_wholememory_location = 'cpu' - cache_policy = wgth.create_builtin_cache_policy( - "none", # cache type - embedding_wholememory_type, - embedding_wholememory_location, - "readonly", # access type - 0.0, # cache ratio - ) - metadata_file = os.path.join(os.path.dirname(part_config_path), - 'wholegraph/metadata.json') - with open(metadata_file, encoding="utf8") as f: - wg_metadata = json.load(f) - data_shape = wg_metadata[type_name + '/' + name]['shape'] - feat_wm_embedding = wgth.create_embedding( - feature_comm, - embedding_wholememory_type, - embedding_wholememory_location, - getattr(th, wg_metadata[type_name + '/' + name]['dtype'].split('.')[1]), - [data_shape[0],1] if len(data_shape) == 1 else data_shape, - optimizer=None, - cache_policy=cache_policy, - ) - feat_path = os.path.join(os.path.dirname(part_config_path), 'wholegraph', \ - type_name + '~' + name) - feat_wm_embedding.get_embedding_tensor().from_file_prefix(feat_path, - part_count=num_parts) - return feat_wm_embedding - def has_node_feats(self, ntype): """ Test if the specified node type has features. @@ -514,7 +469,7 @@ class GSgnnEdgeTrainData(GSgnnEdgeData): different feature names. decoder_edge_feat: str or dict of list of str Edge features used by decoder - + Examples ---------- @@ -525,7 +480,7 @@ class GSgnnEdgeTrainData(GSgnnEdgeData): ep_data = GSgnnEdgeTrainData(graph_name='dummy', part_config=part_config, train_etypes=[('n1', 'e1', 'n2')], label_field='label', node_feat_field='node_feat', edge_feat_field='edge_feat') - ep_dataloader = GSgnnEdgeDataLoader(ep_data, target_idx={"e1":[0]}, + ep_dataloader = GSgnnEdgeDataLoader(ep_data, target_idx={"e1":[0]}, fanout=[15, 10], batch_size=128) """ def __init__(self, graph_name, part_config, train_etypes, eval_etypes=None, @@ -710,7 +665,7 @@ class GSgnnEdgeInferData(GSgnnEdgeData): The node types that contains text features. lm_feat_etypes : list of tuples The edge types that contains text features. - + Examples ---------- @@ -721,7 +676,7 @@ class GSgnnEdgeInferData(GSgnnEdgeData): ep_data = GSgnnEdgeInferData(graph_name='dummy', part_config=part_config, eval_etypes=[('n1', 'e1', 'n2')], label_field='label', node_feat_field='node_feat', edge_feat_field='edge_feat') - ep_dataloader = GSgnnEdgeDataLoader(ep_data, target_idx={"e1":[0]}, + ep_dataloader = GSgnnEdgeDataLoader(ep_data, target_idx={"e1":[0]}, fanout=[15, 10], batch_size=128) """ def __init__(self, graph_name, part_config, eval_etypes, @@ -911,7 +866,7 @@ class GSgnnNodeTrainData(GSgnnNodeData): The node types that contains text features. lm_feat_etypes : list of tuples The edge types that contains text features. - + Examples ---------- @@ -923,7 +878,7 @@ class GSgnnNodeTrainData(GSgnnNodeData): np_data = GSgnnNodeTrainData(graph_name='dummy', part_config=part_config, train_ntypes=['n1'], label_field='label', node_feat_field='feat') - np_dataloader = GSgnnNodeDataLoader(np_data, target_idx={'n1':[0]}, + np_dataloader = GSgnnNodeDataLoader(np_data, target_idx={'n1':[0]}, fanout=[15, 10], batch_size=128) """ def __init__(self, graph_name, part_config, train_ntypes, eval_ntypes=None, @@ -1066,19 +1021,19 @@ class GSgnnNodeInferData(GSgnnNodeData): The node types that contains text features. lm_feat_etypes : list of tuples The edge types that contains text features. - + Examples ---------- - + .. code:: python from graphstorm.dataloading import GSgnnNodeInferData - from graphstorm.dataloading import + from graphstorm.dataloading import np_data = GSgnnNodeInferData(graph_name='dummy', part_config=part_config, eval_ntypes=['n1'], label_field='label', node_feat_field='feat') - np_dataloader = GSgnnNodeDataLoader(np_data, target_idx={'n1':[0]}, + np_dataloader = GSgnnNodeDataLoader(np_data, target_idx={'n1':[0]}, fanout=[15, 10], batch_size=128) """ def __init__(self, graph_name, part_config, eval_ntypes, diff --git a/python/graphstorm/dataloading/utils.py b/python/graphstorm/dataloading/utils.py index 78ac1cc84d..122ddeab34 100644 --- a/python/graphstorm/dataloading/utils.py +++ b/python/graphstorm/dataloading/utils.py @@ -137,13 +137,3 @@ def flip_node_mask(dist_tensor, indices): part_policy=dist_tensor.part_policy) flipped_dist_tensor[indices] = 1 - dist_tensor[indices] return flipped_dist_tensor - -def is_wholegraph_embedding(data): - """ Check if the data is in WholeMemory emedding format which - is required to use wholegraph framework. - """ - try: - import pylibwholegraph - return isinstance(data, pylibwholegraph.torch.WholeMemoryEmbedding) - except: # pylint: disable=bare-except - return False diff --git a/python/graphstorm/gsf.py b/python/graphstorm/gsf.py index 97dabd1164..03afdb8603 100644 --- a/python/graphstorm/gsf.py +++ b/python/graphstorm/gsf.py @@ -22,7 +22,6 @@ import dgl import torch as th import torch.nn.functional as F -from dataclasses import dataclass from dgl.distributed import role from dgl.distributed.constants import DEFAULT_NTYPE from dgl.distributed.constants import DEFAULT_ETYPE @@ -65,29 +64,6 @@ LinkPredictWeightedDistMultDecoder) from .tracker import get_task_tracker_class -def init_wholegraph(): - """ Initialize Wholegraph""" - import pylibwholegraph.torch as wgth - import pylibwholegraph.binding.wholememory_binding as wmb - - @dataclass - class Options: # pylint: disable=missing-class-docstring - pass - Options.launch_agent = 'pytorch' - Options.launch_env_name_world_rank = 'RANK' - Options.launch_env_name_world_size = 'WORLD_SIZE' - Options.launch_env_name_local_rank = 'LOCAL_RANK' - Options.launch_env_name_local_size = 'LOCAL_WORLD_SIZE' - Options.launch_env_name_master_addr = 'MASTER_ADDR' - Options.launch_env_name_master_port = 'MASTER_PORT' - Options.local_rank = get_rank() % role.get_num_trainers() - Options.local_size = role.get_num_trainers() - - wgth.distributed_launch(Options, lambda: None) - wmb.init(0) - wgth.comm.set_world_info(get_rank(), get_world_size(), Options.local_rank, - Options.local_size) - def initialize(ip_config, backend, use_wholegraph=False): """ Initialize distributed training and inference context. @@ -108,6 +84,7 @@ def initialize(ip_config, backend, use_wholegraph=False): th.distributed.init_process_group(backend=backend) # Use wholegraph for feature and label fetching if use_wholegraph: + from .wholegraph import init_wholegraph init_wholegraph() sys_tracker.check("load DistDGL") diff --git a/python/graphstorm/wholegraph/__init__.py b/python/graphstorm/wholegraph/__init__.py new file mode 100644 index 0000000000..705b2f7a81 --- /dev/null +++ b/python/graphstorm/wholegraph/__init__.py @@ -0,0 +1,19 @@ +""" + Copyright 2023 Contributors + + Licensed under the Apache License, Version 2.0 (the "License"); + you may not use this file except in compliance with the License. + You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + + Unless required by applicable law or agreed to in writing, software + distributed under the License is distributed on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + See the License for the specific language governing permissions and + limitations under the License. + + Initial to import WholeGraph-related function and classes +""" +from .wholegraph import (init_wholegraph, convert_feat_to_wholegraph, load_wg_feat) +from .utils import is_wholegraph_embedding diff --git a/python/graphstorm/wholegraph/utils.py b/python/graphstorm/wholegraph/utils.py new file mode 100644 index 0000000000..b0cbb7ec59 --- /dev/null +++ b/python/graphstorm/wholegraph/utils.py @@ -0,0 +1,28 @@ +""" + Copyright 2023 Contributors + + Licensed under the Apache License, Version 2.0 (the "License"); + you may not use this file except in compliance with the License. + You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + + Unless required by applicable law or agreed to in writing, software + distributed under the License is distributed on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + See the License for the specific language governing permissions and + limitations under the License. + + Utils for integrating WholeGraph into GraphStorm +""" + + +def is_wholegraph_embedding(data): + """ Check if the data is in WholeMemory emedding format which + is required to use wholegraph framework. + """ + try: + import pylibwholegraph + return isinstance(data, pylibwholegraph.torch.WholeMemoryEmbedding) + except ImportError: + return False diff --git a/python/graphstorm/wholegraph/wholegraph.py b/python/graphstorm/wholegraph/wholegraph.py new file mode 100644 index 0000000000..6f212fec4d --- /dev/null +++ b/python/graphstorm/wholegraph/wholegraph.py @@ -0,0 +1,285 @@ +""" + Copyright 2023 Contributors + + Licensed under the Apache License, Version 2.0 (the "License"); + you may not use this file except in compliance with the License. + You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + + Unless required by applicable law or agreed to in writing, software + distributed under the License is distributed on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + See the License for the specific language governing permissions and + limitations under the License. + + Functions/classes for integrating WholeGraph into GraphStorm +""" + +import os + +import json +import gc +import logging +import re + +import torch as th +import dgl +from typing import Optional +from dataclasses import dataclass + +from ..utils import get_rank, get_world_size + +try: + import pylibwholegraph.torch as wgth +except ImportError: + wgth = None + + +def init_wholegraph(): + """ Initialize Wholegraph""" + if wgth is None: + raise ImportError("WholeGraph is not installed") + from dgl.distributed import role + import pylibwholegraph.binding.wholememory_binding as wmb + + @dataclass + class Options: # pylint: disable=missing-class-docstring + pass + Options.launch_agent = 'pytorch' + Options.launch_env_name_world_rank = 'RANK' + Options.launch_env_name_world_size = 'WORLD_SIZE' + Options.launch_env_name_local_rank = 'LOCAL_RANK' + Options.launch_env_name_local_size = 'LOCAL_WORLD_SIZE' + Options.launch_env_name_master_addr = 'MASTER_ADDR' + Options.launch_env_name_master_port = 'MASTER_PORT' + Options.local_rank = get_rank() % role.get_num_trainers() + Options.local_size = role.get_num_trainers() + + wgth.distributed_launch(Options, lambda: None) + wmb.init(0) + wgth.comm.set_world_info(get_rank(), get_world_size(), Options.local_rank, + Options.local_size) + + +def wholegraph_processing( + whole_feat_tensor, metadata, feat, wg_folder, num_parts +): + """Convert DGL tensors to wholememory tensor + + Parameters + ---------- + whole_feat_tensor : Tensor + The concatenated feature tensor of different partitions + metadata : Tensor + Metadata of the feature tensor + feat : str + Name of the feature to be converted + wg_folder : str + Name of the folder to store the converted files + num_parts : int + Number of partitions of the input features + """ + metadata[feat] = { + "shape": list(whole_feat_tensor.shape), + "dtype": str(whole_feat_tensor.dtype), + } + local_comm = wgth.comm.get_local_device_communicator() + # Round up the integer division to match WholeGraph partitioning scheme + subpart_size = -(whole_feat_tensor.shape[0] // -num_parts) + + for part_num in range(num_parts): + st = part_num * subpart_size + end = (part_num + 1) * subpart_size \ + if part_num != (num_parts - 1) \ + else whole_feat_tensor.shape[0] + + wg_tensor = wgth.create_wholememory_tensor( + local_comm, + "continuous", + "cpu", + (end - st, *whole_feat_tensor.shape[1:]), + whole_feat_tensor.dtype, + None, + ) + local_tensor, _ = wg_tensor.get_local_tensor(host_view=True) + local_tensor.copy_(whole_feat_tensor[st:end]) + filename = wgth.utils.get_part_file_name( + feat.replace("/", "~"), part_num, num_parts + ) + wg_tensor.local_to_file(os.path.join(wg_folder, filename)) + wgth.destroy_wholememory_tensor(wg_tensor) + + +def trim_feat_files(trimmed_feats, folder, file_name, part): + """Save new truncated distDGL tensors + Parameters + ---------- + trimmed_feats : list of tensors + distDGL tensors after trimming out the processed features + folder : str + Name of the folder of the input feature files + file_name : str + Name of the feature file, either node_feat.dgl or edge_feat.dgl + part : int + Partition number of the input feature files + + """ + dgl.data.utils.save_tensors( + os.path.join(folder, f"part{part}", "new_" + file_name), trimmed_feats[part] + ) + os.rename( + os.path.join(folder, f"part{part}", file_name), + os.path.join(folder, f"part{part}", file_name + ".bak"), + ) + os.rename( + os.path.join(folder, f"part{part}", "new_" + file_name), + os.path.join(folder, f"part{part}", file_name), + ) + + +def convert_feat_to_wholegraph(fname_dict, file_name, metadata, folder, use_low_mem): + """Convert features from distDGL tensor format to WholeGraph format + + Parameters + ---------- + fname_dict: dict of list + Dict of names of the edge features of different edge types + file_name: + Name of the feature file, either node_feat.dgl or edge_feat.dgl + metadata : Tensor + Metadata of the feature tensor + folder: str + Name of the folder of the input feature files + use_low_mem: bool + Whether to use low memory version for conversion + """ + wg_folder = os.path.join(folder, "wholegraph") + folder_pattern = re.compile(r"^part[0-9]+$") + part_files = [ + f + for f in os.listdir(folder) + if os.path.isdir(os.path.join(folder, f)) and folder_pattern.match(f) + ] + part_files = sorted(part_files, key=lambda x: int(x.split("part")[1])) + feats_data = [] + + # When 'use_low_mem' is not enabled, this code loads and appends features from individual + # partitions. Then features are concatenated and converted into the WholeGraph format one + # by one. The minimum memory requirement for this approach is 2X the size of the input + # nodes or edges features in the graph. + if not use_low_mem: + # Read features from file + for path in (os.path.join(folder, name) for name in part_files): + feats_data.append(dgl.data.utils.load_tensors(f"{path}/{file_name}")) + num_parts = len(feats_data) + for type_name, feats in fname_dict.items(): + for feat in feats: + feat = type_name + "/" + feat + if feat not in feats_data[0]: + raise RuntimeError( + f"Error: Unknown feature '{feat}'. Files contain \ + the following features: {feats_data[0].keys()}." + ) + logging.info("Processing %s features...", feat) + whole_feat_tensor = th.concat( + tuple(t[feat] for t in feats_data), dim=0 + ) + # Delete processed feature from memory + for t in feats_data: + del t[feat] + wholegraph_processing( + whole_feat_tensor, metadata, feat, wg_folder, num_parts + ) + # Trim the original distDGL tensors + for part in range(num_parts): + trim_feat_files(feats_data, folder, file_name, part) + + # This low-memory version loads one partition at a time. It processes features one by one, + # iterating through all the partitions and appending only the current feature, converting + # it to a WholeGraph. The minimum memory requirement for this approach is 2X the size of + # the largest node or edge feature in the graph. + else: # low-mem + for ntype, feats in fname_dict.items(): + for feat in feats: + feat = ntype + "/" + feat + node_feats_data = None + num_parts = 0 + # Read features from file + for path in (os.path.join(folder, name) for name in part_files): + nfeat = dgl.data.utils.load_tensors(f"{path}/{file_name}") + if feat not in nfeat: + raise RuntimeError( + f"Error: Unknown feature '{feat}'. Files contain \ + the following features: {nfeat.keys()}." + ) + if node_feats_data is None: + node_feats_data = nfeat[feat] + else: + node_feats_data = th.concat((node_feats_data, nfeat[feat]), dim=0) + num_parts += 1 + del nfeat + gc.collect() + wholegraph_processing( + node_feats_data, + metadata, + feat, + wg_folder, + num_parts, + ) + num_parts = 0 + for path in (os.path.join(folder, name) for name in part_files): + feats_data = dgl.data.utils.load_tensors(f"{path}/{file_name}") + for type_name, feats in fname_dict.items(): + for feat in feats: + feat = type_name + "/" + feat + # Delete processed feature from memory + del feats_data[feat] + num_parts += 1 + trim_feat_files(feats_data, folder, file_name, num_parts) + + +def load_wg_feat(part_config_path, num_parts, type_name, name): + """Load features from wholegraph memory + + Parameters + ---------- + part_config_path : str + The path of the partition configuration file. + num_parts : int + The number of partitions of the dataset + type_name: str + The type of node or edge for which to fetch features or labels for. + name: str + The name of the features to load + """ + global_comm = wgth.comm.get_global_communicator() + feature_comm = global_comm + embedding_wholememory_type = 'distributed' + embedding_wholememory_location = 'cpu' + cache_policy = wgth.create_builtin_cache_policy( + "none", # cache type + embedding_wholememory_type, + embedding_wholememory_location, + "readonly", # access type + 0.0, # cache ratio + ) + metadata_file = os.path.join(os.path.dirname(part_config_path), + 'wholegraph/metadata.json') + with open(metadata_file, encoding="utf8") as f: + wg_metadata = json.load(f) + data_shape = wg_metadata[type_name + '/' + name]['shape'] + feat_wm_embedding = wgth.create_embedding( + feature_comm, + embedding_wholememory_type, + embedding_wholememory_location, + getattr(th, wg_metadata[type_name + '/' + name]['dtype'].split('.')[1]), + [data_shape[0], 1] if len(data_shape) == 1 else data_shape, + optimizer=None, + cache_policy=cache_policy, + ) + feat_path = os.path.join(os.path.dirname(part_config_path), 'wholegraph', \ + type_name + '~' + name) + feat_wm_embedding.get_embedding_tensor().from_file_prefix(feat_path, + part_count=num_parts) + return feat_wm_embedding diff --git a/tools/convert_feat_to_wholegraph.py b/tools/convert_feat_to_wholegraph.py index 969ca73c1d..9d2152b1c5 100755 --- a/tools/convert_feat_to_wholegraph.py +++ b/tools/convert_feat_to_wholegraph.py @@ -16,16 +16,14 @@ Tool to convert node features from distDGL format to WholeGraph format. """ import argparse -import gc + import json import logging import os -import re -import dgl -import pylibwholegraph.torch as wgth import torch +from graphstorm.wholegraph import convert_feat_to_wholegraph def get_node_feat_info(node_feat_names): """Process node feature names @@ -93,184 +91,6 @@ def get_edge_feat_info(edge_feat_names): fname_dict[etype] = feat_info[1].split(",") return fname_dict - -def wholegraph_processing( - whole_feat_tensor, metadata, feat, wg_folder, num_parts -): - """Convert DGL tensors to wholememory tensor - - Parameters - ---------- - whole_feat_tensor : Tensor - The concatenated feature tensor of different partitions - metadata : Tensor - Metadata of the feature tensor - feat : str - Name of the feature to be converted - wg_folder : str - Name of the folder to store the converted files - num_parts : int - Number of partitions of the input features - """ - metadata[feat] = { - "shape": list(whole_feat_tensor.shape), - "dtype": str(whole_feat_tensor.dtype), - } - local_comm = wgth.comm.get_local_device_communicator() - # Round up the integer division to match WholeGraph partitioning scheme - subpart_size = -(whole_feat_tensor.shape[0] // -num_parts) - - for part_num in range(num_parts): - st = part_num * subpart_size - end = (part_num + 1) * subpart_size \ - if part_num != (num_parts - 1) \ - else whole_feat_tensor.shape[0] - - wg_tensor = wgth.create_wholememory_tensor( - local_comm, - "continuous", - "cpu", - (end - st, *whole_feat_tensor.shape[1:]), - whole_feat_tensor.dtype, - None, - ) - local_tensor, _ = wg_tensor.get_local_tensor(host_view=True) - local_tensor.copy_(whole_feat_tensor[st:end]) - filename = wgth.utils.get_part_file_name( - feat.replace("/", "~"), part_num, num_parts - ) - wg_tensor.local_to_file(os.path.join(wg_folder, filename)) - wgth.destroy_wholememory_tensor(wg_tensor) - - -def convert_feat_to_wholegraph(fname_dict, file_name, metadata, folder, use_low_mem): - """Convert features from distDGL tensor format to WholeGraph format - - Parameters - ---------- - fname_dict: dict of list - Dict of names of the edge features of different edge types - file_name: - Name of the feature file, either node_feat.dgl or edge_feat.dgl - metadata : Tensor - Metadata of the feature tensor - folder: str - Name of the folder of the input feature files - use_low_mem: bool - Whether to use low memory version for conversion - """ - wg_folder = os.path.join(folder, "wholegraph") - folder_pattern = re.compile(r"^part[0-9]+$") - part_files = [ - f - for f in os.listdir(folder) - if os.path.isdir(os.path.join(folder, f)) and folder_pattern.match(f) - ] - part_files = sorted(part_files, key=lambda x: int(x.split("part")[1])) - feats_data = [] - - # When 'use_low_mem' is not enabled, this code loads and appends features from individual - # partitions. Then features are concatenated and converted into the WholeGraph format one - # by one. The minimum memory requirement for this approach is 2X the size of the input - # nodes or edges features in the graph. - if not use_low_mem: - # Read features from file - for path in (os.path.join(folder, name) for name in part_files): - feats_data.append(dgl.data.utils.load_tensors(f"{path}/{file_name}")) - num_parts = len(feats_data) - for type_name, feats in fname_dict.items(): - for feat in feats: - feat = type_name + "/" + feat - if feat not in feats_data[0]: - raise RuntimeError( - f"Error: Unknown feature '{feat}'. Files contain \ - the following features: {feats_data[0].keys()}." - ) - logging.info("Processing %s features...", feat) - whole_feat_tensor = torch.concat( - tuple(t[feat] for t in feats_data), dim=0 - ) - # Delete processed feature from memory - for t in feats_data: - del t[feat] - wholegraph_processing( - whole_feat_tensor, metadata, feat, wg_folder, num_parts - ) - # Trim the original distDGL tensors - for part in range(num_parts): - trim_feat_files(feats_data, folder, file_name, part) - - # This low-memory version loads one partition at a time. It processes features one by one, - # iterating through all the partitions and appending only the current feature, converting - # it to a WholeGraph. The minimum memory requirement for this approach is 2X the size of - # the largest node or edge feature in the graph. - else: # low-mem - for ntype, feats in fname_dict.items(): - for feat in feats: - feat = ntype + "/" + feat - node_feats_data = None - num_parts = 0 - # Read features from file - for path in (os.path.join(folder, name) for name in part_files): - nfeat = dgl.data.utils.load_tensors(f"{path}/{file_name}") - if feat not in nfeat: - raise RuntimeError( - f"Error: Unknown feature '{feat}'. Files contain \ - the following features: {nfeat.keys()}." - ) - if node_feats_data is None: - node_feats_data = nfeat[feat] - else: - node_feats_data = torch.concat((node_feats_data, nfeat[feat]), dim=0) - num_parts += 1 - del nfeat - gc.collect() - wholegraph_processing( - node_feats_data, - metadata, - feat, - wg_folder, - num_parts, - ) - num_parts = 0 - for path in (os.path.join(folder, name) for name in part_files): - feats_data = dgl.data.utils.load_tensors(f"{path}/{file_name}") - for type_name, feats in fname_dict.items(): - for feat in feats: - feat = type_name + "/" + feat - # Delete processed feature from memory - del feats_data[feat] - num_parts += 1 - trim_feat_files(feats_data, folder, file_name, num_parts) - - -def trim_feat_files(trimmed_feats, folder, file_name, part): - """Save new truncated distDGL tensors - Parameters - ---------- - trimmed_feats : list of tensors - distDGL tensors after trimming out the processed features - folder : str - Name of the folder of the input feature files - file_name : str - Name of the feature file, either node_feat.dgl or edge_feat.dgl - part : int - Partition number of the input feature files - - """ - dgl.data.utils.save_tensors( - os.path.join(folder, f"part{part}", "new_" + file_name), trimmed_feats[part] - ) - os.rename( - os.path.join(folder, f"part{part}", file_name), - os.path.join(folder, f"part{part}", file_name + ".bak"), - ) - os.rename( - os.path.join(folder, f"part{part}", "new_" + file_name), - os.path.join(folder, f"part{part}", file_name), - ) - - def main(folder, node_feat_names, edge_feat_names, use_low_mem=False): """Convert features from distDGL tensor format to WholeGraph format""" os.environ["MASTER_ADDR"] = "127.0.0.1"