Skip to content

Commit

Permalink
Complete the tests
Browse files Browse the repository at this point in the history
  • Loading branch information
chang-l committed Dec 12, 2023
1 parent ab492a6 commit 2c0b535
Show file tree
Hide file tree
Showing 3 changed files with 221 additions and 90 deletions.
8 changes: 6 additions & 2 deletions python/graphstorm/model/embed.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@

import time
import logging
import numpy as np
import torch as th
from torch import nn
import torch.nn.functional as F
Expand Down Expand Up @@ -314,6 +315,8 @@ def forward(self, input_feats, input_nodes):
assert isinstance(input_nodes, dict), 'The input node IDs should be in a dict.'
embs = {}
for ntype in input_nodes:
if isinstance(input_nodes[ntype], np.ndarray):
input_nodes[ntype] = th.from_numpy(input_nodes[ntype])
emb = None
if ntype in input_feats:
assert ntype in self.input_projs, \
Expand Down Expand Up @@ -346,12 +349,13 @@ def forward(self, input_feats, input_nodes):
embs[ntype] = th.zeros((0, embedding_dim),
device=device, dtype=dtype)
continue

if is_wholegraph_embedding_module(self.sparse_embeds[ntype]):
# output to local device
node_emb = self.sparse_embeds[ntype](input_nodes[ntype].cuda())
node_emb = node_emb.to(device, non_blocking=True)
emb = node_emb.to(device, non_blocking=True)
else:
node_emb = self.sparse_embeds[ntype](input_nodes[ntype], device)
emb = self.sparse_embeds[ntype](input_nodes[ntype], device)

emb = emb @ self.proj_matrix[ntype]
if emb is not None:
Expand Down
147 changes: 64 additions & 83 deletions python/graphstorm/model/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -192,7 +192,6 @@ def save_sparse_emb(model_path, sparse_emb, ntype):
num_embs = sparse_emb.wm_embedding.shape[0]
else:
num_embs = sparse_emb.num_embeddings
# wholegraph local embedding boundary is consistent to _get_sparse_emb_range
start, end = _get_sparse_emb_range(num_embs, rank, world_size)
# collect sparse_emb in a iterative way

Expand All @@ -204,12 +203,15 @@ def save_sparse_emb(model_path, sparse_emb, ntype):
os.makedirs(emb_path, exist_ok=True)

if is_wholegraph_sparse_emb() and is_wholegraph_embedding_module(sparse_emb):
# Using WholeGraph will save sparse emb in npy format, we need np.load(mmap_mode) to load the
# file concurrently by all ranks, which is not supported by torch.load
emb_file_path = os.path.join(emb_path, f'sparse_emb_{pad_file_index(rank)}.npy')
local_tensor, _ = sparse_emb.wm_embedding.get_embedding_tensor().get_local_tensor(host_view=True)
if local_tensor.shape[0] > end - start: # this could only happen in unit test
# wholegraph local embedding boundary is consistent to the output of _get_sparse_emb_range
if local_tensor.shape[0] > end - start: # this should only happen in unit test
embs = local_tensor[start:end].numpy()
else:
assert local_tensor.shape[0] == end - start, "Save/Load has invalid dimensions."
assert local_tensor.shape[0] == end - start, "WholeGraph tensor local boundary has invalid dimensions."
embs = local_tensor.numpy()
np.save(emb_file_path, embs)
else:
Expand Down Expand Up @@ -1382,101 +1384,80 @@ def load_sparse_emb(target_sparse_emb, ntype_emb_path):
else:
num_embs = target_sparse_emb.num_embeddings

# Suppose a sparse embedding is trained and saved using N trainers (e.g., GPUs).
# We are going to use K trainers/infers to load it.
# The code handles the following cases:
# 1. N == K
# 2. N > K, some trainers/infers need to load more than one files
# 3. N < K, some trainers/infers do not need to load any files
if is_wholegraph_sparse_emb() and is_wholegraph_embedding_module(target_sparse_emb):
# Suppose a sparse embedding is trained and saved using N trainers (e.g., GPUs).
# We have to use the same number of N trainers/infers to load it.
# save and load need to have the same dtype
if False: #(num_files == world_size):
if (num_files == world_size):
# When N==K, assume process group has not changed between save and load
# Each rank just needs to load its own part respectively.
file_idx = rank
filepath = os.path.join(ntype_emb_path, f'sparse_emb_{pad_file_index(file_idx)}.bin')

