Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add the support of using WholeGraph distributed embedding to store/update sparse_emb #677

Merged
merged 35 commits into from
Feb 8, 2024
Merged
Show file tree
Hide file tree
Changes from 34 commits
Commits
Show all changes
35 commits
Select commit Hold shift + click to select a range
cc907b9
Add wholegraph distributed embedding support for sparse_emb
chang-l Dec 6, 2023
ebe46e6
Remove the test that is not ready
chang-l Dec 6, 2023
ab492a6
Add scatter op to load embeddings
chang-l Dec 8, 2023
2c0b535
Complete the tests
chang-l Dec 12, 2023
32a1f66
Minor update: reorder code
chang-l Dec 12, 2023
a407a01
Add more tests
chang-l Dec 12, 2023
46ad27d
Formatting for lint and better case control for tests
chang-l Dec 12, 2023
48ed6e4
Add env to turn on/off wg sparse emb
chang-l Dec 13, 2023
3c1e131
Fix a bug in wholegraph sparse_emb forward call
chang-l Dec 14, 2023
003dfe0
Refactor code to separate WholeGraph-related functions
chang-l Jan 2, 2024
3ef307e
Merge branch 'wholegraph_reorg' into add_wg_sparse_emb_rebased
chang-l Jan 3, 2024
26eb51c
Refactor and simplify WholeGraph integration of Sparse Opt
chang-l Jan 9, 2024
8bc04ce
Fix lint
chang-l Jan 10, 2024
04a47ea
Fix lint
chang-l Jan 10, 2024
8164218
Address comment
chang-l Jan 10, 2024
ef2b789
Fix lint
chang-l Jan 11, 2024
0c36e8c
Add Copyright
chang-l Jan 11, 2024
58b3e3a
Merge branch 'wholegraph_reorg' into add_wg_sparse_emb_rebased
chang-l Jan 12, 2024
afa3d05
Fix all tests
chang-l Jan 13, 2024
f8ab9df
Merge branch 'main' into add_wg_sparse_emb_rebased
chang-l Jan 13, 2024
a904722
Minor update
chang-l Jan 13, 2024
532ffcf
Update to compatiable when wholegraph is not avail
chang-l Jan 13, 2024
6ebb492
Partly address comments
chang-l Jan 19, 2024
4728512
Intermediate commit of refactoring WholeGraph Tensor class
chang-l Jan 22, 2024
f5b93ae
Refactor to materialize sparse_emb later
chang-l Jan 23, 2024
81b2b88
Update WG sparse opt unit test to compare against distDGL
chang-l Jan 24, 2024
bef8406
Address comments
chang-l Jan 26, 2024
a93b5ae
Minor update
chang-l Jan 26, 2024
5a8742f
Address comments
chang-l Feb 2, 2024
9f59760
Add cmd argument
chang-l Feb 2, 2024
b990920
Add e2e tests
chang-l Feb 3, 2024
cc4fd30
Add checker if see if wholegraph is installed or not
chang-l Feb 3, 2024
7fd694a
Roll back to remove e2e tests
chang-l Feb 3, 2024
8bad6bd
Merge branch 'main' into add_wg_sparse_emb_rebased
classicsong Feb 8, 2024
158e8c9
Resolve conflicts in main branch for unit tests
chang-l Feb 8, 2024
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
18 changes: 18 additions & 0 deletions python/graphstorm/config/argument.py
Original file line number Diff line number Diff line change
Expand Up @@ -289,6 +289,7 @@ def verify_arguments(self, is_train):
_ = self.grad_norm_type
_ = self.gnn_norm
_ = self.sparse_optimizer_lr
_ = self.use_wholegraph_sparse_emb
_ = self.num_epochs
_ = self.save_model_path
_ = self.save_model_frequency
Expand Down Expand Up @@ -1177,6 +1178,18 @@ def sparse_optimizer_lr(self): # pylint: disable=invalid-name

return self.lr

@property
def use_wholegraph_sparse_emb(self):
""" Whether to use wholegraph for updating learnable node embeddings
"""
# pylint: disable=no-member
if hasattr(self, "_use_wholegraph_sparse_emb"):
assert self._use_wholegraph_sparse_emb in [True, False], \
"Invalid value for _use_wholegraph_sparse_emb. Must be either True or False."
return self._use_wholegraph_sparse_emb
# By default do not use wholegraph for learnable node embeddings
return False

