Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[GSProcessing] PR2 - Custom Data Split for MT #1039

Merged
merged 14 commits into from
Oct 9, 2024
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
"""

import abc
import logging
from typing import Any, Dict, Optional


Expand Down Expand Up @@ -42,6 +43,11 @@ def __init__(self, config_dict: Dict[str, Any]):
{"train": 0.8, "val": 0.1, "test": 0.1},
)
else:
if "split_rate" in config_dict:
logging.warning(
"custom_split_filenames and split_rate are set at the same time, "
"will do custom data split"
)
self._custom_split_filenames = config_dict["custom_split_filenames"]
if "mask_field_names" in config_dict:
self._mask_field_names: Optional[list[str]] = config_dict["mask_field_names"]
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -36,7 +36,7 @@
ArrayType,
ByteType,
)
from pyspark.sql.functions import col, when
from pyspark.sql.functions import col, when, monotonically_increasing_id
from numpy.random import default_rng

from graphstorm_processing.constants import (
Expand Down Expand Up @@ -73,6 +73,7 @@
DELIMITER = "" if FORMAT_NAME == "parquet" else ","
NODE_MAPPING_STR = "orig"
NODE_MAPPING_INT = "new"
CUSTOM_DATA_SPLIT_ORDER = "custom_split_order_flag"


@dataclass
Expand Down Expand Up @@ -1946,12 +1947,9 @@ def _create_split_files(
input_df, label_column, split_rates, seed, mask_field_names
)
else:
if mask_field_names:
raise NotImplementedError(
"Custom split files with custom mask field names currently not supported."
)

mask_dfs = self._create_split_files_custom_split(input_df, custom_split_file)
mask_dfs = self._create_split_files_custom_split(
input_df, custom_split_file, mask_field_names
)

def create_metadata_entry(path_list):
return {
Expand Down Expand Up @@ -2060,7 +2058,10 @@ def multinomial_sample(label_col: str) -> Sequence[int]:
return train_mask_df, val_mask_df, test_mask_df

def _create_split_files_custom_split(
self, input_df: DataFrame, custom_split_file: CustomSplit
self,
input_df: DataFrame,
custom_split_file: CustomSplit,
mask_field_names: Optional[tuple[str, str, str]] = None,
) -> tuple[DataFrame, DataFrame, DataFrame]:
"""
Creates the train/val/test mask dataframe based on custom split files.
Expand All @@ -2074,6 +2075,10 @@ def _create_split_files_custom_split(
training/validation/test.
mask_type: str
The type of mask to create, value can be train, val or test.
mask_field_names: Optional[tuple[str, str, str]]
An optional tuple of field names to use for the split masks.
If not provided, the default field names "train_mask",
"val_mask", and "test_mask" are used.

Returns
-------
Expand All @@ -2083,7 +2088,39 @@ def _create_split_files_custom_split(

# custom node/edge label
# create custom mask dataframe for one of the types: train, val, test
def process_custom_mask_df(input_df: DataFrame, split_file: CustomSplit, mask_type: str):
def process_custom_mask_df(
input_df: DataFrame, split_file: CustomSplit, mask_name: str, mask_type: str
):
jalencato marked this conversation as resolved.
Show resolved Hide resolved
"""
Creates the mask dataframe based on custom split files on one mask type.

Parameters
----------
input_df: DataFrame
Input dataframe for which we will add integer mapping.
split_file: CustomSplit
A CustomSplit object including path to the custom split files for
training/validation/test.
mask_name: str
Mask field name for the mask type.
mask_type: str
The type of mask to create, value can be train, val or test.
"""

def create_mapping(input_df):
"""
Creates the integer mapping for order maintaining.

Parameters
----------
input_df: DataFrame
Input dataframe for which we will add integer mapping.
"""
return_df = input_df.withColumn(
CUSTOM_DATA_SPLIT_ORDER, monotonically_increasing_id()
)
return return_df

if mask_type == "train":
file_path = split_file.train
elif mask_type == "val":
Expand All @@ -2093,22 +2130,26 @@ def process_custom_mask_df(input_df: DataFrame, split_file: CustomSplit, mask_ty
else:
raise ValueError("Unknown mask type")

# Custom data split should only be considered
# in cases with a limited number of labels.
if len(split_file.mask_columns) == 1:
# custom split on node original id
custom_mask_df = self.spark.read.parquet(
os.path.join(self.input_prefix, file_path)
).select(col(split_file.mask_columns[0]).alias(f"custom_{mask_type}_mask"))
mask_df = input_df.join(
input_df_id = create_mapping(input_df)
mask_df = input_df_id.join(
custom_mask_df,
input_df[NODE_MAPPING_STR] == custom_mask_df[f"custom_{mask_type}_mask"],
input_df_id[NODE_MAPPING_STR] == custom_mask_df[f"custom_{mask_type}_mask"],
"left_outer",
)
mask_df = mask_df.orderBy(CUSTOM_DATA_SPLIT_ORDER)
mask_df = mask_df.select(
"*",
when(mask_df[f"custom_{mask_type}_mask"].isNotNull(), 1)
.otherwise(0)
.alias(f"{mask_type}_mask"),
).select(f"{mask_type}_mask")
.alias(mask_name),
).select(mask_name)
elif len(split_file.mask_columns) == 2:
# custom split on edge (srd, dst) original ids
custom_mask_df = self.spark.read.parquet(
Expand All @@ -2117,10 +2158,12 @@ def process_custom_mask_df(input_df: DataFrame, split_file: CustomSplit, mask_ty
col(split_file.mask_columns[0]).alias(f"custom_{mask_type}_mask_src"),
col(split_file.mask_columns[1]).alias(f"custom_{mask_type}_mask_dst"),
)
input_df_id = create_mapping(input_df)
join_condition = (
input_df["src_str_id"] == custom_mask_df[f"custom_{mask_type}_mask_src"]
) & (input_df["dst_str_id"] == custom_mask_df[f"custom_{mask_type}_mask_dst"])
mask_df = input_df.join(custom_mask_df, join_condition, "left_outer")
input_df_id["src_str_id"] == custom_mask_df[f"custom_{mask_type}_mask_src"]
) & (input_df_id["dst_str_id"] == custom_mask_df[f"custom_{mask_type}_mask_dst"])
mask_df = input_df_id.join(custom_mask_df, join_condition, "left_outer")
thvasilo marked this conversation as resolved.
Show resolved Hide resolved
mask_df = mask_df.orderBy(CUSTOM_DATA_SPLIT_ORDER)
mask_df = mask_df.select(
"*",
when(
Expand All @@ -2129,17 +2172,21 @@ def process_custom_mask_df(input_df: DataFrame, split_file: CustomSplit, mask_ty
1,
)
.otherwise(0)
.alias(f"{mask_type}_mask"),
).select(f"{mask_type}_mask")
.alias(mask_name),
).select(mask_name)
else:
raise ValueError("The number of column should be only 1 or 2.")

return mask_df

if mask_field_names:
mask_names = mask_field_names
else:
mask_names = ("train_mask", "val_mask", "test_mask")
train_mask_df, val_mask_df, test_mask_df = (
process_custom_mask_df(input_df, custom_split_file, "train"),
process_custom_mask_df(input_df, custom_split_file, "val"),
process_custom_mask_df(input_df, custom_split_file, "test"),
process_custom_mask_df(input_df, custom_split_file, mask_names[0], "train"),
process_custom_mask_df(input_df, custom_split_file, mask_names[1], "val"),
process_custom_mask_df(input_df, custom_split_file, mask_names[2], "test"),
)
return train_mask_df, val_mask_df, test_mask_df

Expand Down
Loading
Loading