Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[GSProcessing] Custom split for GSProcessing #827

Merged
merged 20 commits into from
May 21, 2024
Merged
9 changes: 9 additions & 0 deletions docs/source/gs-processing/developer/input-configuration.rst
Original file line number Diff line number Diff line change
Expand Up @@ -173,6 +173,15 @@ objects:
assign to the validation set [0.0, 1.0).
- ``test``: The percentage of the data with available labels to
assign to the test set [0.0, 1.0).
- ``custom_split_filenames`` (JSON object, optional): Specifies the customized
training/validation/test mask. Once it is defined, GSProcessing will ignore
the ``split_rate``.
- ``train``: Path of the training mask parquet file such that each line contains
the original ID for node tasks, or the pair [source_id, destination_id] for edge tasks.
- ``val``: Path of the validation mask parquet file such that each line contains
the original ID for node tasks, or the pair [source_id, destination_id] for edge tasks.
- ``test``: Path of the test mask parquet file such that each line contains
the original ID for node tasks, or the pair [source_id, destination_id] for edge tasks.

- ``features`` (List of JSON objects, optional)\ **:** Describes
the set of features for the current edge type. See the :ref:`features-object` section for details.
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -52,16 +52,25 @@ def _convert_label(labels: list[dict]) -> list[dict]:
label_column = label["label_col"] if "label_col" in label else ""
label_type = label["task_type"]
label_dict = {"column": label_column, "type": label_type}
if "split_pct" in label:
label_splitrate = label["split_pct"]
# check if split_pct is valid
assert (
sum(label_splitrate) <= 1.0
), "sum of the label split rate should be <=1.0"
label_dict["split_rate"] = {
"train": label_splitrate[0],
"val": label_splitrate[1],
"test": label_splitrate[2],
if "custom_split_filenames" not in label:
if "split_pct" in label:
label_splitrate = label["split_pct"]
# check if split_pct is valid
assert (
sum(label_splitrate) <= 1.0
), "sum of the label split rate should be <=1.0"
label_dict["split_rate"] = {
"train": label_splitrate[0],
"val": label_splitrate[1],
"test": label_splitrate[2],
}
jalencato marked this conversation as resolved.
Show resolved Hide resolved
else:
label_custom_split_filenames = label["custom_split_filenames"]
label_dict["custom_split_filenames"] = {
"train": label_custom_split_filenames["train"],
"valid": label_custom_split_filenames["valid"],
"test": label_custom_split_filenames["test"],
"column": label_custom_split_filenames["column"],
}
if "separator" in label:
label_sep = label["separator"]
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -30,19 +30,32 @@ def __init__(self, config_dict: Dict[str, Any]):
self._label_column = ""
assert config_dict["type"] == "link_prediction"
self._task_type: str = config_dict["type"]
self._split: Dict[str, float] = config_dict["split_rate"]
self._separator: str = config_dict["separator"] if "separator" in config_dict else None
self._multilabel = self._separator is not None
if "custom_split_filenames" not in config_dict:
self._split: Dict[str, float] = config_dict["split_rate"]
self._custom_split_filenames = None
else:
self._split = None
self._custom_split_filenames: Dict[str, str] = config_dict["custom_split_filenames"]

def _sanity_check(self):
if self._label_column == "":
assert self._task_type == "link_prediction", (
"When no label column is specified, the task type must be link_prediction, "
f"got {self._task_type}"
)
assert isinstance(self._task_type, str)
assert isinstance(self._split, dict)
assert isinstance(self._separator, str) if self._multilabel else self._separator is None
if "custom_split_filenames" not in self._config:
assert isinstance(self._task_type, str)
assert isinstance(self._split, dict)
assert isinstance(self._separator, str) if self._multilabel else self._separator is None
else:
assert isinstance(self._custom_split_filenames, dict)
assert "train" in self._custom_split_filenames
assert "valid" in self._custom_split_filenames
assert "test" in self._custom_split_filenames
assert "column" in self._custom_split_filenames
assert isinstance(self._separator, str) if self._multilabel else self._separator is None

@property
def label_column(self) -> str:
Expand Down Expand Up @@ -71,6 +84,11 @@ def multilabel(self) -> bool:
"""Whether the task is multilabel classification."""
return self._multilabel

@property
def custom_split_filenames(self) -> Dict[str, str]:
"""The config for custom split labels."""
return self._custom_split_filenames


class EdgeLabelConfig(LabelConfig):
"""Holds the configuration of an edge label.
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -55,6 +55,18 @@ def __post_init__(self) -> None:
)


@dataclass
class CustomSplit:
"""
Dataclass to hold the custom split for each of the train/val/test splits.
"""
jalencato marked this conversation as resolved.
Show resolved Hide resolved

train: str
valid: str
test: str
column: list[str]
jalencato marked this conversation as resolved.
Show resolved Hide resolved


class DistLabelLoader:
"""Used to transform label columns to conform to downstream GraphStorm expectations.

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,7 @@
ArrayType,
ByteType,
)
from pyspark.sql.functions import col, when
from numpy.random import default_rng