@property
def use_node_embeddings(self):
""" Whether to use extra learnable node embeddings
Expand Down Expand Up @@ -2398,6 +2411,11 @@ def _add_hyperparam_args(parser):
type=lambda x: (str(x).lower() in ['true', '1']),
default=argparse.SUPPRESS,
help="Whether to use extra learnable node embeddings")
group.add_argument(
"--use-wholegraph-sparse-emb",
type=lambda x: (str(x).lower() in ['true', '1']),
default=argparse.SUPPRESS,
help="Whether to use WholeGraph library to update learnable node embeddings")
group.add_argument("--construct-feat-ntype", type=str, nargs="+",
help="The node types whose features are constructed from neighbors' features.")
group.add_argument("--construct-feat-encoder", type=str, default=argparse.SUPPRESS,
Expand Down
3 changes: 2 additions & 1 deletion python/graphstorm/gsf.py
Original file line number Diff line number Diff line change
Expand Up @@ -533,7 +533,8 @@ def set_encoder(model, g, config, train_task):
activation=config.input_activate,
use_node_embeddings=config.use_node_embeddings,
force_no_embeddings=config.construct_feat_ntype,
num_ffn_layers_in_input=config.num_ffn_layers_in_input)
num_ffn_layers_in_input=config.num_ffn_layers_in_input,
use_wholegraph_sparse_emb=config.use_wholegraph_sparse_emb)
# The number of feature dimensions can change. For example, the feature dimensions
# of BERT embeddings are determined when the input encoder is created.
feat_size = encoder.in_dims
Expand Down
180 changes: 133 additions & 47 deletions python/graphstorm/model/embed.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@

import time
import logging
import numpy as np
import torch as th
from torch import nn
import torch.nn.functional as F
Expand All @@ -26,8 +27,17 @@

from .gs_layer import GSLayer
from ..dataloading.dataset import prepare_batch_input
from ..utils import get_rank, barrier, is_distributed, get_backend, create_dist_tensor
from ..utils import (
get_rank,
barrier,
is_distributed,
get_backend,
create_dist_tensor,
)
from .ngnn_mlp import NGNNMLP
from ..wholegraph import WholeGraphDistTensor
from ..wholegraph import is_wholegraph_init


def init_emb(shape, dtype):
"""Create a tensor with the given shape and date type.
Expand All @@ -50,7 +60,8 @@ def init_emb(shape, dtype):
nn.init.uniform_(arr, -1.0, 1.0)
return arr

class GSNodeInputLayer(GSLayer): # pylint: disable=abstract-method

class GSNodeInputLayer(GSLayer): # pylint: disable=abstract-method
"""The input layer for all nodes in a heterogeneous graph.

Parameters
Expand Down Expand Up @@ -140,6 +151,7 @@ def in_dims(self):
"""
return None


