Skip to content

Commit

Permalink
Change the way how node embeddings are saved (#527)
Browse files Browse the repository at this point in the history
Change the way how node embeddings are saved

From
PATH_TO_EMB:
  |- emb_info.json
  |- ntype0_emb.part00000.bin
  |- ...
  |- ntype1_emb.part00000.bin
  |- ...

To
PATH_TO_EMB:
  |- emb_info.json
  |- ntype0
     |- emb.part00000.bin
     |- emb.part00001.bin
     |- ...
  |- ntype1
     |- emb.part00000.bin

*Issue #, if available:*
#508 


By submitting this pull request, I confirm that you can use, modify,
copy, and redistribute this contribution, under the terms of your
choice.

---------

Co-authored-by: Xiang Song <[email protected]>
  • Loading branch information
classicsong and Xiang Song authored Oct 6, 2023
1 parent f3df132 commit 09cd6d0
Show file tree
Hide file tree
Showing 3 changed files with 78 additions and 28 deletions.
48 changes: 44 additions & 4 deletions python/graphstorm/model/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,9 @@
from ..utils import get_rank, barrier, get_world_size
from ..data.utils import alltoallv_cpu, alltoallv_nccl

# placeholder of the ntype for homogeneous graphs
NTYPE = dgl.NTYPE

def pad_file_index(file_index, width=5):
""" Left pad file_index with zerros.
Expand Down Expand Up @@ -475,6 +478,37 @@ def save_pytorch_embeddings(model_path, embeddings, rank, world_size,
device=th.device('cpu'), node_id_mapping_file=None):
""" Save embeddings through pytorch a distributed way
Example:
--------
The saved node embeddings looks like:
.. code::
PATH_TO_EMB:
|- emb_info.json
|- ntype0_emb.part00000.bin
|- ...
|- ntype1_emb.part00000.bin
|- ...
The emb.info.json contains three information:
* "format", how data are stored, e.g., "pytorch".
* "world_size", the total number of file parts. 0 means there is no partition.
* "emb_name", a list of node types that have embeddings saved.
Example:
--------
.. code::
{
"format": "pytorch",
"world_size": 8,
"emb_name": ["movie", "user"]
}
.. note::
The saved node embeddings are in GraphStorm node ID space.
You need to remap them into raw input
node ID space by following [LINK].
Parameters
----------
model_path : str
Expand Down Expand Up @@ -541,11 +575,17 @@ def save_pytorch_embeddings(model_path, embeddings, rank, world_size,
if isinstance(embeddings, dict):
# embedding per node type
for name, emb in embeddings.items():
th.save(emb, os.path.join(model_path, f'{name}_emb.part{pad_file_index(rank)}.bin'))
os.makedirs(os.path.join(model_path, name), exist_ok=True)
th.save(emb, os.path.join(os.path.join(model_path, name),
f'emb.part{pad_file_index(rank)}.bin'))
emb_info["emb_name"].append(name)
else:
th.save(embeddings, os.path.join(model_path, f'emb.part{pad_file_index(rank)}.bin'))
emb_info["emb_name"] = None
os.makedirs(os.path.join(model_path, NTYPE), exist_ok=True)
# There is no ntype for the embedding
# use NTYPE
th.save(embeddings, os.path.join(os.path.join(model_path, NTYPE),
f'emb.part{pad_file_index(rank)}.bin'))
emb_info["emb_name"] = NTYPE

if rank == 0:
with open(os.path.join(model_path, "emb_info.json"), 'w', encoding='utf-8') as f:
Expand Down Expand Up @@ -593,7 +633,7 @@ def save_embeddings(model_path, embeddings, rank, world_size,
----------
model_path : str
The path of the folder where the model is saved.
embeddings : DistTensor
embeddings : dict of DistTensor or DistTensor
Embeddings to save
rank : int
Rank of the current process in a distributed environment.
Expand Down
26 changes: 13 additions & 13 deletions tests/end2end-tests/check_infer.py
Original file line number Diff line number Diff line change
Expand Up @@ -50,23 +50,23 @@
assert len(train_emb_info["emb_name"]) >= len(info_emb_info["emb_name"])

# feats are same
train_emb_files = os.listdir(args.train_embout)
train_emb_files = sorted(train_emb_files)
info_emb_files = os.listdir(args.infer_embout)
info_emb_files = sorted(info_emb_files)
for name in info_emb_info["emb_name"]:
for ntype in info_emb_info["emb_name"]:
train_emb = []
for f in train_emb_files:
if f.startswith(f'{name}_emb.part'):
# Only work with torch 1.13+
train_emb.append(th.load(os.path.join(args.train_embout, f),weights_only=True))
ntype_emb_path = os.path.join(args.train_embout, ntype)
ntype_emb_files = os.listdir(ntype_emb_path)
ntype_emb_files = sorted(ntype_emb_files)
for f in ntype_emb_files:
# Only work with torch 1.13+
train_emb.append(th.load(os.path.join(ntype_emb_path, f),weights_only=True))
train_emb = th.cat(train_emb, dim=0)

infer_emb = []
for f in info_emb_files:
if f.startswith(f'{name}_emb.part'):
# Only work with torch 1.13+
infer_emb.append(th.load(os.path.join(args.infer_embout, f), weights_only=True))
ntype_emb_path = os.path.join(args.infer_embout, ntype)
ntype_emb_files = os.listdir(ntype_emb_path)
ntype_emb_files = sorted(ntype_emb_files)
for f in ntype_emb_files:
# Only work with torch 1.13+
infer_emb.append(th.load(os.path.join(ntype_emb_path, f), weights_only=True))
infer_emb = th.cat(infer_emb, dim=0)

assert train_emb.shape[0] == infer_emb.shape[0]
Expand Down
32 changes: 21 additions & 11 deletions tests/unit-tests/test_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,7 @@
from numpy.testing import assert_equal, assert_almost_equal
from dgl.distributed import DistTensor
from graphstorm.model.utils import save_embeddings, LazyDistTensor, remove_saved_models, TopKList
from graphstorm.model.utils import _get_data_range
from graphstorm.model.utils import _get_data_range, NTYPE
from graphstorm.model.utils import _exchange_node_id_mapping, distribute_nid_map
from graphstorm.model.utils import shuffle_predict
from graphstorm.model.utils import pad_file_index
Expand Down Expand Up @@ -400,8 +400,10 @@ def test_save_embeddings_with_id_mapping(num_embs, backend):
assert p1.exitcode == 0

# Load saved embeddings
emb0 = th.load(os.path.join(tmpdirname, f'emb.part{pad_file_index(0)}.bin'), weights_only=True)
emb1 = th.load(os.path.join(tmpdirname, f'emb.part{pad_file_index(1)}.bin'), weights_only=True)
emb0 = th.load(os.path.join(os.path.join(tmpdirname, NTYPE),
f'emb.part{pad_file_index(0)}.bin'), weights_only=True)
emb1 = th.load(os.path.join(os.path.join(tmpdirname, NTYPE),
f'emb.part{pad_file_index(1)}.bin'), weights_only=True)
saved_emb = th.cat([emb0, emb1], dim=0)
assert len(saved_emb) == len(emb)
assert_equal(emb[nid_mapping].numpy(), saved_emb.numpy())
Expand Down Expand Up @@ -440,20 +442,26 @@ def test_save_embeddings_with_id_mapping(num_embs, backend):
assert p1.exitcode == 0

# Load saved embeddings
emb0 = th.load(os.path.join(tmpdirname, f'n0_emb.part{pad_file_index(0)}.bin'), weights_only=True)
emb1 = th.load(os.path.join(tmpdirname, f'n0_emb.part{pad_file_index(1)}.bin'), weights_only=True)
emb0 = th.load(os.path.join(os.path.join(tmpdirname, 'n0'),
f'emb.part{pad_file_index(0)}.bin'), weights_only=True)
emb1 = th.load(os.path.join(os.path.join(tmpdirname, 'n0'),
f'emb.part{pad_file_index(1)}.bin'), weights_only=True)
saved_emb = th.cat([emb0, emb1], dim=0)
assert len(saved_emb) == len(embs['n0'])
assert_equal(embs['n0'][nid_mappings['n0']].numpy(), saved_emb.numpy())

emb0 = th.load(os.path.join(tmpdirname, f'n1_emb.part{pad_file_index(0)}.bin'), weights_only=True)
emb1 = th.load(os.path.join(tmpdirname, f'n1_emb.part{pad_file_index(1)}.bin'), weights_only=True)
emb0 = th.load(os.path.join(os.path.join(tmpdirname, 'n1'),
f'emb.part{pad_file_index(0)}.bin'), weights_only=True)
emb1 = th.load(os.path.join(os.path.join(tmpdirname, 'n1'),
f'emb.part{pad_file_index(1)}.bin'), weights_only=True)
saved_emb = th.cat([emb0, emb1], dim=0)
assert len(saved_emb) == len(embs['n1'])
assert_equal(embs['n1'][nid_mappings['n1']].numpy(), saved_emb.numpy())

emb0 = th.load(os.path.join(tmpdirname, f'n2_emb.part{pad_file_index(0)}.bin'), weights_only=True)
emb1 = th.load(os.path.join(tmpdirname, f'n2_emb.part{pad_file_index(1)}.bin'), weights_only=True)
emb0 = th.load(os.path.join(os.path.join(tmpdirname, 'n2'),
f'emb.part{pad_file_index(0)}.bin'), weights_only=True)
emb1 = th.load(os.path.join(os.path.join(tmpdirname, 'n2'),
f'emb.part{pad_file_index(1)}.bin'), weights_only=True)
saved_emb = th.cat([emb0, emb1], dim=0)
assert len(saved_emb) == len(embs['n2'])
assert_equal(embs['n2'][nid_mappings['n2']].numpy(), saved_emb.numpy())
Expand All @@ -470,11 +478,13 @@ def test_save_embeddings():
type0_random_emb, type1_random_emb = helper_save_embedding(tmpdirname)

# Only work with torch 1.13+
feats_type0 = [th.load(os.path.join(tmpdirname, f"type0_emb.part{pad_file_index(i)}.bin"),
feats_type0 = [th.load(os.path.join(os.path.join(tmpdirname, "type0"),
f"emb.part{pad_file_index(i)}.bin"),
weights_only=True) for i in range(4)]
feats_type0 = th.cat(feats_type0, dim=0)
# Only work with torch 1.13+
feats_type1 = [th.load(os.path.join(tmpdirname, f"type1_emb.part{pad_file_index(i)}.bin"),
feats_type1 = [th.load(os.path.join(os.path.join(tmpdirname, "type1"),
f"emb.part{pad_file_index(i)}.bin"),
weights_only=True) for i in range(4)]
feats_type1 = th.cat(feats_type1, dim=0)

Expand Down

0 comments on commit 09cd6d0

Please sign in to comment.