Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[GSProcessing] Supplement for GConstruct Config Check for GSProcessing #937

Merged
Original file line number Diff line number Diff line change
Expand Up @@ -67,6 +67,11 @@ def _convert_label(labels: list[dict]) -> list[dict]:
}
else:
label_custom_split_filenames = label["custom_split_filenames"]
if isinstance(label_custom_split_filenames["column"], list):
assert len(label_custom_split_filenames["column"]) <= 2, (
"Custom split filenames should have one column for node labels, "
"and two columns for edges labels exactly"
)
label_dict["custom_split_filenames"] = {
"train": label_custom_split_filenames["train"],
"valid": label_custom_split_filenames["valid"],
Expand All @@ -76,6 +81,11 @@ def _convert_label(labels: list[dict]) -> list[dict]:
if "separator" in label:
label_sep = label["separator"]
label_dict["separator"] = label_sep
# Not supported for multi-task config for GSProcessing
assert "mask_field_names" not in label, (
"GSProcessing currently do not support to "
"construct labels for multi-task learning"
jalencato marked this conversation as resolved.
Show resolved Hide resolved
)
labels_list.append(label_dict)
except KeyError as exc:
raise KeyError(f"A required key was missing from label input {label}") from exc
Expand Down Expand Up @@ -103,6 +113,11 @@ def _convert_feature(feats: list[Mapping[str, Any]]) -> list[dict]:
gsp_feat_dict["column"] = gconstruct_feat_dict["feature_col"]
elif isinstance(gconstruct_feat_dict["feature_col"], list):
gsp_feat_dict["column"] = gconstruct_feat_dict["feature_col"][0]
if len(gconstruct_feat_dict["feature_col"]) >= 2:
assert "feature_name" in gconstruct_feat_dict, (
"feature_name should be in the gconstruct "
"feature field when feature_col is a list"
)
if "feature_name" in gconstruct_feat_dict:
gsp_feat_dict["name"] = gconstruct_feat_dict["feature_name"]

Expand Down Expand Up @@ -183,9 +198,10 @@ def _convert_feature(feats: list[Mapping[str, Any]]) -> list[dict]:
gsp_transformation_dict["name"] = "no-op"

if "out_dtype" in gconstruct_feat_dict:
assert (
gconstruct_feat_dict["out_dtype"] == "float32"
), "GSProcessing currently only supports float32 features"
assert gconstruct_feat_dict["out_dtype"] in (
"float32",
"float64",
jalencato marked this conversation as resolved.
Show resolved Hide resolved
), "GSProcessing currently only supports float32 or float64 features"

gsp_feat_dict["transformation"] = gsp_transformation_dict
gsp_feats_list.append(gsp_feat_dict)
Expand All @@ -200,6 +216,10 @@ def convert_nodes(nodes_entries):
node_type, node_col = n["node_type"], n["node_id_col"]
# format
node_format = n["format"]["name"]
assert node_format in (
"parquet",
"csv",
jalencato marked this conversation as resolved.
Show resolved Hide resolved
), "GSProcessing only supports parquet files and csv files."
if "separator" not in n["format"]:
node_separator = None
else:
Expand Down Expand Up @@ -249,6 +269,10 @@ def convert_edges(edges_entries):

# format
edge_format = e["format"]["name"]
assert edge_format in (
"parquet",
"csv",
jalencato marked this conversation as resolved.
Show resolved Hide resolved
), "GSProcessing only supports parquet files and csv files."
if "separator" not in e["format"]:
edge_separator = None
else:
Expand Down
Loading