Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add the support of using WholeGraph distributed embedding to store/update sparse_emb #677

Merged
merged 35 commits into from
Feb 8, 2024
Merged
Show file tree
Hide file tree
Changes from 4 commits
Commits
Show all changes
35 commits
Select commit Hold shift + click to select a range
cc907b9
Add wholegraph distributed embedding support for sparse_emb
chang-l Dec 6, 2023
ebe46e6
Remove the test that is not ready
chang-l Dec 6, 2023
ab492a6
Add scatter op to load embeddings
chang-l Dec 8, 2023
2c0b535
Complete the tests
chang-l Dec 12, 2023
32a1f66
Minor update: reorder code
chang-l Dec 12, 2023
a407a01
Add more tests
chang-l Dec 12, 2023
46ad27d
Formatting for lint and better case control for tests
chang-l Dec 12, 2023
48ed6e4
Add env to turn on/off wg sparse emb
chang-l Dec 13, 2023
3c1e131
Fix a bug in wholegraph sparse_emb forward call
chang-l Dec 14, 2023
003dfe0
Refactor code to separate WholeGraph-related functions
chang-l Jan 2, 2024
3ef307e
Merge branch 'wholegraph_reorg' into add_wg_sparse_emb_rebased
chang-l Jan 3, 2024
26eb51c
Refactor and simplify WholeGraph integration of Sparse Opt
chang-l Jan 9, 2024
8bc04ce
Fix lint
chang-l Jan 10, 2024
04a47ea
Fix lint
chang-l Jan 10, 2024
8164218
Address comment
chang-l Jan 10, 2024
ef2b789
Fix lint
chang-l Jan 11, 2024
0c36e8c
Add Copyright
chang-l Jan 11, 2024
58b3e3a
Merge branch 'wholegraph_reorg' into add_wg_sparse_emb_rebased
chang-l Jan 12, 2024
afa3d05
Fix all tests
chang-l Jan 13, 2024
f8ab9df
Merge branch 'main' into add_wg_sparse_emb_rebased
chang-l Jan 13, 2024
a904722
Minor update
chang-l Jan 13, 2024
532ffcf
Update to compatiable when wholegraph is not avail
chang-l Jan 13, 2024
6ebb492
Partly address comments
chang-l Jan 19, 2024
4728512
Intermediate commit of refactoring WholeGraph Tensor class
chang-l Jan 22, 2024
f5b93ae
Refactor to materialize sparse_emb later
chang-l Jan 23, 2024
81b2b88
Update WG sparse opt unit test to compare against distDGL
chang-l Jan 24, 2024
bef8406
Address comments
chang-l Jan 26, 2024
a93b5ae
Minor update
chang-l Jan 26, 2024
5a8742f
Address comments
chang-l Feb 2, 2024
9f59760
Add cmd argument
chang-l Feb 2, 2024
b990920
Add e2e tests
chang-l Feb 3, 2024
cc4fd30
Add checker if see if wholegraph is installed or not
chang-l Feb 3, 2024
7fd694a
Roll back to remove e2e tests
chang-l Feb 3, 2024
8bad6bd
Merge branch 'main' into add_wg_sparse_emb_rebased
classicsong Feb 8, 2024
158e8c9
Resolve conflicts in main branch for unit tests
chang-l Feb 8, 2024
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 2 additions & 1 deletion python/graphstorm/config/argument.py
Original file line number Diff line number Diff line change
Expand Up @@ -1184,7 +1184,8 @@ def use_wholegraph_sparse_emb(self):
"""
# pylint: disable=no-member
if hasattr(self, "_use_wholegraph_sparse_emb"):
assert self._use_wholegraph_sparse_emb in [True, False]
assert self._use_wholegraph_sparse_emb in [True, False], \
"Invalid value for _use_wholegraph_sparse_emb. Must be either True or False."
return self._use_wholegraph_sparse_emb
# By default do not use wholegraph for learnable node embeddings
return False
Expand Down
70 changes: 36 additions & 34 deletions python/graphstorm/model/embed.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,7 @@
create_dist_tensor,
)
from .ngnn_mlp import NGNNMLP
from ..wholegraph import create_wholememory_optimizer, WholeGraphSparseEmbedding
from ..wholegraph import WholeGraphDistTensor
from ..wholegraph import is_wholegraph_init


Expand Down Expand Up @@ -221,19 +221,27 @@ def __init__(self,
self.embed_size = embed_size
self.dropout = nn.Dropout(dropout)
self.use_node_embeddings = use_node_embeddings
self.use_wholegraph_sparse_emb = use_wholegraph_sparse_emb
self._use_wholegraph_sparse_emb = use_wholegraph_sparse_emb
self.feat_size = feat_size
if force_no_embeddings is None:
force_no_embeddings = []

self.activation = activation
self.cache_embed = cache_embed

if self._use_wholegraph_sparse_emb:
if get_backend() != "nccl":
chang-l marked this conversation as resolved.
Show resolved Hide resolved
raise AssertionError(
"WholeGraph sparse embedding is only supported on NCCL backend."
)
if not is_wholegraph_init():
chang-l marked this conversation as resolved.
Show resolved Hide resolved
raise AssertionError("WholeGraph is not initialized yet.")

if (
dgl.__version__ <= "1.1.2"
and is_distributed()
and get_backend() == "nccl"
and not self.use_wholegraph_sparse_emb
and not self._use_wholegraph_sparse_emb
):
if self.use_node_embeddings:
raise NotImplementedError(
Expand All @@ -247,23 +255,11 @@ def __init__(self,
+ "learnable embeddings on featureless nodes. Please use DGL version "
+ ">=1.1.2 or gloo backend."
)
if self.use_wholegraph_sparse_emb:
if get_backend() != "nccl":
raise AssertionError(
"WholeGraph sparse embedding is only supported on NCCL backend."
)
if not is_wholegraph_init():
raise AssertionError("WholeGraph is not initialized yet.")

# create weight embeddings for each node for each relation
self.proj_matrix = nn.ParameterDict()
self.input_projs = nn.ParameterDict()
embed_name = "embed"
if self.use_wholegraph_sparse_emb:
# WG sparse optimizer has to be created at first like below
# This is because WG embedding depends on WG sparse optimizer to track/trace
# the gradients for embeddings.
self.wg_sparse_embs_optimizer = create_wholememory_optimizer("adam", {})
for ntype in g.ntypes:
feat_dim = 0
if feat_size[ntype] > 0:
Expand All @@ -275,17 +271,17 @@ def __init__(self,
nn.init.xavier_uniform_(input_projs, gain=nn.init.calculate_gain("relu"))
self.input_projs[ntype] = input_projs
if self.use_node_embeddings:
if self.use_wholegraph_sparse_emb:
if self._use_wholegraph_sparse_emb:
if get_rank() == 0:
logging.debug(
"Use WholeGraph to host additional sparse embeddings on node %s",
ntype,
)
self._sparse_embeds[ntype] = WholeGraphSparseEmbedding(
g.number_of_nodes(ntype),
self.embed_size,
self._sparse_embeds[ntype] = WholeGraphDistTensor(
(g.number_of_nodes(ntype), self.embed_size),
th.float32, # to consistent with distDGL's DistEmbedding dtype
embed_name + "_" + ntype,
self.wg_sparse_embs_optimizer
use_wg_optimizer=True, # no memory allocation before opt available
)
else:
if get_rank() == 0:
Expand All @@ -305,18 +301,18 @@ def __init__(self,
self.proj_matrix[ntype] = proj_matrix

elif ntype not in force_no_embeddings:
if self.use_wholegraph_sparse_emb:
if self._use_wholegraph_sparse_emb:
if get_rank() == 0:
logging.debug(
"Use WholeGraph to host sparse embeddings on node %s:%d",
ntype,
g.number_of_nodes(ntype),
)
self._sparse_embeds[ntype] = WholeGraphSparseEmbedding(
g.number_of_nodes(ntype),
self.embed_size,
embed_name + '_' + ntype,
self.wg_sparse_embs_optimizer
self._sparse_embeds[ntype] = WholeGraphDistTensor(
(g.number_of_nodes(ntype), self.embed_size),
th.float32, # to consistent with distDGL's DistEmbedding dtype
embed_name + "_" + ntype,
use_wg_optimizer=True, # no memory allocation before opt available
)
else:
if get_rank() == 0:
Expand Down Expand Up @@ -372,22 +368,22 @@ def forward(self, input_feats, input_nodes):
assert ntype in self.sparse_embeds, \
f"We need sparse embedding for node type {ntype}"
# emb.device: target device to put the gathered results
node_emb = self.sparse_embeds[ntype](input_nodes[ntype], emb.device)
if self._use_wholegraph_sparse_emb:
node_emb = self.sparse_embeds[ntype].module(input_nodes[ntype].cuda())
classicsong marked this conversation as resolved.
Show resolved Hide resolved
node_emb = node_emb.to(emb.device, non_blocking=True)
else:
node_emb = self.sparse_embeds[ntype](input_nodes[ntype], emb.device)
concat_emb = th.cat((emb, node_emb), dim=1)
emb = concat_emb @ self.proj_matrix[ntype]
elif ntype in self.sparse_embeds: # nodes do not have input features
# If the number of the input node of a node type is 0,
# return an empty tensor with shape (0, emb_size)
device = self.proj_matrix[ntype].device
# If DistEmbedding supports 0-size input, we can remove this if statement.
if isinstance(self.sparse_embeds[ntype], WholeGraphSparseEmbedding):
if isinstance(self.sparse_embeds[ntype], WholeGraphDistTensor):
# Need all procs pass the following due to nccl all2lallv in wholegraph
emb = self.sparse_embeds[ntype](input_nodes[ntype], device)
if len(input_nodes[ntype]) == 0:
dtype = self.sparse_embeds[ntype].weight.dtype
embs[ntype] = th.zeros((0, self.sparse_embeds[ntype].embedding_dim),
device=device, dtype=dtype)
continue
emb = self.sparse_embeds[ntype].module(input_nodes[ntype].cuda())
classicsong marked this conversation as resolved.
Show resolved Hide resolved
emb = emb.to(device, non_blocking=True)
else:
if len(input_nodes[ntype]) == 0:
dtype = self.sparse_embeds[ntype].weight.dtype
Expand Down Expand Up @@ -449,6 +445,12 @@ def out_dims(self):
"""
return self.embed_size

