Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[WholeGraph] Add support of using WholeGraph to store/load cache_lm_emb #737

Merged
merged 16 commits into from
Feb 20, 2024
Merged
Show file tree
Hide file tree
Changes from 5 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
19 changes: 19 additions & 0 deletions python/graphstorm/config/argument.py
Original file line number Diff line number Diff line change
Expand Up @@ -307,6 +307,7 @@ def verify_arguments(self, is_train):
if self.node_lm_configs:
_ = self.lm_infer_batch_size
_ = self.freeze_lm_encoder_epochs
_ = self.use_wholegraph_cache_lm_embed

if self.distill_lm_configs:
_ = self.textual_data_path
Expand Down Expand Up @@ -690,6 +691,18 @@ def cache_lm_embed(self):
else:
return None

@property
def use_wholegraph_cache_lm_embed(self):
""" Whether to cache the LM embeddings on files by using WholeGraph.
"""
if hasattr(self, "_use_wholegraph_cache_lm_embed"):
if self._use_wholegraph_cache_lm_embed:
assert self.cache_lm_embed, "You must turn on cache_lm_embed " \
"to use wholegraph cache lm embeddings."
return self._use_wholegraph_cache_lm_embed
else:
return None

###################### general gnn model related ######################
@property
def model_encoder_type(self):
Expand Down Expand Up @@ -2481,6 +2494,12 @@ def _add_lm_model_args(parser):
help="Whether to cache the LM embeddings in files. " + \
"If the LM embeddings have been saved before, load the saved embeddings " + \
"instead of computing the LM embeddings again.")
group.add_argument("--use-wholegraph-cache-lm-embed",
chang-l marked this conversation as resolved.
Show resolved Hide resolved
type=lambda x: (str(x).lower() in ['true', '1']),
default=argparse.SUPPRESS,
help="Whether to use WholeGraph to cache the LM embeddings in files. " + \
"If the LM embeddings have been saved before, load the saved embeddings " + \
"instead of computing the LM embeddings again.")
return parser

