diff --git a/graphstorm-processing/graphstorm_processing/data_transformations/dist_feature_transformer.py b/graphstorm-processing/graphstorm_processing/data_transformations/dist_feature_transformer.py index ec11de7ff5..a42613b26a 100644 --- a/graphstorm-processing/graphstorm_processing/data_transformations/dist_feature_transformer.py +++ b/graphstorm-processing/graphstorm_processing/data_transformations/dist_feature_transformer.py @@ -43,7 +43,6 @@ def __init__( feature_config: FeatureConfig, spark: SparkSession, json_representation: dict, - edge_mapping_dict: dict = None, ): feat_type = feature_config.feat_type feat_name = feature_config.feat_name @@ -51,8 +50,6 @@ def __init__( self.transformation: DistributedTransformation # We use this to re-apply transformations self.json_representation = json_representation - # Node Mapping Info for hard negative feature transformation - self.edge_mapping_dict = edge_mapping_dict default_kwargs = { "cols": feature_config.cols, diff --git a/graphstorm-processing/tests/test_dist_hard_negative_transformation.py b/graphstorm-processing/tests/test_dist_hard_negative_transformation.py index 179a65dc3c..f153301eb7 100755 --- a/graphstorm-processing/tests/test_dist_hard_negative_transformation.py +++ b/graphstorm-processing/tests/test_dist_hard_negative_transformation.py @@ -65,9 +65,7 @@ def test_hard_negative_example_list(spark: SparkSession, check_df_schema, tmp_pa expected_output = [[1, -1, -1, -1], [2, 3, -1, -1], [3, 0, 1, -1], [0, -1, -1, -1]] for idx, row in enumerate(output_data): - np.testing.assert_equal( - row[0], expected_output[idx], err_msg=f"Row {idx} is not equal" - ) + np.testing.assert_equal(row[0], expected_output[idx], err_msg=f"Row {idx} is not equal") def test_hard_negative_example_str(spark: SparkSession, check_df_schema, tmp_path): @@ -107,6 +105,4 @@ def test_hard_negative_example_str(spark: SparkSession, check_df_schema, tmp_pat expected_output = [[1, -1, -1, -1], [2, 3, -1, -1], [3, 0, 1, -1], [0, -1, -1, -1]] for idx, row in enumerate(output_data): - np.testing.assert_equal( - row[0], expected_output[idx], err_msg=f"Row {idx} is not equal" - ) + np.testing.assert_equal(row[0], expected_output[idx], err_msg=f"Row {idx} is not equal")