@property
def use_wholegraph_sparse_emb(self):
""" Whether or not to use WholeGraph to host embeddings for sparse updates.
"""
return self._use_wholegraph_sparse_emb


def _gen_emb(g, feat_field, embed_layer, ntype):
""" Test if the embed layer can generate embeddings on the node type.
Expand Down
22 changes: 9 additions & 13 deletions python/graphstorm/model/gnn.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,7 +39,7 @@
get_world_size,
barrier
)
from ..wholegraph import is_wholegraph_optimizer
from ..wholegraph import is_wholegraph_optimizer, create_wholememory_optimizer, WholeGraphDistTensor

from ..dataloading.dataset import prepare_batch_input

Expand Down Expand Up @@ -562,13 +562,6 @@ def use_wholegraph_sparse_emb(self):
return self.node_input_encoder.use_wholegraph_sparse_emb
return False

def get_wholegraph_optimizer(self):
""" Get the WholeGraph optimizer for updating WholeGraph hosted embeddings .
"""
if self.node_input_encoder is not None:
return self.node_input_encoder.wg_sparse_embs_optimizer
return None

def set_node_input_encoder(self, encoder):
"""set the input encoder for nodes.

Expand Down Expand Up @@ -759,11 +752,14 @@ def init_optimizer(self, lr, sparse_optimizer_lr, weight_decay, lm_lr=None):
if len(sparse_params) > 0:
if self.use_wholegraph_sparse_emb():
# To use wholegraph sparse optimizer, optimizer needs to be created
# before sparse embeddings. So, here we just get the optimizer from
# WholeGraphSparseEmbedding and ensure the identity of the optimizer
emb_optimizer = self.get_wholegraph_optimizer()
assert all(params.optimizer is emb_optimizer for params in sparse_params), \
"We only need one wholegraph optimizer for all wm_embeddings."
# before sparse embeddings. Within attach_wg_optimizer, we materialize
# the WG distributed tensor and then attach the optimizer.
emb_optimizer = create_wholememory_optimizer("adam", {})
for params in sparse_params:
for param in params:
assert isinstance(param, WholeGraphDistTensor) and param.use_wg_optimizer, \
"Please create params (WG tensor) with use_wg_optimizer=True."
chang-l marked this conversation as resolved.
Show resolved Hide resolved
param.attach_wg_optimizer(emb_optimizer)
# TODO(@chang-l): Wrap the wholegraph optimizer in a class to
# take an extra input argument: lr
emb_optimizer.lr = sparse_optimizer_lr
Expand Down
12 changes: 6 additions & 6 deletions python/graphstorm/model/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,7 +36,7 @@
get_world_size,
create_dist_tensor,
)
from ..wholegraph import WholeGraphSparseEmbedding
from ..wholegraph import WholeGraphDistTensor
from ..data.utils import alltoallv_cpu, alltoallv_nccl
from ..distributed import flush_data

