Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[GSProcessing] Supplement for GConstruct Config Check for GSProcessing #937

Merged
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,8 @@
from typing import Any
from collections.abc import Mapping

from graphstorm_processing.constants import SUPPORTED_FILE_TYPES, VALID_OUTDTYPE

from .converter_base import ConfigConverter
from .meta_configuration import NodeConfig, EdgeConfig

Expand Down Expand Up @@ -67,6 +69,11 @@ def _convert_label(labels: list[dict]) -> list[dict]:
}
else:
label_custom_split_filenames = label["custom_split_filenames"]
if isinstance(label_custom_split_filenames["column"], list):
assert len(label_custom_split_filenames["column"]) <= 2, (
"Custom split filenames should have one column for node labels, "
"and two columns for edges labels exactly"
)
label_dict["custom_split_filenames"] = {
"train": label_custom_split_filenames["train"],
"valid": label_custom_split_filenames["valid"],
Expand All @@ -76,6 +83,10 @@ def _convert_label(labels: list[dict]) -> list[dict]:
if "separator" in label:
label_sep = label["separator"]
label_dict["separator"] = label_sep
# Not supported for multi-task config for GSProcessing
assert "mask_field_names" not in label, (
"GSProcessing currently cannot " "construct labels for multi-task learning"
)
labels_list.append(label_dict)
except KeyError as exc:
raise KeyError(f"A required key was missing from label input {label}") from exc
Expand Down Expand Up @@ -103,6 +114,11 @@ def _convert_feature(feats: list[Mapping[str, Any]]) -> list[dict]:
gsp_feat_dict["column"] = gconstruct_feat_dict["feature_col"]
elif isinstance(gconstruct_feat_dict["feature_col"], list):
gsp_feat_dict["column"] = gconstruct_feat_dict["feature_col"][0]
if len(gconstruct_feat_dict["feature_col"]) >= 2:
assert "feature_name" in gconstruct_feat_dict, (
"feature_name should be in the gconstruct "
"feature field when feature_col is a list"
)
if "feature_name" in gconstruct_feat_dict:
gsp_feat_dict["name"] = gconstruct_feat_dict["feature_name"]

Expand All @@ -117,7 +133,7 @@ def _convert_feature(feats: list[Mapping[str, Any]]) -> list[dict]:
"imputer": "none",
}

if gconstruct_transform_dict.get("out_dtype") in ["float32", "float64"]:
if gconstruct_transform_dict.get("out_dtype") in VALID_OUTDTYPE:
gsp_transformation_dict["kwargs"]["out_dtype"] = gconstruct_transform_dict[
"out_dtype"
]
Expand Down Expand Up @@ -184,8 +200,8 @@ def _convert_feature(feats: list[Mapping[str, Any]]) -> list[dict]:

if "out_dtype" in gconstruct_feat_dict:
assert (
gconstruct_feat_dict["out_dtype"] == "float32"
), "GSProcessing currently only supports float32 features"
gconstruct_feat_dict["out_dtype"] in VALID_OUTDTYPE
), "GSProcessing currently only supports float32 or float64 features"

gsp_feat_dict["transformation"] = gsp_transformation_dict
gsp_feats_list.append(gsp_feat_dict)
Expand All @@ -200,6 +216,9 @@ def convert_nodes(nodes_entries):
node_type, node_col = n["node_type"], n["node_id_col"]
# format
node_format = n["format"]["name"]
assert (
node_format in SUPPORTED_FILE_TYPES
), "GSProcessing only supports parquet files and csv files."
if "separator" not in n["format"]:
node_separator = None
else:
Expand Down Expand Up @@ -249,6 +268,9 @@ def convert_edges(edges_entries):

# format
edge_format = e["format"]["name"]
assert (
edge_format in SUPPORTED_FILE_TYPES
), "GSProcessing only supports parquet files and csv files."
if "separator" not in e["format"]:
edge_separator = None
else:
Expand Down
47 changes: 47 additions & 0 deletions graphstorm-processing/tests/test_converter.py
Original file line number Diff line number Diff line change
Expand Up @@ -65,6 +65,53 @@ def test_try_read_unsupported_feature(converter: GConstructConfigConverter, node
_ = converter.convert_nodes(node_dict["nodes"])


def test_try_read_invalid_gconstruct_config(converter: GConstructConfigConverter, node_dict: dict):
"""Custom Split Columns"""
node_dict["nodes"][0]["labels"] = [
{
"label_col": "label",
"task_type": "classification",
"custom_split_filenames": {
"column": ["src", "dst", "inter"],
},
"label_stats_type": "frequency_cnt",
}
]

with pytest.raises(AssertionError):
_ = converter.convert_nodes(node_dict["nodes"])

"""Feature Name must exist for multiple feature columns"""
node_dict["nodes"][0]["features"] = [{"feature_col": ["feature_1", "feature_2"]}]

with pytest.raises(AssertionError):
_ = converter.convert_nodes(node_dict["nodes"])

"""Unsupported output dtype"""
node_dict["nodes"][0]["features"] = [{"feature_col": ["feature_1"], "out_dtype": "float16"}]

with pytest.raises(AssertionError):
_ = converter.convert_nodes(node_dict["nodes"])

"""Unsupported format type"""
node_dict["nodes"][0]["format"] = {"name": "txt", "separator": ","}

with pytest.raises(AssertionError):
_ = converter.convert_nodes(node_dict["nodes"])


def test_try_read_multi_task_gconstruct_config(
converter: GConstructConfigConverter, node_dict: dict
):
"""Check unsupported mask column"""
node_dict["nodes"][0]["labels"] = [
{"label_col": "label", "task_type": "classification", "mask_field_names": "train_mask"}
]

with pytest.raises(AssertionError):
_ = converter.convert_nodes(node_dict["nodes"])


@pytest.mark.parametrize("transform", ["max_min_norm", "rank_gauss"])
@pytest.mark.parametrize("out_dtype", ["float16", "float32", "float64"])
def test_try_convert_out_dtype(
Expand Down
Loading