class GSNodeEncoderInputLayer(GSNodeInputLayer):
"""The input encoder layer for all nodes in a heterogeneous graph.

Expand Down Expand Up @@ -172,10 +184,12 @@ class GSNodeEncoderInputLayer(GSNodeInputLayer):
The activation function for the feedforward neural networks.
cache_embed : bool
Whether or not to cache the embeddings.
use_wholegraph_sparse_emb : bool
Whether or not to use WholeGraph to host embeddings for sparse updates.

Examples:
----------

.. code:: python

from graphstorm import get_node_feat_size
Expand All @@ -201,69 +215,116 @@ def __init__(self,
force_no_embeddings=None,
num_ffn_layers_in_input=0,
ffn_activation=F.relu,
cache_embed=False):
cache_embed=False,
use_wholegraph_sparse_emb=False):
super(GSNodeEncoderInputLayer, self).__init__(g)
self.embed_size = embed_size
self.dropout = nn.Dropout(dropout)
self.use_node_embeddings = use_node_embeddings
self._use_wholegraph_sparse_emb = use_wholegraph_sparse_emb
self.feat_size = feat_size
if force_no_embeddings is None:
force_no_embeddings = []

self.activation = activation
self.cache_embed = cache_embed

if dgl.__version__ <= "1.1.2" and is_distributed() and get_backend() == "nccl":
if self._use_wholegraph_sparse_emb:
assert get_backend() == "nccl", \
"WholeGraph sparse embedding is only supported on NCCL backend."
assert is_wholegraph_init(), \
"WholeGraph is not initialized yet."
if (
dgl.__version__ <= "1.1.2"
and is_distributed()
and get_backend() == "nccl"
and not self._use_wholegraph_sparse_emb
):
if self.use_node_embeddings:
raise NotImplementedError('NCCL backend is not supported for utilizing ' +
'node embeddings. Please use DGL version >=1.1.2 or gloo backend.')
raise NotImplementedError(
"NCCL backend is not supported for utilizing "
+ "node embeddings. Please use DGL version >=1.1.2 or gloo backend."
)
for ntype in g.ntypes:
if not feat_size[ntype]:
raise NotImplementedError('NCCL backend is not supported for utilizing ' +
'learnable embeddings on featureless nodes. Please use DGL version ' +
'>=1.1.2 or gloo backend.')
raise NotImplementedError(
"NCCL backend is not supported for utilizing "
+ "learnable embeddings on featureless nodes. Please use DGL version "
+ ">=1.1.2 or gloo backend."
)

# create weight embeddings for each node for each relation
self.proj_matrix = nn.ParameterDict()
self.input_projs = nn.ParameterDict()
embed_name = 'embed'
embed_name = "embed"
for ntype in g.ntypes:
feat_dim = 0
if feat_size[ntype] > 0:
feat_dim += feat_size[ntype]
if feat_dim > 0:
if get_rank() == 0:
logging.debug('Node %s has %d features.', ntype, feat_dim)
logging.debug("Node %s has %d features.", ntype, feat_dim)
input_projs = nn.Parameter(th.Tensor(feat_dim, self.embed_size))
nn.init.xavier_uniform_(input_projs, gain=nn.init.calculate_gain('relu'))
nn.init.xavier_uniform_(input_projs, gain=nn.init.calculate_gain("relu"))
self.input_projs[ntype] = input_projs
if self.use_node_embeddings:
if get_rank() == 0:
logging.debug('Use additional sparse embeddings on node %s', ntype)
part_policy = g.get_node_partition_policy(ntype)
self._sparse_embeds[ntype] = DistEmbedding(g.number_of_nodes(ntype),
self.embed_size,
embed_name + '_' + ntype,
init_emb,
part_policy)
if self._use_wholegraph_sparse_emb:
if get_rank() == 0:
logging.debug(
"Use WholeGraph to host additional sparse embeddings on node %s",
ntype,
)
self._sparse_embeds[ntype] = WholeGraphDistTensor(
(g.number_of_nodes(ntype), self.embed_size),
th.float32, # to consistent with distDGL's DistEmbedding dtype
embed_name + "_" + ntype,
use_wg_optimizer=True, # no memory allocation before opt available
)
else:
if get_rank() == 0:
logging.debug("Use additional sparse embeddings on node %s", ntype)
part_policy = g.get_node_partition_policy(ntype)
self._sparse_embeds[ntype] = DistEmbedding(
g.number_of_nodes(ntype),
self.embed_size,
embed_name + "_" + ntype,
init_emb,
part_policy,
)
proj_matrix = nn.Parameter(th.Tensor(2 * self.embed_size, self.embed_size))
nn.init.xavier_uniform_(proj_matrix, gain=nn.init.calculate_gain('relu'))
nn.init.xavier_uniform_(proj_matrix, gain=nn.init.calculate_gain("relu"))
# nn.ParameterDict support this assignment operation if not None,
# so disable the pylint error
self.proj_matrix[ntype] = proj_matrix # pylint: disable=unsupported-assignment-operation
self.proj_matrix[ntype] = proj_matrix

elif ntype not in force_no_embeddings:
part_policy = g.get_node_partition_policy(ntype)
if get_rank() == 0:
logging.debug('Use sparse embeddings on node %s:%d',
ntype, g.number_of_nodes(ntype))
if self._use_wholegraph_sparse_emb:
if get_rank() == 0:
logging.debug(
"Use WholeGraph to host sparse embeddings on node %s:%d",
ntype,
g.number_of_nodes(ntype),
)
self._sparse_embeds[ntype] = WholeGraphDistTensor(
(g.number_of_nodes(ntype), self.embed_size),
th.float32, # to consistent with distDGL's DistEmbedding dtype
embed_name + "_" + ntype,
use_wg_optimizer=True, # no memory allocation before opt available
)
else:
if get_rank() == 0:
logging.debug('Use sparse embeddings on node %s:%d',
ntype, g.number_of_nodes(ntype))
part_policy = g.get_node_partition_policy(ntype)
self._sparse_embeds[ntype] = DistEmbedding(g.number_of_nodes(ntype),
self.embed_size,
embed_name + '_' + ntype,
init_emb,
part_policy=part_policy)

proj_matrix = nn.Parameter(th.Tensor(self.embed_size, self.embed_size))
nn.init.xavier_uniform_(proj_matrix, gain=nn.init.calculate_gain('relu'))
self.proj_matrix[ntype] = proj_matrix
self._sparse_embeds[ntype] = DistEmbedding(g.number_of_nodes(ntype),
self.embed_size,
embed_name + '_' + ntype,
init_emb,
part_policy=part_policy)

# ngnn
self.num_ffn_layers_in_input = num_ffn_layers_in_input
Expand All @@ -290,29 +351,46 @@ def forward(self, input_feats, input_nodes):
assert isinstance(input_nodes, dict), 'The input node IDs should be in a dict.'
embs = {}
for ntype in input_nodes:
if isinstance(input_nodes[ntype], np.ndarray):
chang-l marked this conversation as resolved.
Show resolved Hide resolved
# WholeGraphSparseEmbedding requires the input nodes (indexing tensor)
# to be a th.Tensor
input_nodes[ntype] = th.from_numpy(input_nodes[ntype])
emb = None
if ntype in input_feats:
assert ntype in self.input_projs, \
f"We need a projection for node type {ntype}"
f"We need a projection for node type {ntype}"
# If the input data is not float, we need to convert it t float first.
emb = input_feats[ntype].float() @ self.input_projs[ntype]
if self.use_node_embeddings:
assert ntype in self.sparse_embeds, \
f"We need sparse embedding for node type {ntype}"
node_emb = self.sparse_embeds[ntype](input_nodes[ntype], emb.device)
concat_emb=th.cat((emb, node_emb),dim=1)
f"We need sparse embedding for node type {ntype}"
# emb.device: target device to put the gathered results
if self._use_wholegraph_sparse_emb:
node_emb = self.sparse_embeds[ntype].module(input_nodes[ntype].cuda())
classicsong marked this conversation as resolved.
Show resolved Hide resolved
node_emb = node_emb.to(emb.device, non_blocking=True)
else:
node_emb = self.sparse_embeds[ntype](input_nodes[ntype], emb.device)
concat_emb = th.cat((emb, node_emb), dim=1)
emb = concat_emb @ self.proj_matrix[ntype]
elif ntype in self.sparse_embeds: # nodes do not have input features
elif ntype in self.sparse_embeds: # nodes do not have input features
# If the number of the input node of a node type is 0,
# return an empty tensor with shape (0, emb_size)
device = self.proj_matrix[ntype].device
if len(input_nodes[ntype]) == 0:
dtype = self.sparse_embeds[ntype].weight.dtype
embs[ntype] = th.zeros((0, self.sparse_embeds[ntype].embedding_dim),
device=device, dtype=dtype)
continue
emb = self.sparse_embeds[ntype](input_nodes[ntype], device)
# If DistEmbedding supports 0-size input, we can remove this if statement.
if isinstance(self.sparse_embeds[ntype], WholeGraphDistTensor):
# Need all procs pass the following due to nccl all2lallv in wholegraph
emb = self.sparse_embeds[ntype].module(input_nodes[ntype].cuda())
classicsong marked this conversation as resolved.
Show resolved Hide resolved
emb = emb.to(device, non_blocking=True)
else:
if len(input_nodes[ntype]) == 0:
dtype = self.sparse_embeds[ntype].weight.dtype
embs[ntype] = th.zeros((0, self.sparse_embeds[ntype].embedding_dim),
device=device, dtype=dtype)
continue
emb = self.sparse_embeds[ntype](input_nodes[ntype], device)

emb = emb @ self.proj_matrix[ntype]

if emb is not None:
if self.activation is not None:
emb = self.activation(emb)
Expand Down Expand Up @@ -364,6 +442,13 @@ def out_dims(self):
"""
return self.embed_size

