Skip to content

Commit

Permalink
[GSProcessing] Ensure we can convert custom split GConstruct data con…
Browse files Browse the repository at this point in the history
…figs whether they are single str or list[str]

NOTE: GSProcessing now requires list[str] for its custom split file config input.
  • Loading branch information
thvasilo committed Dec 19, 2024
1 parent de94bb2 commit dba7f4d
Show file tree
Hide file tree
Showing 11 changed files with 284 additions and 56 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -173,15 +173,23 @@ objects:
assign to the validation set [0.0, 1.0).
- ``test``: The percentage of the data with available labels to
assign to the test set [0.0, 1.0).
- ``custom_split_filenames`` (JSON object, optional): Specifies the customized
training/validation/test mask. Once it is defined, GSProcessing will ignore
the ``split_rate``.
- ``train``: Path of the training mask parquet file such that each line contains
the original ID for node tasks, or the pair [source_id, destination_id] for edge tasks.
- ``val``: Path of the validation mask parquet file such that each line contains
the original ID for node tasks, or the pair [source_id, destination_id] for edge tasks.
- ``test``: Path of the test mask parquet file such that each line contains
the original ID for node tasks, or the pair [source_id, destination_id] for edge tasks.
- ``custom_split_filenames`` (JSON object, optional): Specifies pre-assigned
training/validation/test masks. If defined, GSProcessing will ignore
``split_rate`` if provided.

- ``column``: (List[String], optional) A list of length one for node splits, or two for edge splits,
containing the name(s) of the column(s) that contain the node/edge ids for each split. For example,
if the node ids to include in each split exist in column ``"nid"`` of the custom train/val/test files, this
needs to be ``["nid"]``. For edges it would be a value like ``["src_id", "dst_id"]``.
If not provided for nodes we assume the first column in the data contains the node ids to include.
For edges, we assume the first column is the source id and the second the destination id.
- ``train``: (List[String], optional) Paths of the training mask parquet file such that each line contains
the original ID for node tasks, or the pair ``[source_id, destination_id]`` for edge tasks.
- ``val``: (List[String], optional) Paths of the validation mask parquet file such that each line contains
the original ID for node tasks, or the pair ``[source_id, destination_id]`` for edge tasks.
- ``test``: (List[String], optional) Paths of the test mask parquet file such that each line contains
the original ID for node tasks, or the pair ``[source_id, destination_id]`` for edge tasks.
- Note: At least one of the ``["train", "val", "test"]`` keys must be present.

- ``features`` (List of JSON objects, optional)\ **:** Describes
the set of features for the current edge type. See the :ref:`features-object` section for details.
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -50,7 +50,7 @@ def _convert_label(labels: list[dict]) -> list[dict]:
labels_list = []
if labels in [[], [{}]]:
return []
for label in labels:
for label in labels: # pylint: disable=too-many-nested-blocks
try:
label_column = label["label_col"] if "label_col" in label else ""
label_type = label["task_type"]
Expand All @@ -61,19 +61,37 @@ def _convert_label(labels: list[dict]) -> list[dict]:
# check if split_pct is valid
assert (
math.fsum(label_splitrate) == 1.0
), "sum of the label split rate should be ==1.0"
), f"sum of the label split rate should be ==1.0, got {label_splitrate=}"
label_dict["split_rate"] = {
"train": label_splitrate[0],
"val": label_splitrate[1],
"test": label_splitrate[2],
}
else:
label_custom_split_filenames = label["custom_split_filenames"]
if isinstance(label_custom_split_filenames["column"], list):
assert len(label_custom_split_filenames["column"]) <= 2, (
"Custom split filenames should have one column for node labels, "
"and two columns for edges labels exactly"
)
# Ensure at least one of ["train", "valid", "test"] is in the keys
assert any(
x in label_custom_split_filenames.keys() for x in ["train", "valid", "test"]
), ("At least one of ['train', 'valid', 'test'] "
"needs to exist in custom split configs.")