from graphstorm_processing.constants import (
Expand All @@ -48,7 +49,7 @@
from ..config.label_config_base import LabelConfig
from ..config.feature_config_base import FeatureConfig
from ..data_transformations.dist_feature_transformer import DistFeatureTransformer
from ..data_transformations.dist_label_loader import DistLabelLoader, SplitRates
from ..data_transformations.dist_label_loader import DistLabelLoader, SplitRates, CustomSplit
from ..data_transformations import s3_utils, spark_utils
from .heterogeneous_graphloader import HeterogeneousGraphLoader

Expand Down Expand Up @@ -1063,8 +1064,21 @@ def _process_node_labels(
)
else:
split_rates = None
if label_conf.custom_split_filenames:
custom_split_filenames = CustomSplit(
train=label_conf.custom_split_filenames["train"],
valid=label_conf.custom_split_filenames["valid"],
test=label_conf.custom_split_filenames["test"],
column=label_conf.custom_split_filenames["column"],
)
else:
custom_split_filenames = None
label_split_dicts = self._create_split_files_from_rates(
nodes_df, label_conf.label_column, split_rates, split_masks_output_prefix
nodes_df,
label_conf.label_column,
split_rates,
split_masks_output_prefix,
custom_split_filenames,
)
node_type_label_metadata.update(label_split_dicts)

Expand Down Expand Up @@ -1523,11 +1537,23 @@ def _process_edge_labels(
)
else:
split_rates = None
if label_conf.custom_split_filenames:
custom_split_filenames = CustomSplit(
train=label_conf.custom_split_filenames["train"],
valid=label_conf.custom_split_filenames["valid"],
test=label_conf.custom_split_filenames["test"],
column=label_conf.custom_split_filenames["column"],
)
else:
custom_split_filenames = None
label_split_dicts = self._create_split_files_from_rates(
edges_df, label_conf.label_column, split_rates, split_masks_output_prefix
edges_df,
label_conf.label_column,
split_rates,
split_masks_output_prefix,
custom_split_filenames,
)
label_metadata_dicts.update(label_split_dicts)
# TODO: Support custom_split_filenames

return label_metadata_dicts

Expand Down Expand Up @@ -1607,6 +1633,7 @@ def _create_split_files_from_rates(
label_column: str,
split_rates: Optional[SplitRates],
output_path: str,
custom_split_file: Optional[CustomSplit] = None,
jalencato marked this conversation as resolved.
Show resolved Hide resolved
seed: Optional[int] = None,
) -> Dict:
"""
Expand Down Expand Up @@ -1637,37 +1664,99 @@ def _create_split_files_from_rates(
"""
# If the user did not provide a split rate we use a default
split_metadata = {}
if split_rates is None:
split_rates = SplitRates(train_rate=0.8, val_rate=0.1, test_rate=0.1)
else:
# TODO: add support for sums <= 1.0, useful for large-scale link prediction
if sum(split_rates.tolist()) != 1.0:
raise RuntimeError(f"Provided split rates do not sum to 1: {split_rates}")

split_list = split_rates.tolist()
logging.info(
"Creating split files for label column '%s' with split rates: %s",
label_column,
split_list,
)
if not custom_split_file:
if split_rates is None:
split_rates = SplitRates(train_rate=0.8, val_rate=0.1, test_rate=0.1)
else:
# TODO: add support for sums <= 1.0, useful for large-scale link prediction
if sum(split_rates.tolist()) != 1.0:
jalencato marked this conversation as resolved.
Show resolved Hide resolved
raise RuntimeError(f"Provided split rates do not sum to 1: {split_rates}")

rng = default_rng(seed=seed)
split_list = split_rates.tolist()
logging.info(
"Creating split files for label column '%s' with split rates: %s",
label_column,
split_list,
)

# We use multinomial sampling to create a one-hot
# vector indicating train/test/val membership
def multinomial_sample(label_col: str) -> Sequence[int]:
if label_col in {"", "None", "NaN", None}:
return [0, 0, 0]
return rng.multinomial(1, split_list).tolist()
rng = default_rng(seed=seed)

# We use multinomial sampling to create a one-hot
# vector indicating train/test/val membership
def multinomial_sample(label_col: str) -> Sequence[int]:
if label_col in {"", "None", "NaN", None}:
return [0, 0, 0]
return rng.multinomial(1, split_list).tolist()

group_col_name = "sample_boolean_mask" # TODO: Ensure uniqueness of column?

# TODO: Use PandasUDF and check if it is faster than UDF
split_group = F.udf(multinomial_sample, ArrayType(IntegerType()))
# Convert label col to string and apply UDF
# to create one-hot vector indicating train/test/val membership
input_col = F.col(label_column).astype("string") if label_column else F.lit("dummy")
int_group_df = input_df.select(split_group(input_col).alias(group_col_name))
train_mask_df = int_group_df.select(F.col(group_col_name)[0].alias("train_mask"))
val_mask_df = int_group_df.select(F.col(group_col_name)[1].alias("val_mask"))
test_mask_df = int_group_df.select(F.col(group_col_name)[2].alias("test_mask"))
else:
# custom node/edge label
# create custom mask dataframe for one of the types: train, val, test
def process_custom_mask_df(input_df, split_file, mask_type):
if mask_type == "train":
file_path = split_file.train
elif mask_type == "val":
file_path = split_file.valid
elif mask_type == "test":
file_path = split_file.test
else:
raise ValueError("Unknown mask type")

if len(split_file.column) == 1:
jalencato marked this conversation as resolved.
Show resolved Hide resolved
custom_mask_df = self.spark.read.parquet(
self.input_prefix + "/" + file_path
jalencato marked this conversation as resolved.
Show resolved Hide resolved
).select(col(custom_split_file.column[0]).alias(f"custom_{mask_type}_mask"))
mask_df = input_df.join(
jalencato marked this conversation as resolved.
Show resolved Hide resolved
custom_mask_df,
input_df[NODE_MAPPING_STR] == custom_mask_df[f"custom_{mask_type}_mask"],
"left_outer",
)
mask_df = mask_df.withColumn(
f"{mask_type}_mask",
when(mask_df[f"custom_{mask_type}_mask"].isNotNull(), 1).otherwise(0),
).select(f"{mask_type}_mask")
jalencato marked this conversation as resolved.
Show resolved Hide resolved
elif len(split_file.column) == 2:
jalencato marked this conversation as resolved.
Show resolved Hide resolved
custom_mask_df = self.spark.read.parquet(
self.input_prefix + "/" + file_path
).select(
col(custom_split_file.column[0]).alias(f"custom_{mask_type}_mask_src"),
col(custom_split_file.column[1]).alias(f"custom_{mask_type}_mask_dst"),
)
join_condition = (
input_df["src_str_id"] == custom_mask_df[f"custom_{mask_type}_mask_src"]
) & (input_df["dst_str_id"] == custom_mask_df[f"custom_{mask_type}_mask_dst"])
mask_df = input_df.join(custom_mask_df, join_condition, "left_outer")
mask_df = mask_df.withColumn(
f"{mask_type}_mask",
when(
(mask_df[f"custom_{mask_type}_mask_src"].isNotNull())
& (mask_df[f"custom_{mask_type}_mask_dst"].isNotNull()),
1,
).otherwise(0),
)
else:
raise ValueError(
"Only deal with node/edge label, "
"the number of column should be only 1 or 2."
)

group_col_name = "sample_boolean_mask" # TODO: Ensure uniqueness of column?
return mask_df

# TODO: Use PandasUDF and check if it is faster than UDF
split_group = F.udf(multinomial_sample, ArrayType(IntegerType()))
# Convert label col to string and apply UDF
# to create one-hot vector indicating train/test/val membership
input_col = F.col(label_column).astype("string") if label_column else F.lit("dummy")
int_group_df = input_df.select(split_group(input_col).alias(group_col_name))
train_mask_df, val_mask_df, test_mask_df = (
process_custom_mask_df(input_df, custom_split_file, "train"),
process_custom_mask_df(input_df, custom_split_file, "val"),
process_custom_mask_df(input_df, custom_split_file, "test"),
)

def create_metadata_entry(path_list):
return {"format": {"name": FORMAT_NAME, "delimiter": DELIMITER}, "data": path_list}
Expand All @@ -1679,15 +1768,12 @@ def write_mask(kind: str, mask_df: DataFrame) -> Sequence[str]:
)
return out_path_list

