Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[WholeGraph] Add support of using WholeGraph to store/load cache_lm_emb #737

Merged
merged 16 commits into from
Feb 20, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
20 changes: 20 additions & 0 deletions python/graphstorm/config/argument.py
Original file line number Diff line number Diff line change
Expand Up @@ -261,6 +261,7 @@ def verify_arguments(self, is_train):
_ = self.node_id_mapping_file
_ = self.edge_id_mapping_file
_ = self.verbose
_ = self.use_wholegraph_embed

# Data
_ = self.node_feat_name
Expand Down Expand Up @@ -537,6 +538,18 @@ def verbose(self):

return False

@property
def use_wholegraph_embed(self):
""" Whether to use WholeGraph to store intermediate embeddings/tensors generated
during training or inference, e.g., cache_lm_emb, sparse_emb, etc.
"""
if hasattr(self, "_use_wholegraph_embed"):
assert self._use_wholegraph_embed in [True, False], \
"Invalid value for _use_wholegraph_embed. Must be either True or False."
return self._use_wholegraph_embed
else:
return None

###################### language model support #########################
# Bert related
@property
Expand Down Expand Up @@ -2285,6 +2298,13 @@ def _add_initialization_args(parser):
default=argparse.SUPPRESS,
help="Print more information.",
)
group.add_argument(
"--use-wholegraph-embed",
type=lambda x: (str(x).lower() in ['true', '1']),
default=argparse.SUPPRESS,
help="Whether to use WholeGraph to store intermediate embeddings/tensors generated \
during training or inference, e.g., cache_lm_emb, sparse_emb, etc."
)
return parser

