From fc45ea215961d2030436381e5d3e07cc996240f8 Mon Sep 17 00:00:00 2001 From: Xiang Song Date: Mon, 27 May 2024 00:14:57 -0700 Subject: [PATCH] Fix lint --- python/graphstorm/run/gsgnn_mt/gsgnn_mt.py | 52 ++++++++++++++-------- 1 file changed, 33 insertions(+), 19 deletions(-) diff --git a/python/graphstorm/run/gsgnn_mt/gsgnn_mt.py b/python/graphstorm/run/gsgnn_mt/gsgnn_mt.py index df7f0532b3..d264297474 100644 --- a/python/graphstorm/run/gsgnn_mt/gsgnn_mt.py +++ b/python/graphstorm/run/gsgnn_mt/gsgnn_mt.py @@ -68,7 +68,9 @@ def create_task_train_dataloader(task, config, train_data): logging.info("Create dataloader for %s", task.task_id) if task.task_type in [BUILTIN_TASK_NODE_CLASSIFICATION, BUILTIN_TASK_NODE_REGRESSION]: - train_idxs = train_data.get_node_train_set(task_config.target_ntype, mask=task_config.train_mask) + train_idxs = train_data.get_node_train_set( + task_config.target_ntype, + mask=task_config.train_mask) # TODO(xiangsx): Support construct feat return GSgnnNodeDataLoader(train_data, train_idxs, @@ -78,7 +80,9 @@ def create_task_train_dataloader(task, config, train_data): node_feats=node_feats, label_field=task_config.label_field) elif task.task_type in [BUILTIN_TASK_EDGE_CLASSIFICATION, BUILTIN_TASK_EDGE_REGRESSION]: - train_idxs = train_data.get_edge_train_set(task_config.target_etype, mask=task_config.train_mask) + train_idxs = train_data.get_edge_train_set( + task_config.target_etype, + mask=task_config.train_mask) # TODO(xiangsx): Support construct feat return GSgnnEdgeDataLoader(train_data, train_idxs, @@ -91,7 +95,9 @@ def create_task_train_dataloader(task, config, train_data): reverse_edge_types_map=task_config.reverse_edge_types_map, remove_target_edge_type=task_config.remove_target_edge_type) elif task.task_type in [BUILTIN_TASK_LINK_PREDICTION]: - train_idxs = train_data.get_edge_train_set(task_config.train_etype, mask=task_config.train_mask) + train_idxs = train_data.get_edge_train_set( + task_config.train_etype, + mask=task_config.train_mask) dataloader_cls = gs.get_builtin_lp_train_dataloader_class(task_config) return dataloader_cls(train_data, train_idxs, @@ -226,7 +232,9 @@ def create_task_test_dataloader(task, config, train_data): label_field=task_config.label_field) elif task.task_type in [BUILTIN_TASK_EDGE_CLASSIFICATION, BUILTIN_TASK_EDGE_REGRESSION]: - test_idxs = train_data.get_edge_test_set(task_config.target_etype, mask=task_config.test_mask) + test_idxs = train_data.get_edge_test_set( + task_config.target_etype, + mask=task_config.test_mask) # All tasks share the same GNN model, so the fanout should be the global fanout fanout = config.eval_fanout if task_config.use_mini_batch_infer else [] if len(test_idxs) > 0: @@ -314,18 +322,20 @@ def create_evaluator(task): assert len(config.eval_metric) == 1, \ "GraphStorm doees not support computing multiple metrics at the same time." if config.report_eval_per_type: - return GSgnnPerEtypeMrrLPEvaluator(eval_frequency=config.eval_frequency, - major_etype=config.model_select_etype, - use_early_stop=config.use_early_stop, - early_stop_burnin_rounds=config.early_stop_burnin_rounds, - early_stop_rounds=config.early_stop_rounds, - early_stop_strategy=config.early_stop_strategy) + return GSgnnPerEtypeMrrLPEvaluator( + eval_frequency=config.eval_frequency, + major_etype=config.model_select_etype, + use_early_stop=config.use_early_stop, + early_stop_burnin_rounds=config.early_stop_burnin_rounds, + early_stop_rounds=config.early_stop_rounds, + early_stop_strategy=config.early_stop_strategy) else: - return GSgnnMrrLPEvaluator(eval_frequency=config.eval_frequency, - use_early_stop=config.use_early_stop, - early_stop_burnin_rounds=config.early_stop_burnin_rounds, - early_stop_rounds=config.early_stop_rounds, - early_stop_strategy=config.early_stop_strategy) + return GSgnnMrrLPEvaluator( + eval_frequency=config.eval_frequency, + use_early_stop=config.use_early_stop, + early_stop_burnin_rounds=config.early_stop_burnin_rounds, + early_stop_rounds=config.early_stop_rounds, + early_stop_strategy=config.early_stop_strategy) return None def main(config_args): @@ -362,7 +372,10 @@ def main(config_args): train_dataloaders.append(train_loader) val_dataloaders.append(val_loader) test_dataloaders.append(test_loader) - decoder, loss_func = gs.create_task_decoder(task, train_data.g, encoder_out_dims, train_task=True) + decoder, loss_func = gs.create_task_decoder(task, + train_data.g, + encoder_out_dims, + train_task=True) model.add_task(task.task_id, task.task_type, decoder, loss_func) if not config.no_validation: if val_loader is None: @@ -426,7 +439,10 @@ def main(config_args): gs.gsf.set_encoder(model, train_data.g, config, train_task=True) for task in tasks: - decoder, loss_func = gs.create_task_decoder(task, train_data.g, encoder_out_dims, train_task=True) + decoder, loss_func = gs.create_task_decoder(task, + train_data.g, + encoder_out_dims, + train_task=True) model.add_task(task.task_id, task.task_type, decoder, loss_func) best_model_path = trainer.get_best_model_path() # TODO(zhengda) the model path has to be in a shared filesystem. @@ -459,5 +475,3 @@ def generate_parser(): # Ignore unknown args to make script more robust to input arguments gs_args, _ = arg_parser.parse_known_args() main(gs_args) - -