Skip to content

Commit

Permalink
msg
Browse files Browse the repository at this point in the history
  • Loading branch information
jalencato committed Dec 13, 2024
1 parent 10492a5 commit 96d0636
Showing 1 changed file with 1 addition and 18 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -2016,20 +2016,6 @@ def _create_split_files_split_rates(
tuple[DataFrame, DataFrame, DataFrame]
Train/val/test mask DataFrames.
"""
def create_mapping(input_df):
"""
Creates the integer mapping for order maintaining.
Parameters
----------
input_df: DataFrame
Input dataframe for which we will add integer mapping.
"""
return_df = input_df.withColumn(
CUSTOM_DATA_SPLIT_ORDER, monotonically_increasing_id()
)
return return_df

if split_rates is None:
split_rates = SplitRates(train_rate=0.8, val_rate=0.1, test_rate=0.1)
logging.info(
Expand Down Expand Up @@ -2065,9 +2051,7 @@ def multinomial_sample(label_col: str) -> Sequence[int]:
# Convert label col to string and apply UDF
# to create one-hot vector indicating train/test/val membership
input_col = F.col(label_column).astype("string") if label_column else F.lit("dummy")
input_df = create_mapping(input_df)
int_group_df = input_df.select(split_group(input_col).alias(group_col_name),
RANDOM_DATA_SPLIT_ORDER)
int_group_df = input_df.select(split_group(input_col).alias(group_col_name))

# We cache because we re-use this DF 3 times
int_group_df.cache()
Expand All @@ -2077,7 +2061,6 @@ def multinomial_sample(label_col: str) -> Sequence[int]:
else:
mask_names = ("train_mask", "val_mask", "test_mask")

int_group_df = int_group_df.orderBy(RANDOM_DATA_SPLIT_ORDER)
train_mask_df = int_group_df.select(F.col(group_col_name)[0].alias(mask_names[0]))
val_mask_df = int_group_df.select(F.col(group_col_name)[1].alias(mask_names[1]))
test_mask_df = int_group_df.select(F.col(group_col_name)[2].alias(mask_names[2]))
Expand Down

0 comments on commit 96d0636

Please sign in to comment.