def _add_gsgnn_basic_args(parser):
Expand Down
8 changes: 5 additions & 3 deletions python/graphstorm/gsf.py
Original file line number Diff line number Diff line change
Expand Up @@ -515,9 +515,10 @@ def set_encoder(model, g, config, train_task):
if model_encoder_type == "lm":
# only use language model(s) as input layer encoder(s)
encoder = GSPureLMNodeInputLayer(g, config.node_lm_configs,
num_train=config.lm_train_nodes,
lm_infer_batch_size=config.lm_infer_batch_size,
cached_embed_path=emb_path)
num_train=config.lm_train_nodes,
lm_infer_batch_size=config.lm_infer_batch_size,
cached_embed_path=emb_path,
wg_cached_embed=config.use_wholegraph_embed)
else:
encoder = GSLMNodeEncoderInputLayer(g, config.node_lm_configs,
feat_size, config.hidden_size,
Expand All @@ -526,6 +527,7 @@ def set_encoder(model, g, config, train_task):
dropout=config.dropout,
use_node_embeddings=config.use_node_embeddings,
cached_embed_path=emb_path,
wg_cached_embed=config.use_wholegraph_embed,
force_no_embeddings=config.construct_feat_ntype)
else:
encoder = GSNodeEncoderInputLayer(g, feat_size, config.hidden_size,
Expand Down
68 changes: 52 additions & 16 deletions python/graphstorm/model/lm_embed.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,8 +31,14 @@
from .embed import GSNodeEncoderInputLayer
from .lm_model import init_lm_model
from .lm_model import get_lm_node_feats
from .utils import load_pytorch_embedding, save_pytorch_embedding
from .utils import (
load_pytorch_embedding,
save_pytorch_embedding,
load_wholegraph_embedding,
save_wholegraph_embedding
)
from ..utils import get_rank, get_world_size, create_dist_tensor
from ..wholegraph import WholeGraphDistTensor
from ..distributed import flush_data

class LMModels(nn.Module):
Expand Down Expand Up @@ -239,11 +245,12 @@ class LMCache:
embed_path : str
The path where the embedding files are stored.
"""
def __init__(self, g, lm_models, embed_path=None):
def __init__(self, g, lm_models, embed_path=None, use_wg=False):
self._g = g
self._lm_models = lm_models
self._lm_emb_cache = {}
self._embed_path = embed_path
self._use_wg = use_wg
self._lm_hash = ''

def _get_model_hash(self, ntype):
Expand Down Expand Up @@ -281,7 +288,11 @@ def _load_embeddings(self):
logging.info("load LM embedding from %s for node type %s",
embed_path, ntype)
embed_name = embed_ndata_names[ntype]
self._lm_emb_cache[ntype] = load_pytorch_embedding(embed_path,
if self.use_wg:
self._lm_emb_cache[ntype] = load_wholegraph_embedding(
chang-l marked this conversation as resolved.
Show resolved Hide resolved
embed_path, embed_name)
else:
self._lm_emb_cache[ntype] = load_pytorch_embedding(embed_path,
self._g.get_node_partition_policy(ntype), embed_name)
if set(self._lm_emb_cache.keys()) == set(self._lm_models.ntypes):
logging.debug("Successfully load all embeddings from the cache.")
Expand All @@ -305,10 +316,17 @@ def _save_embeddings(self):
embed_path = os.path.join(os.path.join(
os.path.join(self._embed_path, "lm_cache"), ntype),
self._get_model_name(ntype))
save_pytorch_embedding(embed_path,
self._lm_emb_cache[ntype],
get_rank(),
get_world_size())
if self.use_wg:
save_wholegraph_embedding(embed_path,
self._lm_emb_cache[ntype],
get_rank(),
get_world_size(),
fmt="binary")
chang-l marked this conversation as resolved.
Show resolved Hide resolved
else:
save_pytorch_embedding(embed_path,
self._lm_emb_cache[ntype],
get_rank(),
get_world_size())

def __len__(self):
return len(self._lm_emb_cache)
Expand All @@ -329,6 +347,12 @@ def ntypes(self):
"""
return self._lm_models.ntypes

@property
def use_wg(self):
""" Whether to use WholeGraph to store the embeddings.
"""
return self._use_wg

@property
def embed_ndata_name(self):
""" The embed name of the node data
Expand Down Expand Up @@ -375,12 +399,18 @@ def update_cache(self, lm_infer_batch_size, use_fp16=True):
hidden_size = lm_model.feat_size
if ntype not in self._lm_emb_cache:
embed_name = embed_ndata_names[ntype]
self._lm_emb_cache[ntype] = create_dist_tensor(
if self.use_wg:
self._lm_emb_cache[ntype] = WholeGraphDistTensor(
(self._g.number_of_nodes(ntype), hidden_size),
name=embed_name,
dtype=th.float16 if use_fp16 else th.float32,
part_policy=self._g.get_node_partition_policy(ntype),
persistent=True)
name=embed_name)
else:
self._lm_emb_cache[ntype] = create_dist_tensor(
(self._g.number_of_nodes(ntype), hidden_size),
name=embed_name,
dtype=th.float16 if use_fp16 else th.float32,
part_policy=self._g.get_node_partition_policy(ntype),
persistent=True)
emb = self._lm_emb_cache[ntype]
# LM computations are very computationally expensive. It's better to force
# an even split to ensure all processes have roughly the same number of nodes
Expand All @@ -401,10 +431,11 @@ def update_cache(self, lm_infer_batch_size, use_fp16=True):
fname: feat[input_nodes] for fname, feat in lm_node_feat.items()
}
text_embs = lm_model(input_ntypes, input_lm_feats)
device = 'cuda' if self.use_wg else 'cpu'
if use_fp16:
emb[input_nodes] = text_embs[ntype].half().to('cpu')
emb[input_nodes] = text_embs[ntype].half().to(device)
else:
emb[input_nodes] = text_embs[ntype].to('cpu')
emb[input_nodes] = text_embs[ntype].to(device)
if i % 1000 == 0 and get_rank() == 0:
logging.debug("Compute LM embeddings on %d batches out of %d",
i, len(node_list))
Expand Down Expand Up @@ -484,7 +515,9 @@ def __init__(self,
num_train=0,
lm_infer_batch_size=16,
use_fp16=True,
cached_embed_path=None):
cached_embed_path=None,
wg_cached_embed=False):

super(GSPureLMNodeInputLayer, self).__init__(g)
assert node_lm_configs is not None and len(node_lm_configs) > 0, \
"language model configurations must be provided"
Expand All @@ -494,7 +527,8 @@ def __init__(self,
self.lm_infer_batch_size = lm_infer_batch_size
self.use_fp16 = use_fp16
self.use_cache = False
self.lm_emb_cache = LMCache(g, self._lm_models, embed_path=cached_embed_path)
self.lm_emb_cache = LMCache(g, self._lm_models, embed_path=cached_embed_path,
use_wg=wg_cached_embed)

self._feat_size = self._lm_models.get_feat_size(self._lm_models.ntypes[0])
for lm_model in self._lm_models.lm_models:
Expand Down Expand Up @@ -684,6 +718,7 @@ def __init__(self,
use_node_embeddings=False,
use_fp16=True,
cached_embed_path=None,
wg_cached_embed=False,
force_no_embeddings=None):
assert node_lm_configs is not None and len(node_lm_configs) > 0, \
"language model configurations must be provided"
Expand All @@ -705,7 +740,8 @@ def __init__(self,
self.use_fp16 = use_fp16
self.lm_infer_batch_size = lm_infer_batch_size
self.use_cache = False
self.lm_emb_cache = LMCache(g, lm_models, embed_path=cached_embed_path)
self.lm_emb_cache = LMCache(g, lm_models, embed_path=cached_embed_path,
use_wg=wg_cached_embed)

super(GSLMNodeEncoderInputLayer, self).__init__(
g, adjust_feat_size, embed_size,
Expand Down
139 changes: 139 additions & 0 deletions python/graphstorm/model/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -565,6 +565,145 @@ def save_pytorch_embedding(emb_path, embedding, rank, world_size):
embedding = embedding[start:end]
th.save(embedding, os.path.join(emb_path, f'embed-{pad_file_index(rank)}.pt'))

def save_wholegraph_embedding(emb_path, embedding, rank, world_size, fmt="binary"):
""" Save Dist embedding tensor in binary format for WholeGraph.

Parameters
----------
emb_path : str
The path of the save embedding files.
embedding : WholeGraphDistTensor
The WholeGraph dist tensor to save.
rank : int
Rank of the current process in a distributed environment.
world_size : int
World size in a distributed env.
fmt : str
The format of the saved embeddings. Currently only support "binary" and "pytorch".
"""
assert fmt in ["binary", "pytorch"], \
"Using WholeGraph, the supported formats of the saved embeddings " + \
"are 'binary' and 'pytorch'."
os.makedirs(emb_path, exist_ok=True)
# [04/16]: Only rank 0 can chmod to let all other ranks to write files.
if rank == 0:
# mode 767 means rwx-rw-rwx:
# - owner of the folder can read, write, and execute;
# - owner' group can read, write;
# - others can read, write, and execute.
os.chmod(emb_path, 0o767)

# make sure the emb_path permission is changed before other process start to save
barrier()

assert rank < world_size, \
f"Process rank {rank} must be smaller than the distributed cluster size {world_size}"

assert isinstance(embedding, WholeGraphDistTensor), \
"Input embedding must be a WholeGraphDistTensor."

emb_num = embedding.num_embeddings
emb_dim = embedding.embedding_dim
emb_dtype = embedding.dtype
emb_name = embedding.name
emb_fmt = "wholegraph-" + fmt
emb_info = {
"format": emb_fmt,
"emb_num": str(emb_num),
"emb_dim": str(emb_dim),
"emb_dtype": str(emb_dtype),
"emb_name": emb_name,
"world_size": world_size
}
if fmt == "binary":
# use binary format to save the embedding (supported by native WholeGraph APIs)
# Example: wg-embed_part_0_of_2, wg-embed_part_1_of_2
# Pros: WholeGraph's natvie API to load the embedding directly.
# no RAM duplication; support save/load with different world_size.
embedding.save_to_file(emb_path, file_prefix="wg-embed")
elif fmt == "pytorch":
# use pytorch format to save the embedding (dump local tensor to pt file)
# Example: embed-00000.pt, embed-00001.pt
# Pros: Compatible with the format when WholeGraph is not enabled,
# but still follows wholegraph's even partition policy and duplicate RAM when load.
emb = embedding.get_local_tensor()[0]
wg_rank = embedding.get_comm().get_rank()
th.save(emb, os.path.join(emb_path, f'embed-{pad_file_index(wg_rank)}.pt'))

if rank == 0:
with open(os.path.join(emb_path, "emb_info.json"), 'w', encoding='utf-8') as f:
json.dump(emb_info, f, indent=4)

def load_wholegraph_embedding(emb_path, name):
""" Load embedding tensor in binary format for WholeGraph.

Parameters
----------
emb_path : str
The path of the save embedding files.
part_policy : dgl.distributed.PartitionPolicy
The partitioning policy
name : str
The name of the created distributed tensor.

Returns
-------
WholeGraphDistTensor : the loaded embeddings in WholeGraph.
"""
file_path = os.path.join(emb_path, "emb_info.json")
assert os.path.exists(file_path), \
f"Embedding JSON file: {file_path} not found. " + \
"This file is needed for storing embedding with WholeGraph. It's generated when " + \
"you save embeddings with '--use-wholegraph-embed' flag."
with open(file_path, 'r', encoding='utf-8') as f:
emb_info = json.load(f)

chang-l marked this conversation as resolved.
Show resolved Hide resolved
emb_fmt = emb_info['format']
assert emb_fmt.startswith("wholegraph-"), \
"The format of the saved embeddings should be started with 'wholegraph-'."
emb_fmt = emb_fmt.split("-")[1]
emb_num = int(emb_info['emb_num'])
emb_dim = int(emb_info['emb_dim'])
world_size_in_save = int(emb_info['world_size'])
supported_dtypes = {
'torch.half': th.half,
'torch.float16': th.float16,
'torch.float32': th.float32,
'torch.float': th.float,
'torch.int64': th.int64,
'torch.int32': th.int32
}
emb_dtype = supported_dtypes[emb_info['emb_dtype']]
dist_emb = WholeGraphDistTensor((emb_num, emb_dim), emb_dtype, name=name)
if emb_fmt == "pytorch":
assert dist_emb.get_comm().get_size() == world_size_in_save, \
"World_size when save the embedding is different than the current world_size. " \
"Please switch to the binary format."
wg_rank = dist_emb.get_comm().get_rank()
file_path = os.path.join(emb_path, f'embed-{pad_file_index(wg_rank)}.pt')
assert os.path.exists(file_path), f"Embedding file {file_path} of \
my rank {wg_rank} doesn't exist."
emb = th.load(file_path)
local_emb = dist_emb.get_local_tensor()[0]
assert emb.shape[0] == local_emb.shape[0] and emb.shape[1] == local_emb.shape[1], \
f"Embedding shape of {name} does not match! " + \
f"Expect {emb.shape}, but get {local_emb.shape}"

assert emb.dtype == local_emb.dtype, "Embedding datatype do not match!"
local_emb.copy_(emb)
elif emb_fmt == "binary":
files = os.listdir(emb_path)
filtered_files = [file for file in files if file.startswith("wg-embed")]
num_files = len(filtered_files)
assert num_files > 0, "No WholeGraph embedding files found."
assert world_size_in_save == num_files, \
f"World_size when save the embedding {world_size_in_save} \
doesn't match the number of files {num_files}."
dist_emb.load_from_file(emb_path, "wg-embed", num_files)

barrier()
return dist_emb

def load_pytorch_embedding(emb_path, part_policy, name):
""" Load embedding tensor in Pytorch format.

Expand Down
3 changes: 2 additions & 1 deletion python/graphstorm/run/gsgnn_emb/gsgnn_node_emb.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,8 +35,9 @@ def main(config_args):
config.verify_arguments(True)

use_wg_feats = use_wholegraph(config.part_config)
use_wg_embed = config.use_wholegraph_sparse_emb or config.use_wholegraph_embed
gs.initialize(ip_config=config.ip_config, backend=config.backend,
use_wholegraph=config.use_wholegraph_sparse_emb or use_wg_feats)
use_wholegraph=use_wg_embed or use_wg_feats)
rt_profiler.init(config.profile_path, rank=gs.get_rank())
sys_tracker.init(config.verbose, rank=gs.get_rank())
device = setup_device(config.local_rank)
Expand Down
3 changes: 2 additions & 1 deletion python/graphstorm/run/gsgnn_ep/ep_infer_gnn.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,8 +44,9 @@ def main(config_args):
config.verify_arguments(False)

