diff --git a/graphstorm-processing/graphstorm_processing/graph_loaders/dist_heterogeneous_loader.py b/graphstorm-processing/graphstorm_processing/graph_loaders/dist_heterogeneous_loader.py index 559cf3f5b7..668c36e0c0 100644 --- a/graphstorm-processing/graphstorm_processing/graph_loaders/dist_heterogeneous_loader.py +++ b/graphstorm-processing/graphstorm_processing/graph_loaders/dist_heterogeneous_loader.py @@ -36,7 +36,7 @@ ArrayType, ByteType, ) -from pyspark.sql.functions import col, when +from pyspark.sql.functions import col, when, monotonically_increasing_id from numpy.random import default_rng from graphstorm_processing.constants import ( @@ -73,6 +73,7 @@ DELIMITER = "" if FORMAT_NAME == "parquet" else "," NODE_MAPPING_STR = "orig" NODE_MAPPING_INT = "new" +CUSTOM_DATA_SPLIT_ORDER = "custom_split_order_flag" @dataclass @@ -2090,6 +2091,18 @@ def _create_split_files_custom_split( def process_custom_mask_df( input_df: DataFrame, split_file: CustomSplit, mask_name: str, mask_type: str ): + def create_mapping(input_df): + """ + Creates the integer mapping for order maintaining. + + Parameters + ---------- + input_df: DataFrame + Input dataframe for which we will add integer mapping. + """ + return_df = input_df.withColumn(CUSTOM_DATA_SPLIT_ORDER, monotonically_increasing_id()) + return return_df + if mask_type == "train": file_path = split_file.train elif mask_type == "val": @@ -2106,12 +2119,13 @@ def process_custom_mask_df( custom_mask_df = self.spark.read.parquet( os.path.join(self.input_prefix, file_path) ).select(col(split_file.mask_columns[0]).alias(f"custom_{mask_type}_mask")) - mask_df = input_df.join( + input_df_id = create_mapping(input_df) + mask_df = input_df_id.join( custom_mask_df, - input_df[NODE_MAPPING_STR] == custom_mask_df[f"custom_{mask_type}_mask"], + input_df_id[NODE_MAPPING_STR] == custom_mask_df[f"custom_{mask_type}_mask"], "left_outer", ) - mask_df = mask_df.orderBy(NODE_MAPPING_STR) + mask_df = mask_df.orderBy(CUSTOM_DATA_SPLIT_ORDER) mask_df = mask_df.select( "*", when(mask_df[f"custom_{mask_type}_mask"].isNotNull(), 1) @@ -2126,11 +2140,12 @@ def process_custom_mask_df( col(split_file.mask_columns[0]).alias(f"custom_{mask_type}_mask_src"), col(split_file.mask_columns[1]).alias(f"custom_{mask_type}_mask_dst"), ) + input_df_id = create_mapping(input_df) join_condition = ( - input_df["src_str_id"] == custom_mask_df[f"custom_{mask_type}_mask_src"] - ) & (input_df["dst_str_id"] == custom_mask_df[f"custom_{mask_type}_mask_dst"]) - mask_df = input_df.join(custom_mask_df, join_condition, "left_outer") - mask_df = mask_df.orderBy(NODE_MAPPING_STR) + input_df_id["src_str_id"] == custom_mask_df[f"custom_{mask_type}_mask_src"] + ) & (input_df_id["dst_str_id"] == custom_mask_df[f"custom_{mask_type}_mask_dst"]) + mask_df = input_df_id.join(custom_mask_df, join_condition, "left_outer") + mask_df = mask_df.orderBy(CUSTOM_DATA_SPLIT_ORDER) mask_df = mask_df.select( "*", when(