Skip to content

Commit

Permalink
fix order
Browse files Browse the repository at this point in the history
  • Loading branch information
jalencato committed Oct 2, 2024
1 parent 0d03e80 commit aee3537
Showing 1 changed file with 23 additions and 8 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -36,7 +36,7 @@
ArrayType,
ByteType,
)
from pyspark.sql.functions import col, when
from pyspark.sql.functions import col, when, monotonically_increasing_id
from numpy.random import default_rng

from graphstorm_processing.constants import (
Expand Down Expand Up @@ -73,6 +73,7 @@
DELIMITER = "" if FORMAT_NAME == "parquet" else ","
NODE_MAPPING_STR = "orig"
NODE_MAPPING_INT = "new"
CUSTOM_DATA_SPLIT_ORDER = "custom_split_order_flag"


@dataclass
Expand Down Expand Up @@ -2090,6 +2091,18 @@ def _create_split_files_custom_split(
def process_custom_mask_df(
input_df: DataFrame, split_file: CustomSplit, mask_name: str, mask_type: str
):
def create_mapping(input_df):
"""
Creates the integer mapping for order maintaining.
Parameters
----------
input_df: DataFrame
Input dataframe for which we will add integer mapping.
"""
return_df = input_df.withColumn(CUSTOM_DATA_SPLIT_ORDER, monotonically_increasing_id())
return return_df

if mask_type == "train":
file_path = split_file.train
elif mask_type == "val":
Expand All @@ -2106,12 +2119,13 @@ def process_custom_mask_df(
custom_mask_df = self.spark.read.parquet(
os.path.join(self.input_prefix, file_path)
).select(col(split_file.mask_columns[0]).alias(f"custom_{mask_type}_mask"))
mask_df = input_df.join(
input_df_id = create_mapping(input_df)
mask_df = input_df_id.join(
custom_mask_df,
input_df[NODE_MAPPING_STR] == custom_mask_df[f"custom_{mask_type}_mask"],
input_df_id[NODE_MAPPING_STR] == custom_mask_df[f"custom_{mask_type}_mask"],
"left_outer",
)
mask_df = mask_df.orderBy(NODE_MAPPING_STR)
mask_df = mask_df.orderBy(CUSTOM_DATA_SPLIT_ORDER)
mask_df = mask_df.select(
"*",
when(mask_df[f"custom_{mask_type}_mask"].isNotNull(), 1)
Expand All @@ -2126,11 +2140,12 @@ def process_custom_mask_df(
col(split_file.mask_columns[0]).alias(f"custom_{mask_type}_mask_src"),
col(split_file.mask_columns[1]).alias(f"custom_{mask_type}_mask_dst"),
)
input_df_id = create_mapping(input_df)
join_condition = (
input_df["src_str_id"] == custom_mask_df[f"custom_{mask_type}_mask_src"]
) & (input_df["dst_str_id"] == custom_mask_df[f"custom_{mask_type}_mask_dst"])
mask_df = input_df.join(custom_mask_df, join_condition, "left_outer")
mask_df = mask_df.orderBy(NODE_MAPPING_STR)
input_df_id["src_str_id"] == custom_mask_df[f"custom_{mask_type}_mask_src"]
) & (input_df_id["dst_str_id"] == custom_mask_df[f"custom_{mask_type}_mask_dst"])
mask_df = input_df_id.join(custom_mask_df, join_condition, "left_outer")
mask_df = mask_df.orderBy(CUSTOM_DATA_SPLIT_ORDER)
mask_df = mask_df.select(
"*",
when(
Expand Down

0 comments on commit aee3537

Please sign in to comment.