Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add Command Line for Embedding Generating #525

Merged
merged 43 commits into from
Oct 13, 2023
Merged
Show file tree
Hide file tree
Changes from 31 commits
Commits
Show all changes
43 commits
Select commit Hold shift + click to select a range
1d6432f
initial commit
jalencato Oct 3, 2023
a841911
first commit - no test
jalencato Oct 3, 2023
7607069
remove unnecessary dependency
jalencato Oct 3, 2023
4621c6e
change config
jalencato Oct 3, 2023
d8c0309
fix lint
jalencato Oct 3, 2023
323dbb0
test
jalencato Oct 4, 2023
05807e3
fix save_embed path
jalencato Oct 4, 2023
6514afe
add test
jalencato Oct 4, 2023
249df54
fix lint
jalencato Oct 4, 2023
c9de5e7
temp fix
jalencato Oct 4, 2023
2b25576
fix
jalencato Oct 4, 2023
63fbb6f
fix typo
jalencato Oct 4, 2023
b2ac45b
fix test
jalencato Oct 5, 2023
699dafa
fix test
jalencato Oct 5, 2023
c8991b0
fix
jalencato Oct 5, 2023
3442970
change test
jalencato Oct 5, 2023
273846c
rename the gs_gen_embedding to ge_gen_node_embedding
jalencato Oct 5, 2023
ce05d94
fix test
jalencato Oct 5, 2023
6196d19
Update mgpu_test.sh
jalencato Oct 5, 2023
264c80e
fix bug
jalencato Oct 5, 2023
96bdaf8
fix
jalencato Oct 5, 2023
ebe0d4b
fix embedding bug on link prediction
jalencato Oct 5, 2023
70feebd
use entire graph for embedding generation
jalencato Oct 6, 2023
a38df65
fix whole code structure
jalencato Oct 9, 2023
f642cf0
fix import bug
jalencato Oct 9, 2023
34e22e7
fix lint
jalencato Oct 9, 2023
0d9347f
fix lint
jalencato Oct 9, 2023
3c30b1b
fix bug for not restoring model
jalencato Oct 9, 2023
9c0be98
remove relation embedding
jalencato Oct 11, 2023
6822ef3
remove redundant dependency
jalencato Oct 11, 2023
2b79a47
fix lint
jalencato Oct 11, 2023
9c119de
change to trival version
jalencato Oct 12, 2023
45038f9
add doc string
jalencato Oct 12, 2023
6ceb0d0
fix edge task mini batch
jalencato Oct 12, 2023
5e39786
add
jalencato Oct 12, 2023
5704472
fix sorted bug
jalencato Oct 12, 2023
6de76ff
finish pruning
jalencato Oct 13, 2023
788297a
fix typo
jalencato Oct 13, 2023
0cd315f
apply comment
jalencato Oct 13, 2023
ada55ae
test
jalencato Oct 13, 2023
774187c
add embs
jalencato Oct 13, 2023
017e119
Merge branch 'main' into gen_embedding
jalencato Oct 13, 2023
6906543
fix typo
jalencato Oct 13, 2023
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions python/graphstorm/inference/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,3 +19,4 @@
from .lp_infer import GSgnnLinkPredictionInferrer
from .np_infer import GSgnnNodePredictionInferrer
from .ep_infer import GSgnnEdgePredictionInferrer
from .emb_infer import GSgnnEmbGenInferer
144 changes: 144 additions & 0 deletions python/graphstorm/inference/emb_infer.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,144 @@
"""
Copyright 2023 Contributors

Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at

http://www.apache.org/licenses/LICENSE-2.0

Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.

Inferrer wrapper for embedding generation.
"""
import logging
from graphstorm.config import (BUILTIN_TASK_NODE_CLASSIFICATION,
BUILTIN_TASK_NODE_REGRESSION,
BUILTIN_TASK_EDGE_CLASSIFICATION,
BUILTIN_TASK_EDGE_REGRESSION,
BUILTIN_TASK_LINK_PREDICTION)
from .graphstorm_infer import GSInferrer
classicsong marked this conversation as resolved.
Show resolved Hide resolved
from ..model.utils import save_embeddings as save_gsgnn_embeddings
from ..model.gnn import do_full_graph_inference, do_mini_batch_inference
from ..model.node_gnn import node_mini_batch_gnn_predict