file_path = os.path.join(ntype_emb_path, f'sparse_emb_{pad_file_index(file_idx)}.npy')
np_emb = np.load(file_path)
emb = th.from_numpy(np_emb)
loc_sta, loc_end = _get_sparse_emb_range(num_embs, rank=rank, world_size=world_size)
assert loc_end - loc_sta == emb.shape[0], "Save/Load has invalid dimensions."
local_tensor, _ = target_sparse_emb.wm_embedding.get_embedding_tensor().get_local_tensor(host_view=True)
if local_tensor.shape[0] > emb.shape[0]: # this could only happen in unit test
if local_tensor.shape[0] > emb.shape[0]: # this should only happen in unit test
assert(loc_end - loc_sta == emb.shape[0], "Saved WholeGraph tensor has invalid file boundaries.")
local_tensor[loc_sta:loc_end] = emb
else:
assert(local_tensor.shape[0] == emb.shape[0], "WholeGraph Save/Load has invalid dimensions.")
local_tensor.copy_(emb)

target_sparse_emb.wm_embedding.get_embedding_tensor().from_filelist([filepath])

else:
loc_sta, loc_end = _get_sparse_emb_range(num_embs, rank=rank, world_size=world_size)
assert loc_end - loc_sta == emb.shape[0], "Save/Load has invalid dimensions."
local_tensor, _ = target_sparse_emb.wm_embedding.get_embedding_tensor().get_local_tensor(host_view=True)

# Dict of torch dtype -> numpy dtype (when the correspondence exists)
th_to_np_dtype_dict = {
th.bool : np.bool_,
th.uint8 : np.uint8,
th.int8 : np.int8,
th.int16 : np.int16,
th.int32 : np.int32,
th.int64 : np.int64,
th.float16 : np.float16,
th.float32 : np.float32,
th.float64 : np.float64,
th.complex64 : np.complex64,
th.complex128 : np.complex128
}
for i in range(num_files):
file_idx = i
file_path=os.path.join(ntype_emb_path, f'sparse_emb_{pad_file_index(file_idx)}.bin')
dtype = target_sparse_emb.wm_embedding.get_embedding_tensor().dtype
start, end = _get_sparse_emb_range(num_embs, rank=file_idx, world_size=num_files)
file_offset = start
file_size = end - start

loc_sta, loc_end = _get_sparse_emb_range(file_size, rank, world_size)
rank_offset = loc_sta
# TODO(chang-l): when loc_sta == loc_end, ie., file_size < world_size
# if so, we can let every proc read every file and load the needed part
assert(loc_end > loc_sta, "File size too small. # of store embs at each file must be > world_size.")
offset = rank_offset * target_sparse_emb.wm_embedding.shape[1] * th.tensor([], dtype=dtype).element_size()
shape = (loc_end-loc_sta, target_sparse_emb.wm_embedding.shape[1])

# memmap from file underneath, no heap memory allocation involved
emb = th.from_numpy(np.memmap(file_path, dtype=th_to_np_dtype_dict[dtype], mode='r', offset=offset, shape=shape))

# When N!=K, we process all saved files one by one, using all procs
# Then, followed by a distribute scatter operation to propagate the embs

def _wholegraph_load_scatter(num_embs, file_id, num_files, file_path):
file_start, file_end = _get_sparse_emb_range(num_embs, rank=file_id, world_size=num_files)
file_size = file_end - file_start
rank_sta, rank_end = _get_sparse_emb_lowerbound_range(file_size, rank, world_size)
# TODO(chang-l): verify if loc_sta == loc_end, ie., file_size < world_size, still works.

