Skip to content

Commit

Permalink
Fix lint
Browse files Browse the repository at this point in the history
  • Loading branch information
Xiang Song committed May 27, 2024
1 parent 2a9aa88 commit fc45ea2
Showing 1 changed file with 33 additions and 19 deletions.
52 changes: 33 additions & 19 deletions python/graphstorm/run/gsgnn_mt/gsgnn_mt.py
Original file line number Diff line number Diff line change
Expand Up @@ -68,7 +68,9 @@ def create_task_train_dataloader(task, config, train_data):

logging.info("Create dataloader for %s", task.task_id)
if task.task_type in [BUILTIN_TASK_NODE_CLASSIFICATION, BUILTIN_TASK_NODE_REGRESSION]:
train_idxs = train_data.get_node_train_set(task_config.target_ntype, mask=task_config.train_mask)
train_idxs = train_data.get_node_train_set(
task_config.target_ntype,
mask=task_config.train_mask)
# TODO(xiangsx): Support construct feat
return GSgnnNodeDataLoader(train_data,
train_idxs,
Expand All @@ -78,7 +80,9 @@ def create_task_train_dataloader(task, config, train_data):
node_feats=node_feats,
label_field=task_config.label_field)
elif task.task_type in [BUILTIN_TASK_EDGE_CLASSIFICATION, BUILTIN_TASK_EDGE_REGRESSION]:
train_idxs = train_data.get_edge_train_set(task_config.target_etype, mask=task_config.train_mask)
train_idxs = train_data.get_edge_train_set(
task_config.target_etype,
mask=task_config.train_mask)
# TODO(xiangsx): Support construct feat
return GSgnnEdgeDataLoader(train_data,
train_idxs,
Expand All @@ -91,7 +95,9 @@ def create_task_train_dataloader(task, config, train_data):
reverse_edge_types_map=task_config.reverse_edge_types_map,
remove_target_edge_type=task_config.remove_target_edge_type)
elif task.task_type in [BUILTIN_TASK_LINK_PREDICTION]:
train_idxs = train_data.get_edge_train_set(task_config.train_etype, mask=task_config.train_mask)
train_idxs = train_data.get_edge_train_set(
task_config.train_etype,
mask=task_config.train_mask)
dataloader_cls = gs.get_builtin_lp_train_dataloader_class(task_config)
return dataloader_cls(train_data,
train_idxs,
Expand Down Expand Up @@ -226,7 +232,9 @@ def create_task_test_dataloader(task, config, train_data):
label_field=task_config.label_field)

elif task.task_type in [BUILTIN_TASK_EDGE_CLASSIFICATION, BUILTIN_TASK_EDGE_REGRESSION]:
test_idxs = train_data.get_edge_test_set(task_config.target_etype, mask=task_config.test_mask)
test_idxs = train_data.get_edge_test_set(
task_config.target_etype,
mask=task_config.test_mask)
# All tasks share the same GNN model, so the fanout should be the global fanout
fanout = config.eval_fanout if task_config.use_mini_batch_infer else []
if len(test_idxs) > 0:
Expand Down Expand Up @@ -314,18 +322,20 @@ def create_evaluator(task):
assert len(config.eval_metric) == 1, \
"GraphStorm doees not support computing multiple metrics at the same time."
if config.report_eval_per_type:
return GSgnnPerEtypeMrrLPEvaluator(eval_frequency=config.eval_frequency,
major_etype=config.model_select_etype,
use_early_stop=config.use_early_stop,
early_stop_burnin_rounds=config.early_stop_burnin_rounds,
early_stop_rounds=config.early_stop_rounds,
early_stop_strategy=config.early_stop_strategy)
return GSgnnPerEtypeMrrLPEvaluator(
eval_frequency=config.eval_frequency,
major_etype=config.model_select_etype,
use_early_stop=config.use_early_stop,
early_stop_burnin_rounds=config.early_stop_burnin_rounds,
early_stop_rounds=config.early_stop_rounds,
early_stop_strategy=config.early_stop_strategy)
else:
return GSgnnMrrLPEvaluator(eval_frequency=config.eval_frequency,
use_early_stop=config.use_early_stop,
early_stop_burnin_rounds=config.early_stop_burnin_rounds,
early_stop_rounds=config.early_stop_rounds,
early_stop_strategy=config.early_stop_strategy)
return GSgnnMrrLPEvaluator(
eval_frequency=config.eval_frequency,
use_early_stop=config.use_early_stop,
early_stop_burnin_rounds=config.early_stop_burnin_rounds,
early_stop_rounds=config.early_stop_rounds,
early_stop_strategy=config.early_stop_strategy)
return None

def main(config_args):
Expand Down Expand Up @@ -362,7 +372,10 @@ def main(config_args):
train_dataloaders.append(train_loader)
val_dataloaders.append(val_loader)
test_dataloaders.append(test_loader)
decoder, loss_func = gs.create_task_decoder(task, train_data.g, encoder_out_dims, train_task=True)
decoder, loss_func = gs.create_task_decoder(task,
train_data.g,
encoder_out_dims,
train_task=True)
model.add_task(task.task_id, task.task_type, decoder, loss_func)
if not config.no_validation:
if val_loader is None:
Expand Down Expand Up @@ -426,7 +439,10 @@ def main(config_args):
gs.gsf.set_encoder(model, train_data.g, config, train_task=True)

for task in tasks:
decoder, loss_func = gs.create_task_decoder(task, train_data.g, encoder_out_dims, train_task=True)
decoder, loss_func = gs.create_task_decoder(task,
train_data.g,
encoder_out_dims,
train_task=True)
model.add_task(task.task_id, task.task_type, decoder, loss_func)
best_model_path = trainer.get_best_model_path()
# TODO(zhengda) the model path has to be in a shared filesystem.
Expand Down Expand Up @@ -459,5 +475,3 @@ def generate_parser():
# Ignore unknown args to make script more robust to input arguments
gs_args, _ = arg_parser.parse_known_args()
main(gs_args)


0 comments on commit fc45ea2

Please sign in to comment.