Skip to content

Commit

Permalink
[WholeGraph] Refactor code to separate WholeGraph-related functions (#…
Browse files Browse the repository at this point in the history
…697)

This pull request refactors the code to separate WholeGraph-related
functions into a separate module. This improves code organization and
makes it easier to work with/extend WholeGraph functionality.
  • Loading branch information
chang-l authored Jan 12, 2024
1 parent f8f4937 commit aacf520
Show file tree
Hide file tree
Showing 7 changed files with 353 additions and 279 deletions.
81 changes: 18 additions & 63 deletions python/graphstorm/dataloading/dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,6 @@
"""
import os
import abc
import json
import logging
import re

Expand All @@ -27,9 +26,11 @@
from torch.utils.data import Dataset
import pandas as pd

from ..utils import get_rank, get_world_size, is_distributed, barrier
from ..utils import sys_tracker, is_wholegraph
from .utils import dist_sum, flip_node_mask, is_wholegraph_embedding
from ..utils import get_rank, get_world_size, is_distributed, barrier, is_wholegraph
from ..utils import sys_tracker
from .utils import dist_sum, flip_node_mask

from ..wholegraph import is_wholegraph_embedding

def split_full_edge_list(g, etype, rank):
''' Split the full edge list of a graph.
Expand Down Expand Up @@ -168,6 +169,7 @@ def __init__(self, graph_name, part_config, node_feat_field, edge_feat_field,

# Use wholegraph for feature transfer
if is_distributed() and is_wholegraph():
from ..wholegraph import load_wg_feat
logging.info("Allocate features with Wholegraph")
num_parts = self._g.get_partition_book().num_partitions()

Expand All @@ -184,7 +186,7 @@ def __init__(self, graph_name, part_config, node_feat_field, edge_feat_field,
f"Feature '{name}' of '{ntype}' is not in WholeGraph format. " \
f"Please convert all the available features to WholeGraph " \
f"format to utilize WholeGraph."
data[name] = self.load_wg_feat(part_config, num_parts, ntype, name)
data[name] = load_wg_feat(part_config, num_parts, ntype, name)
if len(self._g.ntypes) == 1:
self._g._ndata_store.update(data)
else:
Expand All @@ -206,7 +208,7 @@ def __init__(self, graph_name, part_config, node_feat_field, edge_feat_field,
f"Feature '{name}' of '{etype}' is not in WholeGraph format. " \
f"Please convert all the available features to WholeGraph " \
f"format to utilize WholeGraph."
data[name] = self.load_wg_feat(part_config, num_parts, etype_wg, name)
data[name] = load_wg_feat(part_config, num_parts, etype_wg, name)
if len(self._g.canonical_etypes) == 1:
self._g._edata_store.update(data)
else:
Expand Down Expand Up @@ -241,53 +243,6 @@ def edge_feat_field(self):
"""the field of edge feature"""
return self._edge_feat_field

def load_wg_feat(self, part_config_path, num_parts, type_name, name):
"""Load features from wholegraph memory
Parameters
----------
part_config_path : str
The path of the partition configuration file.
num_parts : int
The number of partitions of the dataset
type_name: str
The type of node or edge for which to fetch features or labels for.
name: str
The name of the features to load
"""
import pylibwholegraph.torch as wgth

global_comm = wgth.comm.get_global_communicator()
feature_comm = global_comm
embedding_wholememory_type = 'distributed'
embedding_wholememory_location = 'cpu'
cache_policy = wgth.create_builtin_cache_policy(
"none", # cache type
embedding_wholememory_type,
embedding_wholememory_location,
"readonly", # access type
0.0, # cache ratio
)
metadata_file = os.path.join(os.path.dirname(part_config_path),
'wholegraph/metadata.json')
with open(metadata_file, encoding="utf8") as f:
wg_metadata = json.load(f)
data_shape = wg_metadata[type_name + '/' + name]['shape']
feat_wm_embedding = wgth.create_embedding(
feature_comm,
embedding_wholememory_type,
embedding_wholememory_location,
getattr(th, wg_metadata[type_name + '/' + name]['dtype'].split('.')[1]),
[data_shape[0],1] if len(data_shape) == 1 else data_shape,
optimizer=None,
cache_policy=cache_policy,
)
feat_path = os.path.join(os.path.dirname(part_config_path), 'wholegraph', \
type_name + '~' + name)
feat_wm_embedding.get_embedding_tensor().from_file_prefix(feat_path,
part_count=num_parts)
return feat_wm_embedding

def has_node_feats(self, ntype):
""" Test if the specified node type has features.
Expand Down Expand Up @@ -514,7 +469,7 @@ class GSgnnEdgeTrainData(GSgnnEdgeData):
different feature names.
decoder_edge_feat: str or dict of list of str
Edge features used by decoder
Examples
----------
Expand All @@ -525,7 +480,7 @@ class GSgnnEdgeTrainData(GSgnnEdgeData):
ep_data = GSgnnEdgeTrainData(graph_name='dummy', part_config=part_config,
train_etypes=[('n1', 'e1', 'n2')], label_field='label',
node_feat_field='node_feat', edge_feat_field='edge_feat')
ep_dataloader = GSgnnEdgeDataLoader(ep_data, target_idx={"e1":[0]},
ep_dataloader = GSgnnEdgeDataLoader(ep_data, target_idx={"e1":[0]},
fanout=[15, 10], batch_size=128)
"""
def __init__(self, graph_name, part_config, train_etypes, eval_etypes=None,
Expand Down Expand Up @@ -710,7 +665,7 @@ class GSgnnEdgeInferData(GSgnnEdgeData):
The node types that contains text features.
lm_feat_etypes : list of tuples
The edge types that contains text features.
Examples
----------
Expand All @@ -721,7 +676,7 @@ class GSgnnEdgeInferData(GSgnnEdgeData):
ep_data = GSgnnEdgeInferData(graph_name='dummy', part_config=part_config,
eval_etypes=[('n1', 'e1', 'n2')], label_field='label',
node_feat_field='node_feat', edge_feat_field='edge_feat')
ep_dataloader = GSgnnEdgeDataLoader(ep_data, target_idx={"e1":[0]},
ep_dataloader = GSgnnEdgeDataLoader(ep_data, target_idx={"e1":[0]},
fanout=[15, 10], batch_size=128)
"""
def __init__(self, graph_name, part_config, eval_etypes,
Expand Down Expand Up @@ -911,7 +866,7 @@ class GSgnnNodeTrainData(GSgnnNodeData):
The node types that contains text features.
lm_feat_etypes : list of tuples
The edge types that contains text features.
Examples
----------
Expand All @@ -923,7 +878,7 @@ class GSgnnNodeTrainData(GSgnnNodeData):
np_data = GSgnnNodeTrainData(graph_name='dummy', part_config=part_config,
train_ntypes=['n1'], label_field='label',
node_feat_field='feat')
np_dataloader = GSgnnNodeDataLoader(np_data, target_idx={'n1':[0]},
np_dataloader = GSgnnNodeDataLoader(np_data, target_idx={'n1':[0]},
fanout=[15, 10], batch_size=128)
"""
def __init__(self, graph_name, part_config, train_ntypes, eval_ntypes=None,
Expand Down Expand Up @@ -1066,19 +1021,19 @@ class GSgnnNodeInferData(GSgnnNodeData):
The node types that contains text features.
lm_feat_etypes : list of tuples
The edge types that contains text features.
Examples
----------
.. code:: python
from graphstorm.dataloading import GSgnnNodeInferData
from graphstorm.dataloading import
from graphstorm.dataloading import
np_data = GSgnnNodeInferData(graph_name='dummy', part_config=part_config,
eval_ntypes=['n1'], label_field='label',
node_feat_field='feat')
np_dataloader = GSgnnNodeDataLoader(np_data, target_idx={'n1':[0]},
np_dataloader = GSgnnNodeDataLoader(np_data, target_idx={'n1':[0]},
fanout=[15, 10], batch_size=128)
"""
def __init__(self, graph_name, part_config, eval_ntypes,
Expand Down
10 changes: 0 additions & 10 deletions python/graphstorm/dataloading/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -137,13 +137,3 @@ def flip_node_mask(dist_tensor, indices):
part_policy=dist_tensor.part_policy)
flipped_dist_tensor[indices] = 1 - dist_tensor[indices]
return flipped_dist_tensor

def is_wholegraph_embedding(data):
""" Check if the data is in WholeMemory emedding format which
is required to use wholegraph framework.
"""
try:
import pylibwholegraph
return isinstance(data, pylibwholegraph.torch.WholeMemoryEmbedding)
except: # pylint: disable=bare-except
return False
25 changes: 1 addition & 24 deletions python/graphstorm/gsf.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,6 @@
import dgl
import torch as th
import torch.nn.functional as F
from dataclasses import dataclass
from dgl.distributed import role
from dgl.distributed.constants import DEFAULT_NTYPE
from dgl.distributed.constants import DEFAULT_ETYPE
Expand Down Expand Up @@ -65,29 +64,6 @@
LinkPredictWeightedDistMultDecoder)
from .tracker import get_task_tracker_class

def init_wholegraph():
""" Initialize Wholegraph"""
import pylibwholegraph.torch as wgth
import pylibwholegraph.binding.wholememory_binding as wmb

@dataclass
class Options: # pylint: disable=missing-class-docstring
pass
Options.launch_agent = 'pytorch'
Options.launch_env_name_world_rank = 'RANK'
Options.launch_env_name_world_size = 'WORLD_SIZE'
Options.launch_env_name_local_rank = 'LOCAL_RANK'
Options.launch_env_name_local_size = 'LOCAL_WORLD_SIZE'
Options.launch_env_name_master_addr = 'MASTER_ADDR'
Options.launch_env_name_master_port = 'MASTER_PORT'
Options.local_rank = get_rank() % role.get_num_trainers()
Options.local_size = role.get_num_trainers()

wgth.distributed_launch(Options, lambda: None)
wmb.init(0)
wgth.comm.set_world_info(get_rank(), get_world_size(), Options.local_rank,
Options.local_size)

def initialize(ip_config, backend, use_wholegraph=False):
""" Initialize distributed training and inference context.
Expand All @@ -108,6 +84,7 @@ def initialize(ip_config, backend, use_wholegraph=False):
th.distributed.init_process_group(backend=backend)
# Use wholegraph for feature and label fetching
if use_wholegraph:
from .wholegraph import init_wholegraph
init_wholegraph()
sys_tracker.check("load DistDGL")

Expand Down
19 changes: 19 additions & 0 deletions python/graphstorm/wholegraph/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,19 @@
"""
Copyright 2023 Contributors
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
Initial to import WholeGraph-related function and classes
"""
from .wholegraph import (init_wholegraph, convert_feat_to_wholegraph, load_wg_feat)
from .utils import is_wholegraph_embedding
28 changes: 28 additions & 0 deletions python/graphstorm/wholegraph/utils.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,28 @@
"""
Copyright 2023 Contributors
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
Utils for integrating WholeGraph into GraphStorm
"""


def is_wholegraph_embedding(data):
""" Check if the data is in WholeMemory emedding format which
is required to use wholegraph framework.
"""
try:
import pylibwholegraph
return isinstance(data, pylibwholegraph.torch.WholeMemoryEmbedding)
except ImportError:
return False
Loading

0 comments on commit aacf520

Please sign in to comment.