from ..utils import sys_tracker, get_rank, get_world_size, barrier, create_dist_tensor


class GSgnnEmbGenInferer(GSInferrer):
""" Embedding Generation inffer inferrer.
jalencato marked this conversation as resolved.
Show resolved Hide resolved

This is a high-level inferrer wrapper that can be used directly
to generate embedding for inferer.

Parameters
----------
model : GSgnnNodeModel
The GNN model with different task.
"""

def infer(self, data, task_type, save_embed_path, loader,
use_mini_batch_infer=False,
node_id_mapping_file=None,
return_proba=True,
save_embed_format="pytorch"):
""" Do Embedding Generating

Generate node embeddings and save.
jalencato marked this conversation as resolved.
Show resolved Hide resolved

Parameters
----------
data: GSgnnData
The GraphStorm dataset
task_type : str
task_type must be one of graphstorm builtin task types
save_embed_path : str
The path where the GNN embeddings will be saved.
loader : GSEdgeDataLoader/GSNodeDataLoader
The mini-batch sampler for built-in graphstorm task.
use_mini_batch_infer : bool
jalencato marked this conversation as resolved.
Show resolved Hide resolved
Whether or not to use mini-batch inference when computing node embedings.
node_id_mapping_file: str
Path to the file storing node id mapping generated by the
graph partition algorithm.
jalencato marked this conversation as resolved.
Show resolved Hide resolved
save_embed_format : str
Specify the format of saved embeddings.
"""

device = self.device
# deal with uninitialized case first
if use_mini_batch_infer and \
jalencato marked this conversation as resolved.
Show resolved Hide resolved
task_type in {BUILTIN_TASK_EDGE_CLASSIFICATION, BUILTIN_TASK_EDGE_REGRESSION}:
assert save_embed_path is None, \
"Unable to save the node embeddings when using mini batch inference " \
"when doing edge task." \
"It is not guaranteed that mini-batch prediction will cover all the nodes."

if task_type in {BUILTIN_TASK_NODE_REGRESSION, BUILTIN_TASK_NODE_CLASSIFICATION}:
assert len(loader.data.eval_ntypes) == 1, \
jalencato marked this conversation as resolved.
Show resolved Hide resolved
"GraphStorm only support single target node type for training and inference"

assert save_embed_path is not None
jalencato marked this conversation as resolved.
Show resolved Hide resolved

sys_tracker.check('start embedding generation')
self._model.eval()

if task_type == BUILTIN_TASK_LINK_PREDICTION:
jalencato marked this conversation as resolved.
Show resolved Hide resolved
# for embedding generation, it is preferred to use full graph
if use_mini_batch_infer:
embs = do_mini_batch_inference(self._model, data, fanout=loader.fanout,
edge_mask=None,
task_tracker=self.task_tracker)
else:
embs = do_full_graph_inference(self._model, data, fanout=loader.fanout,
edge_mask=None,
task_tracker=self.task_tracker)
elif task_type in {BUILTIN_TASK_NODE_REGRESSION, BUILTIN_TASK_NODE_CLASSIFICATION}:
# only generate embeddings on the target node type
ntype = loader.data.eval_ntypes[0]
jalencato marked this conversation as resolved.
Show resolved Hide resolved
if use_mini_batch_infer:
inter_embs = node_mini_batch_gnn_predict(self._model, loader, return_proba,
return_label=False)[1]
inter_embs = {ntype: inter_embs[ntype]} if isinstance(inter_embs, dict) \
else {ntype: inter_embs}
g = loader.data.g
ntype_emb = create_dist_tensor((g.num_nodes(ntype), inter_embs[ntype].shape[1]),
dtype=inter_embs[ntype].dtype,
name=f'gen-emb-{ntype}',
part_policy=g.get_node_partition_policy(ntype),
persistent=True)
jalencato marked this conversation as resolved.
Show resolved Hide resolved
ntype_emb[loader.target_nidx[ntype]] = inter_embs[ntype]
embs = {ntype: ntype_emb}
else:
embs = do_full_graph_inference(self._model, data, fanout=loader.fanout,
task_tracker=self.task_tracker)
ntype_emb = embs[ntype]
embs = {ntype: ntype_emb}
elif task_type in {BUILTIN_TASK_EDGE_CLASSIFICATION, BUILTIN_TASK_EDGE_REGRESSION}:
# Currently it is not allowed to do mini-batch inference
# and save embedding on edge tasks
embs = do_full_graph_inference(self._model, loader.data, fanout=loader.fanout,
task_tracker=self.task_tracker)
target_ntypes = set()
for etype in loader.data.eval_etypes:
target_ntypes.add(etype[0])
target_ntypes.add(etype[2])