def _add_rgat_args(parser):
Expand Down
9 changes: 6 additions & 3 deletions python/graphstorm/gsf.py
Original file line number Diff line number Diff line change
Expand Up @@ -512,12 +512,14 @@ def set_encoder(model, g, config, train_task):
if config.node_lm_configs is not None:
emb_path = os.path.join(os.path.dirname(config.part_config),
"cached_embs") if config.cache_lm_embed else None
wg_cached_embed = config.use_wholegraph_cache_lm_embed
if model_encoder_type == "lm":
# only use language model(s) as input layer encoder(s)
encoder = GSPureLMNodeInputLayer(g, config.node_lm_configs,
num_train=config.lm_train_nodes,
lm_infer_batch_size=config.lm_infer_batch_size,
cached_embed_path=emb_path)
num_train=config.lm_train_nodes,
lm_infer_batch_size=config.lm_infer_batch_size,
cached_embed_path=emb_path,
wg_cached_embed=wg_cached_embed)
else:
encoder = GSLMNodeEncoderInputLayer(g, config.node_lm_configs,
feat_size, config.hidden_size,
Expand All @@ -526,6 +528,7 @@ def set_encoder(model, g, config, train_task):
dropout=config.dropout,
use_node_embeddings=config.use_node_embeddings,
cached_embed_path=emb_path,
wg_cached_embed=wg_cached_embed,
force_no_embeddings=config.construct_feat_ntype)
else:
encoder = GSNodeEncoderInputLayer(g, feat_size, config.hidden_size,
Expand Down
67 changes: 51 additions & 16 deletions python/graphstorm/model/lm_embed.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,8 +31,14 @@
from .embed import GSNodeEncoderInputLayer
from .lm_model import init_lm_model
from .lm_model import get_lm_node_feats
from .utils import load_pytorch_embedding, save_pytorch_embedding
from .utils import (
load_pytorch_embedding,
save_pytorch_embedding,
load_wholegraph_embedding,
save_wholegraph_embedding
)
from ..utils import get_rank, get_world_size, create_dist_tensor
from ..wholegraph import WholeGraphDistTensor
from ..distributed import flush_data

class LMModels(nn.Module):
Expand Down Expand Up @@ -239,11 +245,12 @@ class LMCache:
embed_path : str
The path where the embedding files are stored.
"""
def __init__(self, g, lm_models, embed_path=None):
def __init__(self, g, lm_models, embed_path=None, use_wg=False):
self._g = g
self._lm_models = lm_models
self._lm_emb_cache = {}
self._embed_path = embed_path
self._use_wg = use_wg
self._lm_hash = ''

def _get_model_hash(self, ntype):
Expand Down Expand Up @@ -281,7 +288,11 @@ def _load_embeddings(self):
logging.info("load LM embedding from %s for node type %s",
embed_path, ntype)
embed_name = embed_ndata_names[ntype]
self._lm_emb_cache[ntype] = load_pytorch_embedding(embed_path,
if self.use_wg:
self._lm_emb_cache[ntype] = load_wholegraph_embedding(
chang-l marked this conversation as resolved.
Show resolved Hide resolved
embed_path, embed_name)
else:
self._lm_emb_cache[ntype] = load_pytorch_embedding(embed_path,
self._g.get_node_partition_policy(ntype), embed_name)
if set(self._lm_emb_cache.keys()) == set(self._lm_models.ntypes):
logging.debug("Successfully load all embeddings from the cache.")
Expand All @@ -305,10 +316,16 @@ def _save_embeddings(self):
embed_path = os.path.join(os.path.join(
os.path.join(self._embed_path, "lm_cache"), ntype),
self._get_model_name(ntype))
save_pytorch_embedding(embed_path,
self._lm_emb_cache[ntype],
get_rank(),
get_world_size())
if self.use_wg:
save_wholegraph_embedding(embed_path,
self._lm_emb_cache[ntype],
get_rank(),
get_world_size())
else:
save_pytorch_embedding(embed_path,
self._lm_emb_cache[ntype],
get_rank(),
get_world_size())

def __len__(self):
return len(self._lm_emb_cache)
Expand All @@ -329,6 +346,12 @@ def ntypes(self):
"""
return self._lm_models.ntypes

@property
def use_wg(self):
""" Whether to use WholeGraph to store the embeddings.
"""
return self._use_wg

@property
def embed_ndata_name(self):
""" The embed name of the node data
Expand Down Expand Up @@ -375,12 +398,18 @@ def update_cache(self, lm_infer_batch_size, use_fp16=True):
hidden_size = lm_model.feat_size
if ntype not in self._lm_emb_cache:
embed_name = embed_ndata_names[ntype]
self._lm_emb_cache[ntype] = create_dist_tensor(
if self.use_wg:
self._lm_emb_cache[ntype] = WholeGraphDistTensor(
(self._g.number_of_nodes(ntype), hidden_size),
name=embed_name,
dtype=th.float16 if use_fp16 else th.float32,
part_policy=self._g.get_node_partition_policy(ntype),
persistent=True)
name=embed_name)
else:
self._lm_emb_cache[ntype] = create_dist_tensor(
(self._g.number_of_nodes(ntype), hidden_size),
name=embed_name,
dtype=th.float16 if use_fp16 else th.float32,
part_policy=self._g.get_node_partition_policy(ntype),
persistent=True)
emb = self._lm_emb_cache[ntype]
# LM computations are very computationally expensive. It's better to force
# an even split to ensure all processes have roughly the same number of nodes
Expand All @@ -401,10 +430,11 @@ def update_cache(self, lm_infer_batch_size, use_fp16=True):
fname: feat[input_nodes] for fname, feat in lm_node_feat.items()
}
text_embs = lm_model(input_ntypes, input_lm_feats)
device = 'cuda' if self.use_wg else 'cpu'
if use_fp16:
emb[input_nodes] = text_embs[ntype].half().to('cpu')
emb[input_nodes] = text_embs[ntype].half().to(device)
else:
emb[input_nodes] = text_embs[ntype].to('cpu')
emb[input_nodes] = text_embs[ntype].to(device)
if i % 1000 == 0 and get_rank() == 0:
logging.debug("Compute LM embeddings on %d batches out of %d",
i, len(node_list))
Expand Down Expand Up @@ -484,7 +514,9 @@ def __init__(self,
num_train=0,
lm_infer_batch_size=16,
use_fp16=True,
cached_embed_path=None):
cached_embed_path=None,
wg_cached_embed=False):

super(GSPureLMNodeInputLayer, self).__init__(g)
assert node_lm_configs is not None and len(node_lm_configs) > 0, \
"language model configurations must be provided"
Expand All @@ -494,7 +526,8 @@ def __init__(self,
self.lm_infer_batch_size = lm_infer_batch_size
self.use_fp16 = use_fp16
self.use_cache = False
self.lm_emb_cache = LMCache(g, self._lm_models, embed_path=cached_embed_path)
self.lm_emb_cache = LMCache(g, self._lm_models, embed_path=cached_embed_path,
use_wg=wg_cached_embed)

self._feat_size = self._lm_models.get_feat_size(self._lm_models.ntypes[0])
for lm_model in self._lm_models.lm_models:
Expand Down Expand Up @@ -684,6 +717,7 @@ def __init__(self,
use_node_embeddings=False,
use_fp16=True,
cached_embed_path=None,
wg_cached_embed=False,
force_no_embeddings=None):
assert node_lm_configs is not None and len(node_lm_configs) > 0, \
"language model configurations must be provided"
Expand All @@ -705,7 +739,8 @@ def __init__(self,
self.use_fp16 = use_fp16
self.lm_infer_batch_size = lm_infer_batch_size
self.use_cache = False
self.lm_emb_cache = LMCache(g, lm_models, embed_path=cached_embed_path)
self.lm_emb_cache = LMCache(g, lm_models, embed_path=cached_embed_path,
use_wg=wg_cached_embed)

super(GSLMNodeEncoderInputLayer, self).__init__(
g, adjust_feat_size, embed_size,
Expand Down
93 changes: 93 additions & 0 deletions python/graphstorm/model/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -565,6 +565,99 @@ def save_pytorch_embedding(emb_path, embedding, rank, world_size):
embedding = embedding[start:end]
th.save(embedding, os.path.join(emb_path, f'embed-{pad_file_index(rank)}.pt'))

def save_wholegraph_embedding(emb_path, embedding, rank, world_size):
""" Save Dist embedding tensor in binary format for WholeGraph.

Parameters
----------
emb_path : str
The path of the save embedding files.
embedding : WholeGraphDistTensor
The WholeGraph dist tensor to save.
rank : int
Rank of the current process in a distributed environment.
world_size : int
World size in a distributed env.
"""
os.makedirs(emb_path, exist_ok=True)
# [04/16]: Only rank 0 can chmod to let all other ranks to write files.
if rank == 0:
# mode 767 means rwx-rw-rwx:
# - owner of the folder can read, write, and execute;
# - owner' group can read, write;
# - others can read, write, and execute.
os.chmod(emb_path, 0o767)

# make sure the emb_path permission is changed before other process start to save
barrier()

assert rank < world_size, \
f"Process rank {rank} must be smaller than the distributed cluster size {world_size}"

assert isinstance(embedding, WholeGraphDistTensor), \
"Input embedding must be a WholeGraphDistTensor."

emb_num = embedding.num_embeddings
emb_dim = embedding.embedding_dim
emb_dtype = embedding.dtype
emb_name = embedding.name
emb_info = {
"format": "binary",
chang-l marked this conversation as resolved.
Show resolved Hide resolved
"emb_num": str(emb_num),
"emb_dim": str(emb_dim),
"emb_dtype": str(emb_dtype),
"emb_name": emb_name,
"world_size": world_size
}

embedding.save_to_file(emb_path, file_prefix="wg-embed")
if rank == 0:
with open(os.path.join(emb_path, "emb_info.json"), 'w', encoding='utf-8') as f:
json.dump(emb_info, f, indent=4)

def load_wholegraph_embedding(emb_path, name):
""" Load embedding tensor in binary format for WholeGraph.

Parameters
----------
emb_path : str
The path of the save embedding files.
part_policy : dgl.distributed.PartitionPolicy
The partitioning policy
name : str
The name of the created distributed tensor.

Returns
-------
WholeGraphDistTensor : the loaded embeddings in WholeGraph.
"""
with open(os.path.join(emb_path, "emb_info.json"), 'r', encoding='utf-8') as f:
emb_info = json.load(f)

chang-l marked this conversation as resolved.
Show resolved Hide resolved
emb_num = int(emb_info['emb_num'])
emb_dim = int(emb_info['emb_dim'])
world_size_in_save = int(emb_info['world_size'])
supported_dtypes = {
'torch.half': th.half,
'torch.float16': th.float16,
'torch.float32': th.float32,
'torch.float': th.float,
'torch.int64': th.int64,
'torch.int32': th.int32
}
emb_dtype = supported_dtypes[emb_info['emb_dtype']]
dist_emb = WholeGraphDistTensor((emb_num, emb_dim), emb_dtype, name=name)
files = os.listdir(emb_path)
filtered_files = [file for file in files if file.startswith("wg-embed")]
num_files = len(filtered_files)
assert num_files > 0, "No WholeGraph embedding files found."
assert world_size_in_save == num_files, \
f"World_size when save the embedding {world_size_in_save} \
doesn't match the number of files {num_files}."
dist_emb.load_from_file(emb_path, "wg-embed", num_files)
barrier()
return dist_emb

def load_pytorch_embedding(emb_path, part_policy, name):
""" Load embedding tensor in Pytorch format.

Expand Down
Loading
Loading