diff --git a/python/graphstorm/model/utils.py b/python/graphstorm/model/utils.py index 5a44bb5d62..6396d42669 100644 --- a/python/graphstorm/model/utils.py +++ b/python/graphstorm/model/utils.py @@ -918,8 +918,10 @@ def load_gsgnn_embeddings(emb_path, g): ntype_emb_path = os.path.join(emb_path, ntype) emb_files = os.listdir(ntype_emb_path) - ntype_emb_files = [file for file in emb_files if file.endswith(".pt") and file.startswith("emb")] - ntype_nid_files = [file for file in emb_files if file.endswith(".pt") and file.startswith("nids")] + ntype_emb_files = [file for file in emb_files if file.endswith(".pt") and + file.startswith("emb")] + ntype_nid_files = [file for file in emb_files if file.endswith(".pt") and + file.startswith("nids")] ntype_emb_files = sorted(ntype_emb_files) ntype_nid_files = sorted(ntype_nid_files) part_policy = g.get_node_partition_policy(ntype)