embs = {ntype: embs[ntype] for ntype in sorted(target_ntypes)}
jalencato marked this conversation as resolved.
Show resolved Hide resolved
else:
raise TypeError("Not supported for task type: ", task_type)

if get_rank() == 0:
logging.info("save embeddings to %s", save_embed_path)

save_gsgnn_embeddings(save_embed_path, embs, get_rank(),
get_world_size(),
device=device,
node_id_mapping_file=node_id_mapping_file,
save_embed_format=save_embed_format)
barrier()
sys_tracker.check('save embeddings')
46 changes: 46 additions & 0 deletions python/graphstorm/run/gs_gen_node_embedding.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,46 @@
"""
Copyright 2023 Contributors

Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at

http://www.apache.org/licenses/LICENSE-2.0

Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.

Entry point for running embedding generating tasks.

Run as:
python3 -m graphstorm.run.gs_gen_node_embedding <Launch args>
"""
import os
import logging

from .launch import get_argument_parser
from .launch import check_input_arguments
from .launch import submit_jobs

def main():
""" Main function
"""
parser = get_argument_parser()
args, exec_script_args = parser.parse_known_args()
check_input_arguments(args)

lib_dir = os.path.abspath(os.path.dirname(__file__))
cmd = "gsgnn_emb/gsgnn_node_emb.py"
cmd_path = os.path.join(lib_dir, cmd)
exec_script_args = [cmd_path] + exec_script_args

submit_jobs(args, exec_script_args)

if __name__ == "__main__":
FMT = "%(asctime)s %(levelname)s %(message)s"
logging.basicConfig(format=FMT, level=logging.INFO)
main()

Empty file.
148 changes: 148 additions & 0 deletions python/graphstorm/run/gsgnn_emb/gsgnn_node_emb.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,148 @@
"""
Copyright 2023 Contributors

Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at

http://www.apache.org/licenses/LICENSE-2.0

Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.

GSgnn pure gpu generate embeddings.
"""
import graphstorm as gs
from graphstorm.config import get_argument_parser
from graphstorm.config import GSConfig
from graphstorm.dataloading import (GSgnnEdgeInferData, GSgnnNodeInferData,
GSgnnEdgeDataLoader, GSgnnNodeDataLoader,
GSgnnLinkPredictionTestDataLoader,
GSgnnLinkPredictionJointTestDataLoader)
from graphstorm.utils import rt_profiler, sys_tracker, setup_device, use_wholegraph
from graphstorm.config import (BUILTIN_TASK_NODE_CLASSIFICATION,
BUILTIN_TASK_NODE_REGRESSION,
BUILTIN_TASK_EDGE_CLASSIFICATION,
BUILTIN_TASK_EDGE_REGRESSION,
BUILTIN_TASK_LINK_PREDICTION)
from graphstorm.dataloading import (BUILTIN_LP_UNIFORM_NEG_SAMPLER,
BUILTIN_LP_JOINT_NEG_SAMPLER)
from graphstorm.inference import GSgnnEmbGenInferer


def main(config_args):
""" main function
"""
config = GSConfig(config_args)
config.verify_arguments(True)

