diff --git a/graphstorm-processing/graphstorm_processing/config/config_conversion/gconstruct_converter.py b/graphstorm-processing/graphstorm_processing/config/config_conversion/gconstruct_converter.py index c49c14dfc0..85e2b9bfdb 100644 --- a/graphstorm-processing/graphstorm_processing/config/config_conversion/gconstruct_converter.py +++ b/graphstorm-processing/graphstorm_processing/config/config_conversion/gconstruct_converter.py @@ -98,7 +98,13 @@ def _convert_feature(feats: list[dict]) -> list[dict]: if gconstruct_transform_dict["name"] == "max_min_norm": gsp_transformation_dict["name"] = "numerical" - gsp_transformation_dict["kwargs"] = {"normalizer": "min-max", "imputer": "none"} + gsp_transformation_dict["kwargs"] = { + "normalizer": "min-max", + "imputer": "none", + } + + if gconstruct_transform_dict.get("out_dtype") in ["float32", "float64"]: + gsp_transformation_dict["kwargs"]["out_dtype"] = gconstruct_transform_dict["out_dtype"] elif gconstruct_transform_dict["name"] == "bucket_numerical": gsp_transformation_dict["name"] = "bucket-numerical" assert ( @@ -115,17 +121,13 @@ def _convert_feature(feats: list[dict]) -> list[dict]: } elif gconstruct_transform_dict["name"] == "rank_gauss": gsp_transformation_dict["name"] = "numerical" + gsp_transformation_dict["kwargs"] = { + "normalizer": "rank-gauss", + "imputer": "none", + } + if "epsilon" in gconstruct_transform_dict: - gsp_transformation_dict["kwargs"] = { - "epsilon": gconstruct_transform_dict["epsilon"], - "normalizer": "rank-gauss", - "imputer": "none", - } - else: - gsp_transformation_dict["kwargs"] = { - "normalizer": "rank-gauss", - "imputer": "none", - } + gsp_transformation_dict["kwargs"]["epsilon"] = gconstruct_transform_dict["epsilon"] elif gconstruct_transform_dict["name"] == "to_categorical": if "separator" in gconstruct_transform_dict: gsp_transformation_dict["name"] = "multi-categorical" diff --git a/graphstorm-processing/tests/test_converter.py b/graphstorm-processing/tests/test_converter.py index a4f64591ef..16fa1570c2 100644 --- a/graphstorm-processing/tests/test_converter.py +++ b/graphstorm-processing/tests/test_converter.py @@ -53,7 +53,7 @@ def test_try_read_file_with_wildcard( def test_try_read_unsupported_feature(converter: GConstructConfigConverter, node_dict: dict): - """We currently only support no-op and numerical features, so should error out otherwise.""" + """We should test about giving unknown feature transformation type.""" node_dict["nodes"][0]["features"] = [ { "feature_col": ["paper_title"], @@ -64,6 +64,35 @@ def test_try_read_unsupported_feature(converter: GConstructConfigConverter, node with pytest.raises(ValueError): _ = converter.convert_nodes(node_dict["nodes"]) +def test_try_convert_out_dtype(converter: GConstructConfigConverter, node_dict: dict): + node_dict["nodes"][0]["features"] = [ + { + "feature_col": ["paper_title"], + "transform": {"name": "max_min_norm", "out_dtype": "float32"}, + } + ] + + res = converter.convert_nodes(node_dict["nodes"])[0] + assert res.features == [{'column': 'paper_title', 'transformation': {'kwargs': {'imputer': 'none', + 'normalizer': 'min-max', + 'out_dtype': 'float32'}, + 'name': 'numerical'}}] + + node_dict["nodes"][0]["features"][0]["transform"]["out_dtype"] = "float64" + + res = converter.convert_nodes(node_dict["nodes"])[0] + assert res.features == [{'column': 'paper_title', 'transformation': {'kwargs': {'imputer': 'none', + 'normalizer': 'min-max', + 'out_dtype': 'float64'}, + 'name': 'numerical'}}] + + node_dict["nodes"][0]["features"][0]["transform"]["out_dtype"] = "float16" + + res = converter.convert_nodes(node_dict["nodes"])[0] + assert res.features == [{'column': 'paper_title', 'transformation': {'kwargs': {'imputer': 'none', + 'normalizer': 'min-max'}, + 'name': 'numerical'}}] + def test_read_node_gconstruct(converter: GConstructConfigConverter, node_dict: dict): """Multiple test cases for GConstruct node conversion""" diff --git a/graphstorm-processing/tests/test_dist_noop_transformation.py b/graphstorm-processing/tests/test_dist_noop_transformation.py index 8604adf67f..1486609143 100644 --- a/graphstorm-processing/tests/test_dist_noop_transformation.py +++ b/graphstorm-processing/tests/test_dist_noop_transformation.py @@ -76,7 +76,7 @@ def test_noop_floatvector_transformation(spark: SparkSession, check_df_schema): assert_array_equal(expected_values, transformed_values) -def test_noop_largegint_transformation(spark: SparkSession, check_df_schema): +def test_noop_largeint_transformation(spark: SparkSession, check_df_schema): """No-op transformation for long numerical columns""" large_int = 4 * 10**18 data = [ @@ -99,4 +99,6 @@ def test_noop_largegint_transformation(spark: SparkSession, check_df_schema): transformed_values = [row[col_name] for row in transformed_df.collect()] + print(data[0]) + print(vec_df[0]) assert_array_equal([val[0] for val in data], transformed_values)