Skip to content

Commit

Permalink
add gconstruct converter
Browse files Browse the repository at this point in the history
  • Loading branch information
jalencato committed Nov 2, 2023
1 parent 7191452 commit 80d4762
Show file tree
Hide file tree
Showing 2 changed files with 33 additions and 0 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -112,6 +112,13 @@ def _convert_feature(feats: list[dict]) -> list[dict]:
"slide_window_size": gconstruct_transform_dict["slide_window_size"],
"imputer": "none",
}
elif gconstruct_transform_dict["name"] == "rank_gauss":
gsp_transformation_dict["name"] = "rank-gauss"
if "epsilon" in gconstruct_transform_dict:
gsp_transformation_dict["kwargs"] = {"epsilon": gconstruct_transform_dict["epsilon"],
"normalizer": "none", "imputer": "none"}
else:
gsp_transformation_dict["kwargs"] = {"normalizer": "none", "imputer": "none"}
# TODO: Add support for other common transformations here
else:
raise ValueError(
Expand Down
26 changes: 26 additions & 0 deletions graphstorm-processing/tests/test_converter.py
Original file line number Diff line number Diff line change
Expand Up @@ -217,6 +217,16 @@ def test_convert_gsprocessing(converter: GConstructConfigConverter):
"slide_window_size": 5,
},
},
{
"feature_col": ["num_citations"],
"feature_name": "rank_gauss1",
"transform": {"name": "rank_gauss"}
},
{
"feature_col": ["num_citations"],
"feature_name": "rank_gauss2",
"transform": {"name": "rank_gauss", "epsilon": 0.1}
},
],
"labels": [
{"label_col": "label", "task_type": "classification", "split_pct": [0.8, 0.1, 0.1]}
Expand Down Expand Up @@ -273,6 +283,22 @@ def test_convert_gsprocessing(converter: GConstructConfigConverter):
},
},
},
{
"column": "num_citations",
'name': 'rank_gauss1',
"transformation": {
"name": "rank-gauss",
"kwargs": {"normalizer": "none", "imputer": "none"},
},
},
{
"column": "num_citations",
'name': 'rank_gauss2',
"transformation": {
"name": "rank-gauss",
"kwargs": {"epsilon": 0.1, "normalizer": "none", "imputer": "none"},
},
},
]
assert nodes_output["labels"] == [
{
Expand Down

0 comments on commit 80d4762

Please sign in to comment.