diff --git a/graphstorm-processing/graphstorm_processing/config/config_conversion/gconstruct_converter.py b/graphstorm-processing/graphstorm_processing/config/config_conversion/gconstruct_converter.py index 56d5edeea0..9edf300389 100644 --- a/graphstorm-processing/graphstorm_processing/config/config_conversion/gconstruct_converter.py +++ b/graphstorm-processing/graphstorm_processing/config/config_conversion/gconstruct_converter.py @@ -98,6 +98,17 @@ def _convert_feature(feats: list[dict]) -> list[dict]: if gconstruct_transform_dict["name"] == "max_min_norm": gsp_transformation_dict["name"] = "numerical" gsp_transformation_dict["kwargs"] = {"normalizer": "min-max", "imputer": "mean"} + elif gconstruct_transform_dict["name"] == "bucket_numerical": + gsp_transformation_dict["name"] = "numerical" + assert "bucket_cnt" in gconstruct_transform_dict, \ + "bucket_cnt should be in the gconstruct bucket feature transform field" + assert "range" in gconstruct_transform_dict, \ + "range should be in the gconstruct bucket feature transform field" + gsp_transformation_dict["kwargs"] = {"normalizer": "bucket_numerical", + "bucket_cnt": gconstruct_transform_dict['bucket_cnt'], + "range": gconstruct_transform_dict['range'], + "slide_window_size": gconstruct_transform_dict['slide_window_size'], + "imputer": "mean"} # TODO: Add support for other common transformations here else: raise ValueError( diff --git a/graphstorm-processing/tests/test_converter.py b/graphstorm-processing/tests/test_converter.py index e2935160c9..9c7fec6299 100644 --- a/graphstorm-processing/tests/test_converter.py +++ b/graphstorm-processing/tests/test_converter.py @@ -208,6 +208,10 @@ def test_convert_gsprocessing(converter: GConstructConfigConverter): "features": [ {"feature_col": ["citation_time"], "feature_name": "feat"}, {"feature_col": ["num_citations"], "transform": {"name": "max_min_norm"}}, + {"feature_col": ["num_citations"], "transform": {"name": "bucket_numerical", + "bucket_cnt": 9, + "range": [10, 100], + "slide_window_size": 5}}, ], "labels": [ {"label_col": "label", "task_type": "classification", "split_pct": [0.8, 0.1, 0.1]} @@ -252,6 +256,17 @@ def test_convert_gsprocessing(converter: GConstructConfigConverter): "kwargs": {"normalizer": "min-max", "imputer": "mean"}, }, }, + { + "column": "num_citations", + "transformation": { + "name": "numerical", + "kwargs": {"normalizer": "bucket_numerical", + "bucket_cnt": 9, + "range": [10, 100], + "slide_window_size": 5, + "imputer": "mean"}, + }, + }, ] assert nodes_output["labels"] == [ {