gs.initialize(ip_config=config.ip_config, backend=config.backend,
use_wholegraph=use_wholegraph(config.part_config))
rt_profiler.init(config.profile_path, rank=gs.get_rank())
sys_tracker.init(config.verbose, rank=gs.get_rank())
device = setup_device(config.local_rank)
tracker = gs.create_builtin_task_tracker(config)
if gs.get_rank() == 0:
tracker.log_params(config.__dict__)

if config.task_type == BUILTIN_TASK_LINK_PREDICTION:
jalencato marked this conversation as resolved.
Show resolved Hide resolved
input_graph = GSgnnEdgeInferData(config.graph_name,
jalencato marked this conversation as resolved.
Show resolved Hide resolved
config.part_config,
eval_etypes=config.eval_etype,
node_feat_field=config.node_feat_name,
decoder_edge_feat=config.decoder_edge_feat)
jalencato marked this conversation as resolved.
Show resolved Hide resolved
elif config.task_type in {BUILTIN_TASK_NODE_REGRESSION, BUILTIN_TASK_NODE_CLASSIFICATION}:
input_graph = GSgnnNodeInferData(config.graph_name,
jalencato marked this conversation as resolved.
Show resolved Hide resolved
config.part_config,
eval_ntypes=config.target_ntype,
node_feat_field=config.node_feat_name,
label_field=config.label_field)
jalencato marked this conversation as resolved.
Show resolved Hide resolved
elif config.task_type in {BUILTIN_TASK_EDGE_CLASSIFICATION, BUILTIN_TASK_EDGE_REGRESSION}:
input_graph = GSgnnEdgeInferData(config.graph_name,
jalencato marked this conversation as resolved.
Show resolved Hide resolved
config.part_config,
eval_etypes=config.target_etype,
node_feat_field=config.node_feat_name,
label_field=config.label_field,
decoder_edge_feat=config.decoder_edge_feat)
jalencato marked this conversation as resolved.
Show resolved Hide resolved
else:
raise TypeError("Not supported for task type: ", config.task_type)

# assert the setting for the graphstorm embedding generation.
assert config.save_embed_path is not None, \
"save embeded path cannot be none for gs_gen_node_embeddings"
assert config.restore_model_path is not None, \
"restore model path cannot be none for gs_gen_node_embeddings"

# load the model
if config.task_type == BUILTIN_TASK_LINK_PREDICTION:
model = gs.create_builtin_lp_gnn_model(input_graph.g, config, train_task=False)
elif config.task_type in {BUILTIN_TASK_NODE_REGRESSION, BUILTIN_TASK_NODE_CLASSIFICATION}:
model = gs.create_builtin_node_gnn_model(input_graph.g, config, train_task=False)
elif config.task_type in {BUILTIN_TASK_EDGE_CLASSIFICATION, BUILTIN_TASK_EDGE_REGRESSION}:
model = gs.create_builtin_edge_gnn_model(input_graph.g, config, train_task=False)
else:
raise TypeError("Not supported for task type: ", config.task_type)
model.restore_model(config.restore_model_path,
model_layer_to_load=config.restore_model_layers)

# define the dataloader
if config.task_type == BUILTIN_TASK_LINK_PREDICTION:
if config.eval_negative_sampler == BUILTIN_LP_UNIFORM_NEG_SAMPLER:
link_prediction_loader = GSgnnLinkPredictionTestDataLoader
elif config.eval_negative_sampler == BUILTIN_LP_JOINT_NEG_SAMPLER:
link_prediction_loader = GSgnnLinkPredictionJointTestDataLoader
else:
raise ValueError('Unknown test negative sampler.'
'Supported test negative samplers include '
f'[{BUILTIN_LP_UNIFORM_NEG_SAMPLER}, {BUILTIN_LP_JOINT_NEG_SAMPLER}]')