Expand Down Expand Up @@ -186,7 +186,7 @@ def save_sparse_emb(model_path, sparse_emb, ntype):
----------
model_path: str
The path of the model is saved.
sparse_emb: dgl.distributed.DistEmbedding or wholegraph.WholeGraphSparseEmbedding
sparse_emb: dgl.distributed.DistEmbedding or wholegraph.WholeGraphDistTensor
A Distributed node embedding.
ntype: str
The node type the embedding belongs to.
Expand All @@ -204,8 +204,8 @@ def save_sparse_emb(model_path, sparse_emb, ntype):
emb_path = os.path.join(model_path, ntype)
os.makedirs(emb_path, exist_ok=True)

if isinstance(sparse_emb, WholeGraphSparseEmbedding):
(local_tensor, _) = sparse_emb.weight.get_local_tensor(host_view=True)
if isinstance(sparse_emb, WholeGraphDistTensor):
(local_tensor, _) = sparse_emb.get_local_tensor()
# Using WholeGraph will save sparse emb in binary format (evenly distributed)
# Example: wg_sparse_emb_part_1_of_2, wg_sparse_emb_part_2_of_2
assert (
Expand Down Expand Up @@ -1380,9 +1380,9 @@ def load_sparse_emb(target_sparse_emb, ntype_emb_path):
num_files = len(os.listdir(ntype_emb_path))
num_embs = target_sparse_emb.num_embeddings

if isinstance(target_sparse_emb, WholeGraphSparseEmbedding):
if isinstance(target_sparse_emb, WholeGraphDistTensor):
# Using WholeGraph will load sparse emb in binary format, let's assume
# the sparse emb is saved by WholeGraphSparseEmbedding.save_to_file(), i.e.,
# the sparse emb is saved by WholeGraphDistTensor.save_to_file(), i.e.,
# the meta info remains the same.
# Example: wg_sparse_emb_part_1_of_2, wg_sparse_emb_part_2_of_2
target_sparse_emb.load_from_file(ntype_emb_path, "wg_sparse_emb", num_files)
classicsong marked this conversation as resolved.
Show resolved Hide resolved
Expand Down
4 changes: 2 additions & 2 deletions python/graphstorm/wholegraph/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,8 +22,8 @@
load_wg_feat
)

from .wholegraph import (create_wholememory_optimizer, create_wg_sparse_params)
from .wholegraph import WholeGraphSparseEmbedding
from .wholegraph import create_wholememory_optimizer
from .wholegraph import WholeGraphDistTensor

from .utils import (
is_wholegraph_embedding,
Expand Down
Loading
Loading