From 09cd6d08ba131749505d4a447e91412e9ed39d3a Mon Sep 17 00:00:00 2001 From: "xiang song(charlie.song)" Date: Thu, 5 Oct 2023 17:35:33 -0700 Subject: [PATCH] Change the way how node embeddings are saved (#527) Change the way how node embeddings are saved From PATH_TO_EMB: |- emb_info.json |- ntype0_emb.part00000.bin |- ... |- ntype1_emb.part00000.bin |- ... To PATH_TO_EMB: |- emb_info.json |- ntype0 |- emb.part00000.bin |- emb.part00001.bin |- ... |- ntype1 |- emb.part00000.bin *Issue #, if available:* #508 By submitting this pull request, I confirm that you can use, modify, copy, and redistribute this contribution, under the terms of your choice. --------- Co-authored-by: Xiang Song --- python/graphstorm/model/utils.py | 48 +++++++++++++++++++++++++++--- tests/end2end-tests/check_infer.py | 26 ++++++++-------- tests/unit-tests/test_utils.py | 32 +++++++++++++------- 3 files changed, 78 insertions(+), 28 deletions(-) diff --git a/python/graphstorm/model/utils.py b/python/graphstorm/model/utils.py index 355a90015c..2867016fc8 100644 --- a/python/graphstorm/model/utils.py +++ b/python/graphstorm/model/utils.py @@ -31,6 +31,9 @@ from ..utils import get_rank, barrier, get_world_size from ..data.utils import alltoallv_cpu, alltoallv_nccl +# placeholder of the ntype for homogeneous graphs +NTYPE = dgl.NTYPE + def pad_file_index(file_index, width=5): """ Left pad file_index with zerros. @@ -475,6 +478,37 @@ def save_pytorch_embeddings(model_path, embeddings, rank, world_size, device=th.device('cpu'), node_id_mapping_file=None): """ Save embeddings through pytorch a distributed way + Example: + -------- + The saved node embeddings looks like: + + .. code:: + PATH_TO_EMB: + |- emb_info.json + |- ntype0_emb.part00000.bin + |- ... + |- ntype1_emb.part00000.bin + |- ... + + The emb.info.json contains three information: + * "format", how data are stored, e.g., "pytorch". + * "world_size", the total number of file parts. 0 means there is no partition. + * "emb_name", a list of node types that have embeddings saved. + + Example: + -------- + .. code:: + { + "format": "pytorch", + "world_size": 8, + "emb_name": ["movie", "user"] + } + + .. note:: + The saved node embeddings are in GraphStorm node ID space. + You need to remap them into raw input + node ID space by following [LINK]. + Parameters ---------- model_path : str @@ -541,11 +575,17 @@ def save_pytorch_embeddings(model_path, embeddings, rank, world_size, if isinstance(embeddings, dict): # embedding per node type for name, emb in embeddings.items(): - th.save(emb, os.path.join(model_path, f'{name}_emb.part{pad_file_index(rank)}.bin')) + os.makedirs(os.path.join(model_path, name), exist_ok=True) + th.save(emb, os.path.join(os.path.join(model_path, name), + f'emb.part{pad_file_index(rank)}.bin')) emb_info["emb_name"].append(name) else: - th.save(embeddings, os.path.join(model_path, f'emb.part{pad_file_index(rank)}.bin')) - emb_info["emb_name"] = None + os.makedirs(os.path.join(model_path, NTYPE), exist_ok=True) + # There is no ntype for the embedding + # use NTYPE + th.save(embeddings, os.path.join(os.path.join(model_path, NTYPE), + f'emb.part{pad_file_index(rank)}.bin')) + emb_info["emb_name"] = NTYPE if rank == 0: with open(os.path.join(model_path, "emb_info.json"), 'w', encoding='utf-8') as f: @@ -593,7 +633,7 @@ def save_embeddings(model_path, embeddings, rank, world_size, ---------- model_path : str The path of the folder where the model is saved. - embeddings : DistTensor + embeddings : dict of DistTensor or DistTensor Embeddings to save rank : int Rank of the current process in a distributed environment. diff --git a/tests/end2end-tests/check_infer.py b/tests/end2end-tests/check_infer.py index 810c251f2d..d9cd646f0c 100644 --- a/tests/end2end-tests/check_infer.py +++ b/tests/end2end-tests/check_infer.py @@ -50,23 +50,23 @@ assert len(train_emb_info["emb_name"]) >= len(info_emb_info["emb_name"]) # feats are same - train_emb_files = os.listdir(args.train_embout) - train_emb_files = sorted(train_emb_files) - info_emb_files = os.listdir(args.infer_embout) - info_emb_files = sorted(info_emb_files) - for name in info_emb_info["emb_name"]: + for ntype in info_emb_info["emb_name"]: train_emb = [] - for f in train_emb_files: - if f.startswith(f'{name}_emb.part'): - # Only work with torch 1.13+ - train_emb.append(th.load(os.path.join(args.train_embout, f),weights_only=True)) + ntype_emb_path = os.path.join(args.train_embout, ntype) + ntype_emb_files = os.listdir(ntype_emb_path) + ntype_emb_files = sorted(ntype_emb_files) + for f in ntype_emb_files: + # Only work with torch 1.13+ + train_emb.append(th.load(os.path.join(ntype_emb_path, f),weights_only=True)) train_emb = th.cat(train_emb, dim=0) infer_emb = [] - for f in info_emb_files: - if f.startswith(f'{name}_emb.part'): - # Only work with torch 1.13+ - infer_emb.append(th.load(os.path.join(args.infer_embout, f), weights_only=True)) + ntype_emb_path = os.path.join(args.infer_embout, ntype) + ntype_emb_files = os.listdir(ntype_emb_path) + ntype_emb_files = sorted(ntype_emb_files) + for f in ntype_emb_files: + # Only work with torch 1.13+ + infer_emb.append(th.load(os.path.join(ntype_emb_path, f), weights_only=True)) infer_emb = th.cat(infer_emb, dim=0) assert train_emb.shape[0] == infer_emb.shape[0] diff --git a/tests/unit-tests/test_utils.py b/tests/unit-tests/test_utils.py index 6309e4f422..c8c22e8211 100644 --- a/tests/unit-tests/test_utils.py +++ b/tests/unit-tests/test_utils.py @@ -26,7 +26,7 @@ from numpy.testing import assert_equal, assert_almost_equal from dgl.distributed import DistTensor from graphstorm.model.utils import save_embeddings, LazyDistTensor, remove_saved_models, TopKList -from graphstorm.model.utils import _get_data_range +from graphstorm.model.utils import _get_data_range, NTYPE from graphstorm.model.utils import _exchange_node_id_mapping, distribute_nid_map from graphstorm.model.utils import shuffle_predict from graphstorm.model.utils import pad_file_index @@ -400,8 +400,10 @@ def test_save_embeddings_with_id_mapping(num_embs, backend): assert p1.exitcode == 0 # Load saved embeddings - emb0 = th.load(os.path.join(tmpdirname, f'emb.part{pad_file_index(0)}.bin'), weights_only=True) - emb1 = th.load(os.path.join(tmpdirname, f'emb.part{pad_file_index(1)}.bin'), weights_only=True) + emb0 = th.load(os.path.join(os.path.join(tmpdirname, NTYPE), + f'emb.part{pad_file_index(0)}.bin'), weights_only=True) + emb1 = th.load(os.path.join(os.path.join(tmpdirname, NTYPE), + f'emb.part{pad_file_index(1)}.bin'), weights_only=True) saved_emb = th.cat([emb0, emb1], dim=0) assert len(saved_emb) == len(emb) assert_equal(emb[nid_mapping].numpy(), saved_emb.numpy()) @@ -440,20 +442,26 @@ def test_save_embeddings_with_id_mapping(num_embs, backend): assert p1.exitcode == 0 # Load saved embeddings - emb0 = th.load(os.path.join(tmpdirname, f'n0_emb.part{pad_file_index(0)}.bin'), weights_only=True) - emb1 = th.load(os.path.join(tmpdirname, f'n0_emb.part{pad_file_index(1)}.bin'), weights_only=True) + emb0 = th.load(os.path.join(os.path.join(tmpdirname, 'n0'), + f'emb.part{pad_file_index(0)}.bin'), weights_only=True) + emb1 = th.load(os.path.join(os.path.join(tmpdirname, 'n0'), + f'emb.part{pad_file_index(1)}.bin'), weights_only=True) saved_emb = th.cat([emb0, emb1], dim=0) assert len(saved_emb) == len(embs['n0']) assert_equal(embs['n0'][nid_mappings['n0']].numpy(), saved_emb.numpy()) - emb0 = th.load(os.path.join(tmpdirname, f'n1_emb.part{pad_file_index(0)}.bin'), weights_only=True) - emb1 = th.load(os.path.join(tmpdirname, f'n1_emb.part{pad_file_index(1)}.bin'), weights_only=True) + emb0 = th.load(os.path.join(os.path.join(tmpdirname, 'n1'), + f'emb.part{pad_file_index(0)}.bin'), weights_only=True) + emb1 = th.load(os.path.join(os.path.join(tmpdirname, 'n1'), + f'emb.part{pad_file_index(1)}.bin'), weights_only=True) saved_emb = th.cat([emb0, emb1], dim=0) assert len(saved_emb) == len(embs['n1']) assert_equal(embs['n1'][nid_mappings['n1']].numpy(), saved_emb.numpy()) - emb0 = th.load(os.path.join(tmpdirname, f'n2_emb.part{pad_file_index(0)}.bin'), weights_only=True) - emb1 = th.load(os.path.join(tmpdirname, f'n2_emb.part{pad_file_index(1)}.bin'), weights_only=True) + emb0 = th.load(os.path.join(os.path.join(tmpdirname, 'n2'), + f'emb.part{pad_file_index(0)}.bin'), weights_only=True) + emb1 = th.load(os.path.join(os.path.join(tmpdirname, 'n2'), + f'emb.part{pad_file_index(1)}.bin'), weights_only=True) saved_emb = th.cat([emb0, emb1], dim=0) assert len(saved_emb) == len(embs['n2']) assert_equal(embs['n2'][nid_mappings['n2']].numpy(), saved_emb.numpy()) @@ -470,11 +478,13 @@ def test_save_embeddings(): type0_random_emb, type1_random_emb = helper_save_embedding(tmpdirname) # Only work with torch 1.13+ - feats_type0 = [th.load(os.path.join(tmpdirname, f"type0_emb.part{pad_file_index(i)}.bin"), + feats_type0 = [th.load(os.path.join(os.path.join(tmpdirname, "type0"), + f"emb.part{pad_file_index(i)}.bin"), weights_only=True) for i in range(4)] feats_type0 = th.cat(feats_type0, dim=0) # Only work with torch 1.13+ - feats_type1 = [th.load(os.path.join(tmpdirname, f"type1_emb.part{pad_file_index(i)}.bin"), + feats_type1 = [th.load(os.path.join(os.path.join(tmpdirname, "type1"), + f"emb.part{pad_file_index(i)}.bin"), weights_only=True) for i in range(4)] feats_type1 = th.cat(feats_type1, dim=0)