Skip to content

Commit

Permalink
add test
Browse files Browse the repository at this point in the history
  • Loading branch information
jalencato committed Jul 31, 2024
1 parent 9ff46da commit c56e61a
Showing 1 changed file with 59 additions and 0 deletions.
59 changes: 59 additions & 0 deletions graphstorm-processing/tests/test_converter.py
Original file line number Diff line number Diff line change
Expand Up @@ -65,6 +65,65 @@ def test_try_read_unsupported_feature(converter: GConstructConfigConverter, node
_ = converter.convert_nodes(node_dict["nodes"])


def test_try_read_invalid_gconstruct_config(converter: GConstructConfigConverter, node_dict: dict):
"""Custom Split Columns"""
node_dict["nodes"][0]["labels"] = [
{
"label_col": "label",
"task_type": "classification",
"custom_split_filenames": {
"column": ["src", "dst", "inter"],
},
"label_stats_type": "frequency_cnt",
}
]

with pytest.raises(AssertionError):
_ = converter.convert_nodes(node_dict["nodes"])

"""Feature Name must exist for multiple feature columns"""
node_dict["nodes"][0]["features"] = [
{
"feature_col": ["feature_1", "feature_2"]
}
]

with pytest.raises(AssertionError):
_ = converter.convert_nodes(node_dict["nodes"])

"""Unsupported output dtype"""
node_dict["nodes"][0]["features"] = [
{
"feature_col": ["feature_1"],
"out_dtype": "float16"
}
]

with pytest.raises(AssertionError):
_ = converter.convert_nodes(node_dict["nodes"])

"""Unsupported format type"""
node_dict["nodes"][0]["format"] = \
{"name": "txt", "separator": ","}

with pytest.raises(AssertionError):
_ = converter.convert_nodes(node_dict["nodes"])


def test_try_read_multi_task_gconstruct_config(converter: GConstructConfigConverter, node_dict: dict):
"""Check unsupported mask column """
node_dict["nodes"][0]["labels"] = [
{
"label_col": "label",
"task_type": "classification",
"mask_field_names": "train_mask"
}
]

with pytest.raises(AssertionError):
_ = converter.convert_nodes(node_dict["nodes"])


@pytest.mark.parametrize("transform", ["max_min_norm", "rank_gauss"])
@pytest.mark.parametrize("out_dtype", ["float16", "float32", "float64"])
def test_try_convert_out_dtype(
Expand Down

0 comments on commit c56e61a

Please sign in to comment.