diff --git a/python/graphstorm/model/embed.py b/python/graphstorm/model/embed.py index eed50977c7..62c5aa66e6 100644 --- a/python/graphstorm/model/embed.py +++ b/python/graphstorm/model/embed.py @@ -450,6 +450,10 @@ def compute_node_input_embeddings(g, batch_size, embed_layer, assert g.nodes[ntype].data['input_emb'].shape[1] == embed_size input_emb = g.nodes[ntype].data['input_emb'] # TODO(zhengda) this is not a memory efficient way of implementing this. + # Here `force_even` is set to False, this means that we compute the input node + # embeddings for the nodes in the local partition and save the embeddings to + # the local partition with shared memory. Therefore, we don't need to call + # flush at the end of inference. infer_nodes = node_split(th.ones((g.number_of_nodes(ntype),), dtype=th.bool), partition_book=g.get_partition_book(), ntype=ntype, force_even=False) diff --git a/python/graphstorm/model/gnn_encoder_base.py b/python/graphstorm/model/gnn_encoder_base.py index 450711a7cc..a820fa3ea0 100644 --- a/python/graphstorm/model/gnn_encoder_base.py +++ b/python/graphstorm/model/gnn_encoder_base.py @@ -227,6 +227,12 @@ def dist_minibatch_inference(g, gnn_encoder, get_input_embeds, batch_size, fanou for ntype, out_nodes in output_nodes.items(): out_embs[ntype][out_nodes] = output[ntype].cpu() + # The nodes are split in such a way that all processes only need to compute + # the embeddings of the nodes in the local partition. Therefore, a barrier + # is enough to ensure that all data have been written to memory for distributed + # read after this function is returned. + # Note: there is a risk here. If the nodes for inference on each partition + # are very skewed, some of the processes may timeout in the barrier. barrier() return out_embs @@ -392,5 +398,4 @@ def get_input_embeds1(input_nodes, node_feats): list(infer_nodes.keys()), layer, get_input_embeds, device, task_tracker) - barrier() return next_layer_input diff --git a/python/graphstorm/model/lm_embed.py b/python/graphstorm/model/lm_embed.py index 7adcb29c1d..5c13ad167d 100644 --- a/python/graphstorm/model/lm_embed.py +++ b/python/graphstorm/model/lm_embed.py @@ -32,7 +32,8 @@ from .lm_model import init_lm_model from .lm_model import get_lm_node_feats from .utils import load_pytorch_embedding, save_pytorch_embedding -from ..utils import get_rank, get_world_size, barrier, create_dist_tensor +from ..utils import get_rank, get_world_size, create_dist_tensor +from ..distributed import flush_data class LMModels(nn.Module): """ LM model collection @@ -381,17 +382,20 @@ def update_cache(self, lm_infer_batch_size, use_fp16=True): part_policy=self._g.get_node_partition_policy(ntype), persistent=True) emb = self._lm_emb_cache[ntype] + # LM computations are very computationally expensive. It's better to force + # an even split to ensure all processes have roughly the same number of nodes + # for LM inference. infer_nodes = dgl.distributed.node_split( th.ones((self._g.number_of_nodes(ntype),), dtype=th.bool), partition_book=self._g.get_partition_book(), - ntype=ntype, force_even=False) - logging.debug("node %s, local infer set: %d, batch size: %d", - ntype, len(infer_nodes), lm_infer_batch_size) + ntype=ntype, force_even=True) + logging.debug("Rank %d: node %s, local infer set: %d, batch size: %d", + get_rank(), ntype, len(infer_nodes), lm_infer_batch_size) node_list = th.split(infer_nodes, lm_infer_batch_size) input_ntypes = [ntype] with th.no_grad(): - for input_nodes in node_list: + for i, input_nodes in enumerate(node_list): input_lm_feats = {} input_lm_feats[ntype] = { fname: feat[input_nodes] for fname, feat in lm_node_feat.items() @@ -401,7 +405,13 @@ def update_cache(self, lm_infer_batch_size, use_fp16=True): emb[input_nodes] = text_embs[ntype].half().to('cpu') else: emb[input_nodes] = text_embs[ntype].to('cpu') - barrier() + if i % 1000 == 0 and get_rank() == 0: + logging.debug("Compute LM embeddings on %d batches out of %d", + i, len(node_list)) + # Because we split the nodes evenly, we need to write data to remote machines. + # Therefore, we need to flush data here to ensure that we can load data + # correctly afterwards. + flush_data() if get_rank() == 0: logging.info('Computing bert embedding on node %s takes %.3f seconds', ntype, time.time() - start)