Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[WholeGraph] Refactor code to separate WholeGraph-related functions #697

Merged
merged 5 commits into from
Jan 12, 2024
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
80 changes: 18 additions & 62 deletions python/graphstorm/dataloading/dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,9 +26,11 @@
from torch.utils.data import Dataset
import pandas as pd

from ..utils import get_rank, get_world_size, is_distributed, barrier
from ..utils import sys_tracker, is_wholegraph
from .utils import dist_sum, flip_node_mask, is_wholegraph_embedding
from ..utils import get_rank, get_world_size, is_distributed, barrier, is_wholegraph
from ..utils import sys_tracker
from .utils import dist_sum, flip_node_mask

from ..wholegraph_utils import is_wholegraph_embedding
chang-l marked this conversation as resolved.
Show resolved Hide resolved

def split_full_edge_list(g, etype, rank):
''' Split the full edge list of a graph.
Expand Down Expand Up @@ -167,6 +169,7 @@ def __init__(self, graph_name, part_config, node_feat_field, edge_feat_field,

# Use wholegraph for feature transfer
if is_distributed() and is_wholegraph():
from ..wholegraph import load_wg_feat
logging.info("Allocate features with Wholegraph")
num_parts = self._g.get_partition_book().num_partitions()

Expand All @@ -183,7 +186,7 @@ def __init__(self, graph_name, part_config, node_feat_field, edge_feat_field,
f"Feature '{name}' of '{ntype}' is not in WholeGraph format. " \
f"Please convert all the available features to WholeGraph " \
f"format to utilize WholeGraph."
data[name] = self.load_wg_feat(part_config, num_parts, ntype, name)
data[name] = load_wg_feat(part_config, num_parts, ntype, name)
if len(self._g.ntypes) == 1:
self._g._ndata_store.update(data)
else:
Expand All @@ -205,7 +208,7 @@ def __init__(self, graph_name, part_config, node_feat_field, edge_feat_field,
f"Feature '{name}' of '{etype}' is not in WholeGraph format. " \
f"Please convert all the available features to WholeGraph " \
f"format to utilize WholeGraph."
data[name] = self.load_wg_feat(part_config, num_parts, etype_wg, name)
data[name] = load_wg_feat(part_config, num_parts, etype_wg, name)
if len(self._g.canonical_etypes) == 1:
self._g._edata_store.update(data)
else:
Expand Down Expand Up @@ -240,53 +243,6 @@ def edge_feat_field(self):
"""the field of edge feature"""
return self._edge_feat_field

def load_wg_feat(self, part_config_path, num_parts, type_name, name):
"""Load features from wholegraph memory

Parameters
----------
part_config_path : str
The path of the partition configuration file.
num_parts : int
The number of partitions of the dataset
type_name: str
The type of node or edge for which to fetch features or labels for.
name: str
The name of the features to load
"""
import pylibwholegraph.torch as wgth

global_comm = wgth.comm.get_global_communicator()
feature_comm = global_comm
embedding_wholememory_type = 'distributed'
embedding_wholememory_location = 'cpu'
cache_policy = wgth.create_builtin_cache_policy(
"none", # cache type
embedding_wholememory_type,
embedding_wholememory_location,
"readonly", # access type
0.0, # cache ratio
)
metadata_file = os.path.join(os.path.dirname(part_config_path),
'wholegraph/metadata.json')
with open(metadata_file, encoding="utf8") as f:
wg_metadata = json.load(f)
data_shape = wg_metadata[type_name + '/' + name]['shape']
feat_wm_embedding = wgth.create_embedding(
feature_comm,
embedding_wholememory_type,
embedding_wholememory_location,
getattr(th, wg_metadata[type_name + '/' + name]['dtype'].split('.')[1]),
[data_shape[0],1] if len(data_shape) == 1 else data_shape,
optimizer=None,
cache_policy=cache_policy,
)
feat_path = os.path.join(os.path.dirname(part_config_path), 'wholegraph', \
type_name + '~' + name)
feat_wm_embedding.get_embedding_tensor().from_file_prefix(feat_path,
part_count=num_parts)
return feat_wm_embedding

def has_node_feats(self, ntype):
""" Test if the specified node type has features.

Expand Down Expand Up @@ -513,7 +469,7 @@ class GSgnnEdgeTrainData(GSgnnEdgeData):
different feature names.
decoder_edge_feat: str or dict of list of str
Edge features used by decoder

Examples
----------

Expand All @@ -524,7 +480,7 @@ class GSgnnEdgeTrainData(GSgnnEdgeData):
ep_data = GSgnnEdgeTrainData(graph_name='dummy', part_config=part_config,
train_etypes=[('n1', 'e1', 'n2')], label_field='label',
node_feat_field='node_feat', edge_feat_field='edge_feat')
ep_dataloader = GSgnnEdgeDataLoader(ep_data, target_idx={"e1":[0]},
ep_dataloader = GSgnnEdgeDataLoader(ep_data, target_idx={"e1":[0]},
fanout=[15, 10], batch_size=128)
"""
def __init__(self, graph_name, part_config, train_etypes, eval_etypes=None,
Expand Down Expand Up @@ -700,7 +656,7 @@ class GSgnnEdgeInferData(GSgnnEdgeData):
The node types that contains text features.
lm_feat_etypes : list of tuples
The edge types that contains text features.

Examples
----------

Expand All @@ -711,7 +667,7 @@ class GSgnnEdgeInferData(GSgnnEdgeData):
ep_data = GSgnnEdgeInferData(graph_name='dummy', part_config=part_config,
eval_etypes=[('n1', 'e1', 'n2')], label_field='label',
node_feat_field='node_feat', edge_feat_field='edge_feat')
ep_dataloader = GSgnnEdgeDataLoader(ep_data, target_idx={"e1":[0]},
ep_dataloader = GSgnnEdgeDataLoader(ep_data, target_idx={"e1":[0]},
fanout=[15, 10], batch_size=128)
"""
def __init__(self, graph_name, part_config, eval_etypes,
Expand Down Expand Up @@ -893,7 +849,7 @@ class GSgnnNodeTrainData(GSgnnNodeData):
The node types that contains text features.
lm_feat_etypes : list of tuples
The edge types that contains text features.

Examples
----------

Expand All @@ -905,7 +861,7 @@ class GSgnnNodeTrainData(GSgnnNodeData):
np_data = GSgnnNodeTrainData(graph_name='dummy', part_config=part_config,
train_ntypes=['n1'], label_field='label',
node_feat_field='feat')
np_dataloader = GSgnnNodeDataLoader(np_data, target_idx={'n1':[0]},
np_dataloader = GSgnnNodeDataLoader(np_data, target_idx={'n1':[0]},
fanout=[15, 10], batch_size=128)
"""
def __init__(self, graph_name, part_config, train_ntypes, eval_ntypes=None,
Expand Down Expand Up @@ -1041,19 +997,19 @@ class GSgnnNodeInferData(GSgnnNodeData):
The node types that contains text features.
lm_feat_etypes : list of tuples
The edge types that contains text features.

Examples
----------

.. code:: python

from graphstorm.dataloading import GSgnnNodeInferData
from graphstorm.dataloading import
from graphstorm.dataloading import

np_data = GSgnnNodeInferData(graph_name='dummy', part_config=part_config,
eval_ntypes=['n1'], label_field='label',
node_feat_field='feat')
np_dataloader = GSgnnNodeDataLoader(np_data, target_idx={'n1':[0]},
np_dataloader = GSgnnNodeDataLoader(np_data, target_idx={'n1':[0]},
fanout=[15, 10], batch_size=128)
"""
def __init__(self, graph_name, part_config, eval_ntypes,
Expand Down
10 changes: 0 additions & 10 deletions python/graphstorm/dataloading/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -137,13 +137,3 @@ def flip_node_mask(dist_tensor, indices):
part_policy=dist_tensor.part_policy)
flipped_dist_tensor[indices] = 1 - dist_tensor[indices]
return flipped_dist_tensor

def is_wholegraph_embedding(data):
""" Check if the data is in WholeMemory emedding format which
is required to use wholegraph framework.
"""
try:
import pylibwholegraph
return isinstance(data, pylibwholegraph.torch.WholeMemoryEmbedding)
except: # pylint: disable=bare-except
return False
25 changes: 1 addition & 24 deletions python/graphstorm/gsf.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,6 @@
import dgl
import torch as th
import torch.nn.functional as F
from dataclasses import dataclass
from dgl.distributed import role
from dgl.distributed.constants import DEFAULT_NTYPE
from dgl.distributed.constants import DEFAULT_ETYPE
Expand Down Expand Up @@ -65,29 +64,6 @@
LinkPredictWeightedDistMultDecoder)
from .tracker import get_task_tracker_class

def init_wholegraph():
""" Initialize Wholegraph"""
import pylibwholegraph.torch as wgth
import pylibwholegraph.binding.wholememory_binding as wmb

@dataclass
class Options: # pylint: disable=missing-class-docstring
pass
Options.launch_agent = 'pytorch'
Options.launch_env_name_world_rank = 'RANK'
Options.launch_env_name_world_size = 'WORLD_SIZE'
Options.launch_env_name_local_rank = 'LOCAL_RANK'
Options.launch_env_name_local_size = 'LOCAL_WORLD_SIZE'
Options.launch_env_name_master_addr = 'MASTER_ADDR'
Options.launch_env_name_master_port = 'MASTER_PORT'
Options.local_rank = get_rank() % role.get_num_trainers()
Options.local_size = role.get_num_trainers()

wgth.distributed_launch(Options, lambda: None)
wmb.init(0)
wgth.comm.set_world_info(get_rank(), get_world_size(), Options.local_rank,
Options.local_size)

def initialize(ip_config, backend, use_wholegraph=False):
""" Initialize distributed training and inference context.

Expand All @@ -108,6 +84,7 @@ def initialize(ip_config, backend, use_wholegraph=False):
th.distributed.init_process_group(backend=backend)
# Use wholegraph for feature and label fetching
if use_wholegraph:
from .wholegraph import init_wholegraph
init_wholegraph()
sys_tracker.check("load DistDGL")

Expand Down
85 changes: 85 additions & 0 deletions python/graphstorm/wholegraph.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,85 @@
"""This module provides functionality for working with the WholeGraph"""
import os
import json
from typing import Optional

import torch as th
from dataclasses import dataclass
from .utils import get_rank, get_world_size

try:
import pylibwholegraph.torch as wgth
except ImportError:
wgth = None

def init_wholegraph():
""" Initialize Wholegraph"""
if wgth is None:
raise ImportError("WholeGraph is not installed")
from dgl.distributed import role
import pylibwholegraph.binding.wholememory_binding as wmb

@dataclass
class Options: # pylint: disable=missing-class-docstring
pass
Options.launch_agent = 'pytorch'
Options.launch_env_name_world_rank = 'RANK'
Options.launch_env_name_world_size = 'WORLD_SIZE'
Options.launch_env_name_local_rank = 'LOCAL_RANK'
Options.launch_env_name_local_size = 'LOCAL_WORLD_SIZE'
Options.launch_env_name_master_addr = 'MASTER_ADDR'
Options.launch_env_name_master_port = 'MASTER_PORT'
Options.local_rank = get_rank() % role.get_num_trainers()
Options.local_size = role.get_num_trainers()

wgth.distributed_launch(Options, lambda: None)
wmb.init(0)
wgth.comm.set_world_info(get_rank(), get_world_size(), Options.local_rank,
Options.local_size)

def load_wg_feat(part_config_path, num_parts, type_name, name):
"""Load features from wholegraph memory

Parameters
----------
part_config_path : str
The path of the partition configuration file.
num_parts : int
The number of partitions of the dataset
type_name: str
The type of node or edge for which to fetch features or labels for.
name: str
The name of the features to load
"""
if wgth is None:
raise ImportError("WholeGraph is not installed")
global_comm = wgth.comm.get_global_communicator()
feature_comm = global_comm
embedding_wholememory_type = 'distributed'
embedding_wholememory_location = 'cpu'
cache_policy = wgth.create_builtin_cache_policy(
"none", # cache type
embedding_wholememory_type,
embedding_wholememory_location,
"readonly", # access type
0.0, # cache ratio
)
metadata_file = os.path.join(os.path.dirname(part_config_path),
'wholegraph/metadata.json')
with open(metadata_file, encoding="utf8") as f:
wg_metadata = json.load(f)
data_shape = wg_metadata[type_name + '/' + name]['shape']
feat_wm_embedding = wgth.create_embedding(
feature_comm,
embedding_wholememory_type,
embedding_wholememory_location,
getattr(th, wg_metadata[type_name + '/' + name]['dtype'].split('.')[1]),
[data_shape[0],1] if len(data_shape) == 1 else data_shape,
optimizer=None,
cache_policy=cache_policy,
)
feat_path = os.path.join(os.path.dirname(part_config_path), 'wholegraph', \
type_name + '~' + name)
feat_wm_embedding.get_embedding_tensor().from_file_prefix(feat_path,
part_count=num_parts)
return feat_wm_embedding
15 changes: 15 additions & 0 deletions python/graphstorm/wholegraph_utils.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,15 @@
"""This module provides utility functions for working with WholeGraph"""
from .utils import is_wholegraph

def is_wholegraph_embedding(data):
""" Check if the data is in WholeMemory emedding format which
is required to use wholegraph framework.
"""
try:
import pylibwholegraph
assert (
is_wholegraph()
), "WholeGraph needs to be enabled first."
return isinstance(data, pylibwholegraph.torch.WholeMemoryEmbedding)
except ImportError:
return False
Loading