From 96d0636108898d083725bbd7f98c51994bd4c5e4 Mon Sep 17 00:00:00 2001 From: JalenCato Date: Fri, 13 Dec 2024 15:04:04 +0000 Subject: [PATCH] msg --- .../dist_heterogeneous_loader.py | 19 +------------------ 1 file changed, 1 insertion(+), 18 deletions(-) diff --git a/graphstorm-processing/graphstorm_processing/graph_loaders/dist_heterogeneous_loader.py b/graphstorm-processing/graphstorm_processing/graph_loaders/dist_heterogeneous_loader.py index ec87a4dff..b0622ef8b 100644 --- a/graphstorm-processing/graphstorm_processing/graph_loaders/dist_heterogeneous_loader.py +++ b/graphstorm-processing/graphstorm_processing/graph_loaders/dist_heterogeneous_loader.py @@ -2016,20 +2016,6 @@ def _create_split_files_split_rates( tuple[DataFrame, DataFrame, DataFrame] Train/val/test mask DataFrames. """ - def create_mapping(input_df): - """ - Creates the integer mapping for order maintaining. - - Parameters - ---------- - input_df: DataFrame - Input dataframe for which we will add integer mapping. - """ - return_df = input_df.withColumn( - CUSTOM_DATA_SPLIT_ORDER, monotonically_increasing_id() - ) - return return_df - if split_rates is None: split_rates = SplitRates(train_rate=0.8, val_rate=0.1, test_rate=0.1) logging.info( @@ -2065,9 +2051,7 @@ def multinomial_sample(label_col: str) -> Sequence[int]: # Convert label col to string and apply UDF # to create one-hot vector indicating train/test/val membership input_col = F.col(label_column).astype("string") if label_column else F.lit("dummy") - input_df = create_mapping(input_df) - int_group_df = input_df.select(split_group(input_col).alias(group_col_name), - RANDOM_DATA_SPLIT_ORDER) + int_group_df = input_df.select(split_group(input_col).alias(group_col_name)) # We cache because we re-use this DF 3 times int_group_df.cache() @@ -2077,7 +2061,6 @@ def multinomial_sample(label_col: str) -> Sequence[int]: else: mask_names = ("train_mask", "val_mask", "test_mask") - int_group_df = int_group_df.orderBy(RANDOM_DATA_SPLIT_ORDER) train_mask_df = int_group_df.select(F.col(group_col_name)[0].alias(mask_names[0])) val_mask_df = int_group_df.select(F.col(group_col_name)[1].alias(mask_names[1])) test_mask_df = int_group_df.select(F.col(group_col_name)[2].alias(mask_names[2]))