# memmap from file, no heap memory allocation involved here for now
np_emb = np.load(file_path, mmap_mode='r')
emb = th.from_numpy(np_emb)[rank_sta:rank_end]
# write sparse_emb back by wm_scatter function distributedly via nccl
# due to device memory limitation (scattered embeddings have to go through device), write back in a batched way
batch_size = 102400
loc_part_size = math.ceil(file_size / world_size)
standard_idxs = th.split(th.arange(loc_part_size), batch_size, dim=0)

idxs = th.split(th.arange(loc_end - loc_sta), batch_size, dim=0) # file local idx
assert len(standard_idxs) >= len(idxs)
if len(standard_idxs) != len(idxs): # last rank
t1 = th.arange(loc_end - loc_sta)
nrepeat = loc_part_size + loc_sta - loc_end
t2 = t1[-1].repeat(nrepeat)
idxs = th.split(th.cat((t1,t2),0), batch_size, dim=0)
assert len(standard_idxs) == len(idxs)
std_part_size = file_size // world_size
nbatches = std_part_size // batch_size
if (nbatches != (rank_end-rank_sta) // batch_size):
batch_size = (rank_end-rank_sta) // nbatches
assert(nbatches == (rank_end-rank_sta) // batch_size)
idxs = th.split(th.arange(rank_end - rank_sta), batch_size, dim=0) # local idx for embs saved in each file read by each rank
for idx in idxs:
scatter_input = emb[idx].cuda() # read from file into device memory
scatter_gidx = file_offset + rank_offset + idx
scatter_gidx = file_start + rank_sta + idx # file offset and rank offset
scatter_gidx = scatter_gidx.cuda()
target_sparse_emb.wm_embedding.get_embedding_tensor().scatter(scatter_input, scatter_gidx)

#loc_sta, loc_end = _get_sparse_emb_range(num_embs, rank=rank, world_size=world_size)
#assert loc_end - loc_sta == emb.shape[0], "Save/Load has invalid dimensions."
#local_tensor, _ = target_sparse_emb.wm_embedding.get_embedding_tensor().get_local_tensor(host_view=True)
#if local_tensor.shape[0] > emb.shape[0]: # this could only happen in unit test
# local_tensor[loc_sta:loc_end] = emb
#else:
# local_tensor.copy_(emb)

# TODO(chang-l): Extend to the case when N!=K, same as DistEmbedding (need scatter to broadcast the values)
# only copy_ the overlapped chunks
# scatter the non-local(non-overlapped) chunks of emb to target_sparse_emb
# e.g., target_sparse_emb.wm_embedding.get_embedding_tensor().scatter(emb, indices)
import pylibwholegraph
if pylibwholegraph.__version__ < "23.12.00":
import pylibwholegraph.torch.wholememory_ops as wm_ops
wmb_tensor = target_sparse_emb.wm_embedding.wmb_embedding.get_embedding_tensor()
wm_ops.wholememory_scatter_functor(scatter_input, scatter_gidx, wmb_tensor)
else:
target_sparse_emb.wm_embedding.get_embedding_tensor().scatter(scatter_input, scatter_gidx)

def _get_sparse_emb_lowerbound_range(file_size, rank, world_size):
assert rank < world_size, \
"local rank {rank} shold be smaller than world size {world_size}"
if file_size < world_size:
start = rank if rank < file_size else file_size
end = rank + 1 if rank < file_size else file_size
else:
part = file_size // world_size
start = rank * part
end = (rank + 1) * part
end = file_size if rank + 1 == world_size else end
return start, end

for i in range(num_files):
file_idx = i
file_path=os.path.join(ntype_emb_path, f'sparse_emb_{pad_file_index(file_idx)}.npy')
_wholegraph_load_scatter(num_embs, file_idx, num_files, file_path)
else:
# Suppose a sparse embedding is trained and saved using N trainers (e.g., GPUs).
# We are going to use K trainers/infers to load it.
# The code handles the following cases:
# 1. N == K
# 2. N > K, some trainers/infers need to load more than one files
# 3. N < K, some trainers/infers do not need to load any files
for i in range(math.ceil(num_files/world_size)):
file_idx = i * world_size + rank
if file_idx < num_files:
Expand Down
Loading

0 comments on commit 2c0b535

Please sign in to comment.