train_mask_df = int_group_df.select(F.col(group_col_name)[0].alias("train_mask"))
out_path_list = write_mask("train", train_mask_df)
split_metadata["train_mask"] = create_metadata_entry(out_path_list)

val_mask_df = int_group_df.select(F.col(group_col_name)[1].alias("val_mask"))
out_path_list = write_mask("val", val_mask_df)
split_metadata["val_mask"] = create_metadata_entry(out_path_list)

test_mask_df = int_group_df.select(F.col(group_col_name)[2].alias("test_mask"))
out_path_list = write_mask("test", test_mask_df)
split_metadata["test_mask"] = create_metadata_entry(out_path_list)

Expand Down
44 changes: 44 additions & 0 deletions graphstorm-processing/tests/test_converter.py
Original file line number Diff line number Diff line change
Expand Up @@ -166,6 +166,50 @@ def test_read_node_gconstruct(converter: GConstructConfigConverter, node_dict: d
}
]

node_dict["nodes"].append(
{
"node_type": "paper_custom",
"format": {"name": "parquet"},
"files": ["/tmp/acm_raw/nodes/paper_custom.parquet"],
"node_id_col": "node_id",
"labels": [
{
"label_col": "label",
"task_type": "classification",
"custom_split_filenames": {
"train": "customized_label/node_train_idx.parquet",
"valid": "customized_label/node_val_idx.parquet",
"test": "customized_label/node_test_idx.parquet",
"column": ["ID"],
},
"label_stats_type": "frequency_cnt",
}
],
}
)

# nodes with all elements
# [self.type, self.format, self.files, self.separator, self.column, self.features, self.labels]
node_config = converter.convert_nodes(node_dict["nodes"])[2]
assert len(converter.convert_nodes(node_dict["nodes"])) == 3
assert node_config.node_type == "paper_custom"
assert node_config.file_format == "parquet"
assert node_config.files == ["/tmp/acm_raw/nodes/paper_custom.parquet"]
assert node_config.separator is None
assert node_config.column == "node_id"
assert node_config.labels == [
{
"column": "label",
"type": "classification",
"custom_split_filenames": {
"train": "customized_label/node_train_idx.parquet",
"valid": "customized_label/node_val_idx.parquet",
"test": "customized_label/node_test_idx.parquet",
"column": ["ID"],
},
}
]


@pytest.mark.parametrize("col_name", ["author", ["author"]])
def test_read_edge_gconstruct(converter: GConstructConfigConverter, col_name):
Expand Down
Loading
Loading