Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[GSProcessing] Custom split for GSProcessing #827

Merged
merged 20 commits into from
May 21, 2024
Merged
9 changes: 9 additions & 0 deletions docs/source/gs-processing/developer/input-configuration.rst
Original file line number Diff line number Diff line change
Expand Up @@ -173,6 +173,15 @@ objects:
assign to the validation set [0.0, 1.0).
- ``test``: The percentage of the data with available labels to
assign to the test set [0.0, 1.0).
- ``custom_split_filenames`` (JSON object, optional): Specifies the customized
training/validation/test mask. Once it is defined, GSProcessing will ignore
the ``split_rate``.
- ``train``: Path of the training mask parquet file such that each line contains
the original ID for node tasks, or the pair [source_id, destination_id] for edge tasks.
- ``val``: Path of the validation mask parquet file such that each line contains
the original ID for node tasks, or the pair [source_id, destination_id] for edge tasks.
- ``test``: Path of the test mask parquet file such that each line contains
the original ID for node tasks, or the pair [source_id, destination_id] for edge tasks.

- ``features`` (List of JSON objects, optional)\ **:** Describes
the set of features for the current edge type. See the :ref:`features-object` section for details.
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@
limitations under the License.
"""

import math
from typing import Any
from collections.abc import Mapping

Expand Down Expand Up @@ -52,16 +53,25 @@ def _convert_label(labels: list[dict]) -> list[dict]:
label_column = label["label_col"] if "label_col" in label else ""
label_type = label["task_type"]
label_dict = {"column": label_column, "type": label_type}
if "split_pct" in label:
label_splitrate = label["split_pct"]
# check if split_pct is valid
assert (
sum(label_splitrate) <= 1.0
), "sum of the label split rate should be <=1.0"
label_dict["split_rate"] = {
"train": label_splitrate[0],
"val": label_splitrate[1],
"test": label_splitrate[2],
if "custom_split_filenames" not in label:
if "split_pct" in label:
label_splitrate = label["split_pct"]
# check if split_pct is valid
assert (
math.fsum(label_splitrate) == 1.0
), "sum of the label split rate should be ==1.0"
label_dict["split_rate"] = {
"train": label_splitrate[0],
"val": label_splitrate[1],
"test": label_splitrate[2],
}
else:
label_custom_split_filenames = label["custom_split_filenames"]
label_dict["custom_split_filenames"] = {
"train": label_custom_split_filenames["train"],
"valid": label_custom_split_filenames["valid"],
"test": label_custom_split_filenames["test"],
"column": label_custom_split_filenames["column"],
}
if "separator" in label:
label_sep = label["separator"]
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -30,19 +30,32 @@ def __init__(self, config_dict: Dict[str, Any]):
self._label_column = ""
assert config_dict["type"] == "link_prediction"
self._task_type: str = config_dict["type"]
self._split: Dict[str, float] = config_dict["split_rate"]
self._separator: str = config_dict["separator"] if "separator" in config_dict else None
self._multilabel = self._separator is not None
if "custom_split_filenames" not in config_dict:
self._split: Dict[str, float] = config_dict["split_rate"]
self._custom_split_filenames = None
else:
self._split = None
self._custom_split_filenames: Dict[str, str] = config_dict["custom_split_filenames"]

def _sanity_check(self):
if self._label_column == "":
assert self._task_type == "link_prediction", (
"When no label column is specified, the task type must be link_prediction, "
f"got {self._task_type}"
)
assert isinstance(self._task_type, str)
assert isinstance(self._split, dict)
assert isinstance(self._separator, str) if self._multilabel else self._separator is None
if "custom_split_filenames" not in self._config:
assert isinstance(self._task_type, str)
assert isinstance(self._split, dict)
assert isinstance(self._separator, str) if self._multilabel else self._separator is None
else:
assert isinstance(self._custom_split_filenames, dict)
assert "train" in self._custom_split_filenames
assert "valid" in self._custom_split_filenames
assert "test" in self._custom_split_filenames
assert "column" in self._custom_split_filenames
assert isinstance(self._separator, str) if self._multilabel else self._separator is None

@property
def label_column(self) -> str:
Expand Down Expand Up @@ -71,6 +84,11 @@ def multilabel(self) -> bool:
"""Whether the task is multilabel classification."""
return self._multilabel

@property
def custom_split_filenames(self) -> Dict[str, str]:
"""The config for custom split labels."""
return self._custom_split_filenames


class EdgeLabelConfig(LabelConfig):
"""Holds the configuration of an edge label.
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -55,6 +55,29 @@ def __post_init__(self) -> None:
)


@dataclass
class CustomSplit:
"""
Dataclass to hold the custom split for each of the train/val/test splits.

Parameters
----------
train : str
Path of the training mask parquet file.
valid : str
Path of the validation mask parquet file.
test : str
Path of the testing mask parquet file.
mask_columns : list[str]
List of columns that contain original string ids.
"""

train: str
valid: str
test: str
mask_columns: list[str]


class DistLabelLoader:
"""Used to transform label columns to conform to downstream GraphStorm expectations.

Expand Down
Loading
Loading