use_wg_feats = use_wholegraph(config.part_config)
use_wg_embed = config.use_wholegraph_sparse_emb or config.use_wholegraph_embed
gs.initialize(ip_config=config.ip_config, backend=config.backend,
use_wholegraph=config.use_wholegraph_sparse_emb or use_wg_feats)
use_wholegraph=use_wg_embed or use_wg_feats)
device = setup_device(config.local_rank)

infer_data = GSgnnEdgeInferData(config.graph_name,
Expand Down
3 changes: 2 additions & 1 deletion python/graphstorm/run/gsgnn_ep/gsgnn_ep.py
Original file line number Diff line number Diff line change
Expand Up @@ -58,8 +58,9 @@ def main(config_args):
config.verify_arguments(True)

use_wg_feats = use_wholegraph(config.part_config)
use_wg_embed = config.use_wholegraph_sparse_emb or config.use_wholegraph_embed
gs.initialize(ip_config=config.ip_config, backend=config.backend,
use_wholegraph=config.use_wholegraph_sparse_emb or use_wg_feats)
use_wholegraph=use_wg_embed or use_wg_feats)
rt_profiler.init(config.profile_path, rank=gs.get_rank())
sys_tracker.init(config.verbose, rank=gs.get_rank())
device = setup_device(config.local_rank)
Expand Down
Loading
Loading