@property
def use_wholegraph_sparse_emb(self):
""" Whether or not to use WholeGraph to host embeddings for sparse updates.
"""
return self._use_wholegraph_sparse_emb


def _gen_emb(g, feat_field, embed_layer, ntype):
""" Test if the embed layer can generate embeddings on the node type.

Expand Down Expand Up @@ -394,6 +479,7 @@ def _gen_emb(g, feat_field, embed_layer, ntype):
emb = embed_layer(feat, {ntype: input_nodes})
return ntype in emb


def compute_node_input_embeddings(g, batch_size, embed_layer,
task_tracker=None, feat_field='feat',
target_ntypes=None):
Expand Down Expand Up @@ -442,10 +528,10 @@ def compute_node_input_embeddings(g, batch_size, embed_layer,
# a lot of memory.
if 'input_emb' not in g.nodes[ntype].data:
g.nodes[ntype].data['input_emb'] = create_dist_tensor(
(g.number_of_nodes(ntype), embed_size),
dtype=th.float32, name=f'{ntype}_input_emb',
part_policy=g.get_node_partition_policy(ntype),
persistent=True)
(g.number_of_nodes(ntype), embed_size),
dtype=th.float32, name=f'{ntype}_input_emb',
part_policy=g.get_node_partition_policy(ntype),
persistent=True)
else:
assert g.nodes[ntype].data['input_emb'].shape[1] == embed_size
input_emb = g.nodes[ntype].data['input_emb']
Expand Down
Loading
Loading