# Fill in missing values if needed
for entry in ["train", "valid", "test", "column"]:
entry_val = label_custom_split_filenames.get(entry, None)
if entry_val:
if isinstance(entry_val, str):
label_custom_split_filenames[entry] = [entry_val]
else:
assert isinstance(
entry_val, list
), "Custom split filenames should be a string or a list of strings"
else:
label_custom_split_filenames[entry] = []
assert len(label_custom_split_filenames["column"]) <= 2, (
"Custom split filenames should have one column for node labels, "
"and two columns for edges labels exactly"
)

label_dict["custom_split_filenames"] = {
"train": label_custom_split_filenames["train"],
"valid": label_custom_split_filenames["valid"],
Expand Down Expand Up @@ -213,7 +231,7 @@ def _convert_feature(feats: list[Mapping[str, Any]]) -> list[dict]:
return gsp_feats_list

@staticmethod
def convert_nodes(nodes_entries):
def convert_nodes(nodes_entries) -> list[NodeConfig]:
res = []
for n in nodes_entries:
# type, column id
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -194,6 +194,8 @@ def __init__(self, edge_dict: Dict[str, Dict], data_dict: Dict[str, Any]):
self._dst_ntype = edge_dict["dest"]["type"]
self._rel_type = edge_dict["relation"]["type"]
self._rel_col: Optional[str] = edge_dict["relation"].get("column", None)
self._feature_configs = []
self._labels = []

if "features" in edge_dict:
for feature_dict in edge_dict["features"]:
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -60,17 +60,24 @@ def _sanity_check(self):
"When no label column is specified, the task type must be link_prediction, "
f"got {self._task_type}"
)
# Sanity checks for random and custom splits
if "custom_split_filenames" not in self._config:
assert isinstance(self._task_type, str)
assert isinstance(self._split, dict)
assert isinstance(self._separator, str) if self._multilabel else self._separator is None
else:
assert isinstance(self._custom_split_filenames, dict)
assert "train" in self._custom_split_filenames
assert "valid" in self._custom_split_filenames
assert "test" in self._custom_split_filenames
assert "column" in self._custom_split_filenames
assert isinstance(self._separator, str) if self._multilabel else self._separator is None

# Ensure at least one of ["train", "valid", "test"] is in the keys
assert any(x in self._custom_split_filenames.keys() for x in ["train", "valid", "test"])

for entry in ["train", "valid", "test", "column"]:
# All existing entries must be lists of strings
if entry in self._custom_split_filenames:
assert isinstance(self._custom_split_filenames[entry], list)
assert all(isinstance(x, str) for x in self._custom_split_filenames[entry])

assert isinstance(self._separator, str) if self._multilabel else self._separator is None

if self._mask_field_names:
assert isinstance(self._mask_field_names, list)
assert all(isinstance(x, str) for x in self._mask_field_names)
Expand Down Expand Up @@ -111,7 +118,11 @@ def custom_split_filenames(self) -> Dict[str, Any]:
@property
def mask_field_names(self) -> Optional[tuple[str, str, str]]:
"""Custom names to assign to masks for multi-task learning."""
return tuple(self._mask_field_names) if self._mask_field_names else None
if self._mask_field_names is None:
return None
else:
assert len(self._mask_field_names) == 3
return (self._mask_field_names[0], self._mask_field_names[1], self._mask_field_names[2])


