diff --git a/python/graphstorm/model/embed.py b/python/graphstorm/model/embed.py index ee1fa9b963..d86aced478 100644 --- a/python/graphstorm/model/embed.py +++ b/python/graphstorm/model/embed.py @@ -18,6 +18,7 @@ import time import logging +import numpy as np import torch as th from torch import nn import torch.nn.functional as F @@ -314,6 +315,8 @@ def forward(self, input_feats, input_nodes): assert isinstance(input_nodes, dict), 'The input node IDs should be in a dict.' embs = {} for ntype in input_nodes: + if isinstance(input_nodes[ntype], np.ndarray): + input_nodes[ntype] = th.from_numpy(input_nodes[ntype]) emb = None if ntype in input_feats: assert ntype in self.input_projs, \ @@ -346,12 +349,13 @@ def forward(self, input_feats, input_nodes): embs[ntype] = th.zeros((0, embedding_dim), device=device, dtype=dtype) continue + if is_wholegraph_embedding_module(self.sparse_embeds[ntype]): # output to local device node_emb = self.sparse_embeds[ntype](input_nodes[ntype].cuda()) - node_emb = node_emb.to(device, non_blocking=True) + emb = node_emb.to(device, non_blocking=True) else: - node_emb = self.sparse_embeds[ntype](input_nodes[ntype], device) + emb = self.sparse_embeds[ntype](input_nodes[ntype], device) emb = emb @ self.proj_matrix[ntype] if emb is not None: diff --git a/python/graphstorm/model/utils.py b/python/graphstorm/model/utils.py index b43188fcdb..7929958ee1 100644 --- a/python/graphstorm/model/utils.py +++ b/python/graphstorm/model/utils.py @@ -192,7 +192,6 @@ def save_sparse_emb(model_path, sparse_emb, ntype): num_embs = sparse_emb.wm_embedding.shape[0] else: num_embs = sparse_emb.num_embeddings - # wholegraph local embedding boundary is consistent to _get_sparse_emb_range start, end = _get_sparse_emb_range(num_embs, rank, world_size) # collect sparse_emb in a iterative way @@ -204,12 +203,15 @@ def save_sparse_emb(model_path, sparse_emb, ntype): os.makedirs(emb_path, exist_ok=True) if is_wholegraph_sparse_emb() and is_wholegraph_embedding_module(sparse_emb): + # Using WholeGraph will save sparse emb in npy format, we need np.load(mmap_mode) to load the + # file concurrently by all ranks, which is not supported by torch.load emb_file_path = os.path.join(emb_path, f'sparse_emb_{pad_file_index(rank)}.npy') local_tensor, _ = sparse_emb.wm_embedding.get_embedding_tensor().get_local_tensor(host_view=True) - if local_tensor.shape[0] > end - start: # this could only happen in unit test + # wholegraph local embedding boundary is consistent to the output of _get_sparse_emb_range + if local_tensor.shape[0] > end - start: # this should only happen in unit test embs = local_tensor[start:end].numpy() else: - assert local_tensor.shape[0] == end - start, "Save/Load has invalid dimensions." + assert local_tensor.shape[0] == end - start, "WholeGraph tensor local boundary has invalid dimensions." embs = local_tensor.numpy() np.save(emb_file_path, embs) else: @@ -1382,101 +1384,80 @@ def load_sparse_emb(target_sparse_emb, ntype_emb_path): else: num_embs = target_sparse_emb.num_embeddings + # Suppose a sparse embedding is trained and saved using N trainers (e.g., GPUs). + # We are going to use K trainers/infers to load it. + # The code handles the following cases: + # 1. N == K + # 2. N > K, some trainers/infers need to load more than one files + # 3. N < K, some trainers/infers do not need to load any files if is_wholegraph_sparse_emb() and is_wholegraph_embedding_module(target_sparse_emb): - # Suppose a sparse embedding is trained and saved using N trainers (e.g., GPUs). - # We have to use the same number of N trainers/infers to load it. - # save and load need to have the same dtype - if False: #(num_files == world_size): + if (num_files == world_size): + # When N==K, assume process group has not changed between save and load + # Each rank just needs to load its own part respectively. file_idx = rank - filepath = os.path.join(ntype_emb_path, f'sparse_emb_{pad_file_index(file_idx)}.bin') - + file_path = os.path.join(ntype_emb_path, f'sparse_emb_{pad_file_index(file_idx)}.npy') + np_emb = np.load(file_path) + emb = th.from_numpy(np_emb) loc_sta, loc_end = _get_sparse_emb_range(num_embs, rank=rank, world_size=world_size) - assert loc_end - loc_sta == emb.shape[0], "Save/Load has invalid dimensions." local_tensor, _ = target_sparse_emb.wm_embedding.get_embedding_tensor().get_local_tensor(host_view=True) - if local_tensor.shape[0] > emb.shape[0]: # this could only happen in unit test + if local_tensor.shape[0] > emb.shape[0]: # this should only happen in unit test + assert(loc_end - loc_sta == emb.shape[0], "Saved WholeGraph tensor has invalid file boundaries.") local_tensor[loc_sta:loc_end] = emb else: + assert(local_tensor.shape[0] == emb.shape[0], "WholeGraph Save/Load has invalid dimensions.") local_tensor.copy_(emb) - - target_sparse_emb.wm_embedding.get_embedding_tensor().from_filelist([filepath]) - else: - loc_sta, loc_end = _get_sparse_emb_range(num_embs, rank=rank, world_size=world_size) - assert loc_end - loc_sta == emb.shape[0], "Save/Load has invalid dimensions." - local_tensor, _ = target_sparse_emb.wm_embedding.get_embedding_tensor().get_local_tensor(host_view=True) - - # Dict of torch dtype -> numpy dtype (when the correspondence exists) - th_to_np_dtype_dict = { - th.bool : np.bool_, - th.uint8 : np.uint8, - th.int8 : np.int8, - th.int16 : np.int16, - th.int32 : np.int32, - th.int64 : np.int64, - th.float16 : np.float16, - th.float32 : np.float32, - th.float64 : np.float64, - th.complex64 : np.complex64, - th.complex128 : np.complex128 - } - for i in range(num_files): - file_idx = i - file_path=os.path.join(ntype_emb_path, f'sparse_emb_{pad_file_index(file_idx)}.bin') - dtype = target_sparse_emb.wm_embedding.get_embedding_tensor().dtype - start, end = _get_sparse_emb_range(num_embs, rank=file_idx, world_size=num_files) - file_offset = start - file_size = end - start - - loc_sta, loc_end = _get_sparse_emb_range(file_size, rank, world_size) - rank_offset = loc_sta - # TODO(chang-l): when loc_sta == loc_end, ie., file_size < world_size - # if so, we can let every proc read every file and load the needed part - assert(loc_end > loc_sta, "File size too small. # of store embs at each file must be > world_size.") - offset = rank_offset * target_sparse_emb.wm_embedding.shape[1] * th.tensor([], dtype=dtype).element_size() - shape = (loc_end-loc_sta, target_sparse_emb.wm_embedding.shape[1]) - - # memmap from file underneath, no heap memory allocation involved - emb = th.from_numpy(np.memmap(file_path, dtype=th_to_np_dtype_dict[dtype], mode='r', offset=offset, shape=shape)) - + # When N!=K, we process all saved files one by one, using all procs + # Then, followed by a distribute scatter operation to propagate the embs + + def _wholegraph_load_scatter(num_embs, file_id, num_files, file_path): + file_start, file_end = _get_sparse_emb_range(num_embs, rank=file_id, world_size=num_files) + file_size = file_end - file_start + rank_sta, rank_end = _get_sparse_emb_lowerbound_range(file_size, rank, world_size) + # TODO(chang-l): verify if loc_sta == loc_end, ie., file_size < world_size, still works. + + # memmap from file, no heap memory allocation involved here for now + np_emb = np.load(file_path, mmap_mode='r') + emb = th.from_numpy(np_emb)[rank_sta:rank_end] # write sparse_emb back by wm_scatter function distributedly via nccl # due to device memory limitation (scattered embeddings have to go through device), write back in a batched way batch_size = 102400 - loc_part_size = math.ceil(file_size / world_size) - standard_idxs = th.split(th.arange(loc_part_size), batch_size, dim=0) - - idxs = th.split(th.arange(loc_end - loc_sta), batch_size, dim=0) # file local idx - assert len(standard_idxs) >= len(idxs) - if len(standard_idxs) != len(idxs): # last rank - t1 = th.arange(loc_end - loc_sta) - nrepeat = loc_part_size + loc_sta - loc_end - t2 = t1[-1].repeat(nrepeat) - idxs = th.split(th.cat((t1,t2),0), batch_size, dim=0) - assert len(standard_idxs) == len(idxs) + std_part_size = file_size // world_size + nbatches = std_part_size // batch_size + if (nbatches != (rank_end-rank_sta) // batch_size): + batch_size = (rank_end-rank_sta) // nbatches + assert(nbatches == (rank_end-rank_sta) // batch_size) + idxs = th.split(th.arange(rank_end - rank_sta), batch_size, dim=0) # local idx for embs saved in each file read by each rank for idx in idxs: scatter_input = emb[idx].cuda() # read from file into device memory - scatter_gidx = file_offset + rank_offset + idx + scatter_gidx = file_start + rank_sta + idx # file offset and rank offset scatter_gidx = scatter_gidx.cuda() - target_sparse_emb.wm_embedding.get_embedding_tensor().scatter(scatter_input, scatter_gidx) - - #loc_sta, loc_end = _get_sparse_emb_range(num_embs, rank=rank, world_size=world_size) - #assert loc_end - loc_sta == emb.shape[0], "Save/Load has invalid dimensions." - #local_tensor, _ = target_sparse_emb.wm_embedding.get_embedding_tensor().get_local_tensor(host_view=True) - #if local_tensor.shape[0] > emb.shape[0]: # this could only happen in unit test - # local_tensor[loc_sta:loc_end] = emb - #else: - # local_tensor.copy_(emb) - - # TODO(chang-l): Extend to the case when N!=K, same as DistEmbedding (need scatter to broadcast the values) - # only copy_ the overlapped chunks - # scatter the non-local(non-overlapped) chunks of emb to target_sparse_emb - # e.g., target_sparse_emb.wm_embedding.get_embedding_tensor().scatter(emb, indices) + import pylibwholegraph + if pylibwholegraph.__version__ < "23.12.00": + import pylibwholegraph.torch.wholememory_ops as wm_ops + wmb_tensor = target_sparse_emb.wm_embedding.wmb_embedding.get_embedding_tensor() + wm_ops.wholememory_scatter_functor(scatter_input, scatter_gidx, wmb_tensor) + else: + target_sparse_emb.wm_embedding.get_embedding_tensor().scatter(scatter_input, scatter_gidx) + + def _get_sparse_emb_lowerbound_range(file_size, rank, world_size): + assert rank < world_size, \ + "local rank {rank} shold be smaller than world size {world_size}" + if file_size < world_size: + start = rank if rank < file_size else file_size + end = rank + 1 if rank < file_size else file_size + else: + part = file_size // world_size + start = rank * part + end = (rank + 1) * part + end = file_size if rank + 1 == world_size else end + return start, end + + for i in range(num_files): + file_idx = i + file_path=os.path.join(ntype_emb_path, f'sparse_emb_{pad_file_index(file_idx)}.npy') + _wholegraph_load_scatter(num_embs, file_idx, num_files, file_path) else: - # Suppose a sparse embedding is trained and saved using N trainers (e.g., GPUs). - # We are going to use K trainers/infers to load it. - # The code handles the following cases: - # 1. N == K - # 2. N > K, some trainers/infers need to load more than one files - # 3. N < K, some trainers/infers do not need to load any files for i in range(math.ceil(num_files/world_size)): file_idx = i * world_size + rank if file_idx < num_files: diff --git a/tests/unit-tests/test_wg_sparse_emb_save_load.py b/tests/unit-tests/test_wg_sparse_emb.py similarity index 52% rename from tests/unit-tests/test_wg_sparse_emb_save_load.py rename to tests/unit-tests/test_wg_sparse_emb.py index 8fb9db909b..a33b545832 100644 --- a/tests/unit-tests/test_wg_sparse_emb_save_load.py +++ b/tests/unit-tests/test_wg_sparse_emb.py @@ -23,12 +23,16 @@ import numpy as np import torch as th -from numpy.testing import assert_equal +import torch.nn.functional as F +from torch import nn +from numpy.testing import assert_equal, assert_almost_equal + from unittest.mock import patch from graphstorm.gsf import init_wholegraph from graphstorm.utils import use_wholegraph_sparse_emb, is_wholegraph_sparse_emb from graphstorm.model import GSNodeEncoderInputLayer +from graphstorm.model.embed import compute_node_input_embeddings from graphstorm.model.utils import save_sparse_embeds from graphstorm.model.utils import load_sparse_embeds from graphstorm.model.utils import _get_sparse_emb_range @@ -36,7 +40,6 @@ from graphstorm import get_feat_size from data_utils import generate_dummy_dist_graph -import pylibwholegraph.torch as wgth def initialize(use_wholegraph=True): @@ -65,6 +68,7 @@ def test_wg_sparse_embed_save(world_size): And then check the value of the saved embedding. """ # initialize the torch distributed environment + wgth = pytest.importorskip("pylibwholegraph.torch") use_wholegraph_sparse_emb() initialize(use_wholegraph=is_wholegraph_sparse_emb()) @@ -114,7 +118,7 @@ def check_saved_sparse_emb(mock_get_world_size, mock_get_rank): th.distributed.destroy_process_group() dgl.distributed.kvstore.close_kvstore() -@pytest.mark.parametrize("infer_world_size", [8]) +@pytest.mark.parametrize("infer_world_size", [3, 8, 16]) @pytest.mark.parametrize("train_world_size", [8]) def test_wg_sparse_embed_load(infer_world_size, train_world_size): """ Test sparse embedding loading logic using wholegraph. (graphstorm.model.utils.load_sparse_embeds) @@ -125,6 +129,7 @@ def test_wg_sparse_embed_load(infer_world_size, train_world_size): It will compare the embedings stored and loaded. """ # initialize the torch distributed environment + wgth = pytest.importorskip("pylibwholegraph.torch") use_wholegraph_sparse_emb() initialize(use_wholegraph=is_wholegraph_sparse_emb()) @@ -159,7 +164,7 @@ def check_sparse_emb(mock_get_world_size, mock_get_rank): for i in range(infer_world_size): mock_get_rank.side_effect = [i] * 2 - mock_get_world_size.side_effect = [train_world_size] * 2 + mock_get_world_size.side_effect = [infer_world_size] * 2 load_sparse_embeds(model_path, embed_layer) if is_wholegraph_sparse_emb(): load_sparse_embs = \ @@ -181,6 +186,147 @@ def check_sparse_emb(mock_get_world_size, mock_get_rank): th.distributed.destroy_process_group() dgl.distributed.kvstore.close_kvstore() +# In this case, we use node feature on one node type and +# use sparse embedding on the other node type. +@pytest.mark.parametrize("dev", ['cpu','cuda:0']) +def test_wg_input_layer3(dev): + # initialize the torch distributed environment + wgth = pytest.importorskip("pylibwholegraph.torch") + use_wholegraph_sparse_emb() + initialize(use_wholegraph=is_wholegraph_sparse_emb()) + with tempfile.TemporaryDirectory() as tmpdirname: + # get the test dummy distributed graph + g, _ = generate_dummy_dist_graph(tmpdirname) + + feat_size = get_feat_size(g, {'n0' : ['feat']}) + layer = GSNodeEncoderInputLayer(g, feat_size, 2) + assert len(layer.input_projs) == 1 + assert list(layer.input_projs.keys())[0] == 'n0' + assert len(layer.sparse_embeds) == 1 + layer = layer.to(dev) + + node_feat = {} + node_embs = {} + input_nodes = {} + for ntype in g.ntypes: + input_nodes[ntype] = np.arange(10) + nn.init.eye_(layer.input_projs['n0']) + nn.init.eye_(layer.proj_matrix['n1']) + node_feat['n0'] = g.nodes['n0'].data['feat'][input_nodes['n0']].to(dev) + + node_embs['n1'] = layer.sparse_embeds['n1'](th.from_numpy(input_nodes['n1']).cuda()) + + embed = layer(node_feat, input_nodes) + assert len(embed) == len(input_nodes) + # check emb device + for _, emb in embed.items(): + assert emb.get_device() == (-1 if dev == 'cpu' else 0) + assert_almost_equal(embed['n0'].detach().cpu().numpy(), + node_feat['n0'].detach().cpu().numpy()) + assert_almost_equal(embed['n1'].detach().cpu().numpy(), + node_embs['n1'].detach().cpu().numpy()) + + # test the case that one node type has no input nodes. + input_nodes['n0'] = np.arange(10) + + # TODO(chang-l): Somehow, WholeGraph does not support empty indices created from numpy then converted to torch, i.e., + # empty_nodes = th.from_numpy(np.zeros((0,), dtype=int)) does not work (segfault in wholegraph.gather). + # Need to submit an issue to WholeGraph team + input_nodes['n1'] = th.tensor([],dtype=th.int64) #np.zeros((0,), dtype=int) should work but not!! + + nn.init.eye_(layer.input_projs['n0']) + node_feat['n0'] = g.nodes['n0'].data['feat'][input_nodes['n0']].to(dev) + node_embs['n1'] = layer.sparse_embeds['n1'](input_nodes['n1'].cuda()) + + embed = layer(node_feat, input_nodes) + assert len(embed) == len(input_nodes) + # check emb device + for _, emb in embed.items(): + assert emb.get_device() == (-1 if dev == 'cpu' else 0) + assert_almost_equal(embed['n0'].detach().cpu().numpy(), + node_feat['n0'].detach().cpu().numpy()) + assert_almost_equal(embed['n1'].detach().cpu().numpy(), + node_embs['n1'].detach().cpu().numpy()) + + if is_wholegraph_sparse_emb(): + wgth.finalize() + th.distributed.destroy_process_group() + +# In this case, we use both node features and sparse embeddings. +def test_wg_input_layer2(): + # initialize the torch distributed environment + wgth = pytest.importorskip("pylibwholegraph.torch") + use_wholegraph_sparse_emb() + initialize(use_wholegraph=is_wholegraph_sparse_emb()) + with tempfile.TemporaryDirectory() as tmpdirname: + # get the test dummy distributed graph + g, _ = generate_dummy_dist_graph(tmpdirname) + + feat_size = get_feat_size(g, 'feat') + layer = GSNodeEncoderInputLayer(g, feat_size, 2, use_node_embeddings=True) + assert set(layer.input_projs.keys()) == set(g.ntypes) + assert set(layer.sparse_embeds.keys()) == set(g.ntypes) + assert set(layer.proj_matrix.keys()) == set(g.ntypes) + node_feat = {} + node_embs = {} + input_nodes = {} + for ntype in g.ntypes: + # We make the projection matrix a diagonal matrix so that + # the input and output matrices are identical. + nn.init.eye_(layer.input_projs[ntype]) + assert layer.proj_matrix[ntype].shape == (4, 2) + # We make the projection matrix that can simply add the node features + # and the node sparse embeddings after projection. + with th.no_grad(): + layer.proj_matrix[ntype][:2,:] = layer.input_projs[ntype] + layer.proj_matrix[ntype][2:,:] = layer.input_projs[ntype] + input_nodes[ntype] = np.arange(10) + node_feat[ntype] = g.nodes[ntype].data['feat'][input_nodes[ntype]] + node_embs[ntype] = layer.sparse_embeds[ntype](th.from_numpy(input_nodes[ntype]).cuda()) + embed = layer(node_feat, input_nodes) + assert len(embed) == len(input_nodes) + assert len(embed) == len(node_feat) + for ntype in embed: + true_val = node_feat[ntype].detach().numpy() + node_embs[ntype].detach().cpu().numpy() + assert_almost_equal(embed[ntype].detach().cpu().numpy(), true_val) + if is_wholegraph_sparse_emb(): + wgth.finalize() + th.distributed.destroy_process_group() + +@pytest.mark.parametrize("dev", ['cpu','cuda:0']) +def test_wg_compute_embed(dev): + # initialize the torch distributed environment + wgth = pytest.importorskip("pylibwholegraph.torch") + use_wholegraph_sparse_emb() + initialize(use_wholegraph=is_wholegraph_sparse_emb()) + with tempfile.TemporaryDirectory() as tmpdirname: + # get the test dummy distributed graph + g, _ = generate_dummy_dist_graph(tmpdirname) + print('g has {} nodes of n0 and {} nodes of n1'.format( + g.number_of_nodes('n0'), g.number_of_nodes('n1'))) + + feat_size = get_feat_size(g, {'n0' : ['feat']}) + layer = GSNodeEncoderInputLayer(g, feat_size, 2) + nn.init.eye_(layer.input_projs['n0']) + nn.init.eye_(layer.proj_matrix['n1']) + layer.to(dev) + + embeds = compute_node_input_embeddings(g, 10, layer, + feat_field={'n0' : ['feat']}) + assert len(embeds) == len(g.ntypes) + assert_almost_equal(embeds['n0'][0:len(embeds['n1'])].cpu().numpy(), + g.nodes['n0'].data['feat'][0:g.number_of_nodes('n0')].cpu().numpy()) + indices = th.arange(g.number_of_nodes('n1')) + assert_almost_equal(embeds['n1'][0:len(embeds['n1'])].cpu().numpy(), + layer.sparse_embeds['n1'](indices.cuda()).cpu().detach().numpy()) + # Run it again to tigger the branch that access 'input_emb' directly. + embeds = compute_node_input_embeddings(g, 10, layer, + feat_field={'n0' : ['feat']}) + if is_wholegraph_sparse_emb(): + wgth.finalize() + th.distributed.destroy_process_group() + if __name__ == '__main__': test_wg_sparse_embed_save(4) - test_wg_sparse_embed_load(8, 8) \ No newline at end of file + test_wg_sparse_embed_load(3, 8) + test_wg_sparse_embed_load(8, 8)