dataloader = link_prediction_loader(input_graph, input_graph.test_idxs,
batch_size=config.eval_batch_size,
num_negative_edges=config.num_negative_edges_eval,
fanout=config.eval_fanout)
elif config.task_type in {BUILTIN_TASK_NODE_REGRESSION, BUILTIN_TASK_NODE_CLASSIFICATION}:
dataloader = GSgnnNodeDataLoader(input_graph, input_graph.infer_idxs,
fanout=config.eval_fanout,
batch_size=config.eval_batch_size, device=device,
train_task=False,
construct_feat_ntype=config.construct_feat_ntype,
construct_feat_fanout=config.construct_feat_fanout)
elif config.task_type in {BUILTIN_TASK_EDGE_CLASSIFICATION, BUILTIN_TASK_EDGE_REGRESSION}:
dataloader = GSgnnEdgeDataLoader(input_graph, input_graph.infer_idxs,
fanout=config.eval_fanout,
batch_size=config.eval_batch_size,
device=device, train_task=False,
reverse_edge_types_map=config.reverse_edge_types_map,
remove_target_edge_type=config.remove_target_edge_type,
construct_feat_ntype=config.construct_feat_ntype,
construct_feat_fanout=config.construct_feat_fanout)
else:
raise TypeError("Not supported for task type: ", config.task_type)

# start the infer
emb_generator = GSgnnEmbGenInferer(model)
emb_generator.setup_device(device=device)

emb_generator.infer(input_graph, config.task_type,
save_embed_path=config.save_embed_path,
loader=dataloader,
use_mini_batch_infer=config.use_mini_batch_infer,
node_id_mapping_file=config.node_id_mapping_file,
return_proba=config.return_proba,
save_embed_format=config.save_embed_format)

def generate_parser():
""" Generate an argument parser
"""
parser = get_argument_parser()
return parser


if __name__ == '__main__':
arg_parser = generate_parser()

args = arg_parser.parse_args()
main(args)
3 changes: 2 additions & 1 deletion python/graphstorm/run/gsgnn_ep/gsgnn_ep.py
Original file line number Diff line number Diff line change
Expand Up @@ -156,7 +156,8 @@ def main(config_args):
save_embeddings(config.save_embed_path, embs, gs.get_rank(),
jalencato marked this conversation as resolved.
Show resolved Hide resolved
gs.get_world_size(),
device=device,
node_id_mapping_file=config.node_id_mapping_file)
node_id_mapping_file=config.node_id_mapping_file,
save_embed_format=config.save_embed_format)

def generate_parser():
""" Generate an argument parser
Expand Down
3 changes: 2 additions & 1 deletion python/graphstorm/run/gsgnn_lp/gsgnn_lp.py
Original file line number Diff line number Diff line change
Expand Up @@ -202,7 +202,8 @@ def main(config_args):
save_embeddings(config.save_embed_path, embeddings, gs.get_rank(),
gs.get_world_size(),
device=device,
node_id_mapping_file=config.node_id_mapping_file)
node_id_mapping_file=config.node_id_mapping_file,
save_embed_format=config.save_embed_format)

def generate_parser():
""" Generate an argument parser
Expand Down
3 changes: 2 additions & 1 deletion python/graphstorm/run/gsgnn_np/gsgnn_np.py
Original file line number Diff line number Diff line change
Expand Up @@ -156,7 +156,8 @@ def main(config_args):
save_embeddings(config.save_embed_path, embeddings, gs.get_rank(),
gs.get_world_size(),
device=device,
node_id_mapping_file=config.node_id_mapping_file)
node_id_mapping_file=config.node_id_mapping_file,
save_embed_format=config.save_embed_format)

def generate_parser():
""" Generate an argument parser
Expand Down
2 changes: 1 addition & 1 deletion python/graphstorm/run/launch.py
Original file line number Diff line number Diff line change
Expand Up @@ -908,7 +908,7 @@ def check_input_arguments(args):
), "--num-servers must be a positive number."
assert (
args.part_config is not None
), "A user has to specify a partition configuration file with --part-onfig."
), "A user has to specify a partition configuration file with --part-config."
assert (
args.ip_config is not None
), "A user has to specify an IP configuration file with --ip-config."
Expand Down
Loading