class EdgeLabelConfig(LabelConfig):
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -73,19 +73,19 @@ class CustomSplit:
Parameters
----------
train : str
Path of the training mask parquet file.
train : list[str]
Paths of the training mask parquet files.
valid : str
Path of the validation mask parquet file.
Paths of the validation mask parquet files.
test : str
Path of the testing mask parquet file.
Paths of the testing mask parquet files.
mask_columns : list[str]
List of columns that contain original string ids.
"""

train: str
valid: str
test: str
train: list[str]
valid: list[str]
test: list[str]
mask_columns: list[str]


Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -177,7 +177,7 @@ def __init__(
assert os.path.isabs(loader_config.input_prefix), "We expect an absolute path"
self.filesystem_type = "local"

self.spark = spark # type: SparkSession
self.spark: SparkSession = spark
self.add_reverse_edges = loader_config.add_reverse_edges
# Remove trailing slash in s3 paths
if self.filesystem_type == "s3":
Expand Down Expand Up @@ -2131,11 +2131,11 @@ def create_mapping(input_df):
return return_df

if mask_type == "train":
file_path = split_file.train
file_paths = split_file.train
elif mask_type == "val":
file_path = split_file.valid
file_paths = split_file.valid
elif mask_type == "test":
file_path = split_file.test
file_paths = split_file.test
else:
raise ValueError("Unknown mask type")

Expand All @@ -2144,7 +2144,7 @@ def create_mapping(input_df):
if len(split_file.mask_columns) == 1:
# custom split on node original id
custom_mask_df = self.spark.read.parquet(
os.path.join(self.input_prefix, file_path)
*[os.path.join(self.input_prefix, file_path) for file_path in file_paths]
).select(col(split_file.mask_columns[0]).alias(f"custom_{mask_type}_mask"))
input_df_id = create_mapping(input_df)
mask_df = input_df_id.join(
Expand All @@ -2162,7 +2162,7 @@ def create_mapping(input_df):
elif len(split_file.mask_columns) == 2:
# custom split on edge (srd, dst) original ids
custom_mask_df = self.spark.read.parquet(
os.path.join(self.input_prefix, file_path)
*[os.path.join(self.input_prefix, file_path) for file_path in file_paths]
).select(
col(split_file.mask_columns[0]).alias(f"custom_{mask_type}_mask_src"),
col(split_file.mask_columns[1]).alias(f"custom_{mask_type}_mask_dst"),
Expand All @@ -2184,7 +2184,10 @@ def create_mapping(input_df):
.alias(mask_name),
).select(mask_name)
else:
raise ValueError("The number of column should be only 1 or 2.")
raise ValueError(
"The number of column should be only 1 or 2, got columns: "
f"{split_file.mask_columns}"
)

return mask_df

Expand Down
1 change: 1 addition & 0 deletions graphstorm-processing/pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -41,6 +41,7 @@ pre-commit = "^3.3.3"
types-mock = "^5.1.0.1"
pylint = "~2.17.5"
diff-cover = "^9.0.0"
pytest-cov = "^6.0.0"

[project]
requires-python = ">=3.9" # TODO: Do we need a tilde here?
Expand Down
34 changes: 29 additions & 5 deletions graphstorm-processing/tests/test_converter.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@
@pytest.fixture(name="converter")
def fixture_create_converter() -> GConstructConfigConverter:
"""Creates a new converter object for each test."""
yield GConstructConfigConverter()
return GConstructConfigConverter()


@pytest.fixture(name="node_dict")
Expand Down Expand Up @@ -65,8 +65,32 @@ def test_try_read_unsupported_feature(converter: GConstructConfigConverter, node
_ = converter.convert_nodes(node_dict["nodes"])


def test_custom_split_config_conversion(converter: GConstructConfigConverter):
"""Test custom split file config conversion"""
gconstruct_label_dicts = [
{
"label_col": "label",
"task_type": "classification",
"custom_split_filenames": {
"column": ["src", "dst"],
},
"label_stats_type": "frequency_cnt",
}
]

# Should raise when none of train/val/test are in the input
with pytest.raises(AssertionError):
converter._convert_label(gconstruct_label_dicts)

# Ensure single strings are converted to list of strings
gconstruct_label_dicts[0]["custom_split_filenames"]["train"] = "fake_file"
gsprocessing_label_dict = converter._convert_label(gconstruct_label_dicts)[0]

assert gsprocessing_label_dict["custom_split_filenames"]["train"] == ["fake_file"]


def test_try_read_invalid_gconstruct_config(converter: GConstructConfigConverter, node_dict: dict):
"""Custom Split Columns"""
"""Test various invalid input scenarios"""
node_dict["nodes"][0]["labels"] = [
{
"label_col": "label",
Expand Down Expand Up @@ -238,9 +262,9 @@ def test_read_node_gconstruct(converter: GConstructConfigConverter, node_dict: d
"column": "label",
"type": "classification",
"custom_split_filenames": {
"train": "customized_label/node_train_idx.parquet",
"valid": "customized_label/node_val_idx.parquet",
"test": "customized_label/node_test_idx.parquet",
"train": ["customized_label/node_train_idx.parquet"],
"valid": ["customized_label/node_val_idx.parquet"],
"test": ["customized_label/node_test_idx.parquet"],
"column": ["ID"],
},
}
Expand Down
24 changes: 12 additions & 12 deletions graphstorm-processing/tests/test_dist_heterogenous_loader.py
Original file line number Diff line number Diff line change
Expand Up @@ -955,9 +955,9 @@ def test_node_custom_label(spark, dghl_loader: DistHeterogeneousGraphLoader, tmp
"type": "classification",
"split_rate": {"train": 0.8, "val": 0.1, "test": 0.1},
"custom_split_filenames": {
"train": f"{tmp_path}/train.parquet",
"valid": f"{tmp_path}/val.parquet",
"test": f"{tmp_path}/test.parquet",
"train": [f"{tmp_path}/train.parquet"],
"valid": [f"{tmp_path}/val.parquet"],
"test": [f"{tmp_path}/test.parquet"],
"column": ["mask_id"],
},
}
Expand Down Expand Up @@ -1010,9 +1010,9 @@ def test_edge_custom_label(spark, dghl_loader: DistHeterogeneousGraphLoader, tmp
"column": "",
"type": "link_prediction",
"custom_split_filenames": {
"train": f"{tmp_path}/train.parquet",
"valid": f"{tmp_path}/val.parquet",
"test": f"{tmp_path}/test.parquet",
"train": [f"{tmp_path}/train.parquet"],
"valid": [f"{tmp_path}/val.parquet"],
"test": [f"{tmp_path}/test.parquet"],
"column": ["mask_src_id", "mask_dst_id"],
},
}
Expand Down Expand Up @@ -1061,9 +1061,9 @@ def test_node_custom_label_multitask(spark, dghl_loader: DistHeterogeneousGraphL
"type": "classification",
"split_rate": {"train": 0.8, "val": 0.1, "test": 0.1},
"custom_split_filenames": {
"train": f"{tmp_path}/train.parquet",
"valid": f"{tmp_path}/val.parquet",
"test": f"{tmp_path}/test.parquet",
"train": [f"{tmp_path}/train.parquet"],
"valid": [f"{tmp_path}/val.parquet"],
"test": [f"{tmp_path}/test.parquet"],
"column": ["mask_id"],
},
"mask_field_names": class_mask_names,
Expand Down Expand Up @@ -1173,9 +1173,9 @@ def test_edge_custom_label_multitask(spark, dghl_loader: DistHeterogeneousGraphL
"column": "",
"type": "link_prediction",
"custom_split_filenames": {
"train": f"{tmp_path}/train.parquet",
"valid": f"{tmp_path}/val.parquet",
"test": f"{tmp_path}/test.parquet",
"train": [f"{tmp_path}/train.parquet"],
"valid": [f"{tmp_path}/val.parquet"],
"test": [f"{tmp_path}/test.parquet"],
"column": ["mask_src_id", "mask_dst_id"],
},
"mask_field_names": class_mask_names,
Expand Down
Loading

0 comments on commit dba7f4d

Please sign in to comment.