Skip to content

Commit

Permalink
add
Browse files Browse the repository at this point in the history
  • Loading branch information
jalencato committed May 16, 2024
1 parent 780285f commit 35bbeb0
Showing 1 changed file with 26 additions and 0 deletions.
26 changes: 26 additions & 0 deletions graphstorm-processing/tests/test_dist_heterogenous_loader.py
Original file line number Diff line number Diff line change
Expand Up @@ -680,3 +680,29 @@ def test_update_label_properties_multilabel(

assert user_properties[COLUMN_NAME] == "multi"
assert user_properties[VALUE_COUNTS] == {str(i): 1 for i in range(1, 11)}


def test_custom_label(spark, user_df: DataFrame, dghl_loader: DistHeterogeneousGraphLoader, tmp_path):
data = [(i,) for i in range(1, 11)]

# Create DataFrame
nodes_df = spark.createDataFrame(data, ["id"])
nodes_df.show()

train_df = spark.createDataFrame([(i,) for i in range(1, 6)], ["mask_id"])
val_df = spark.createDataFrame([(i,) for i in range(6, 9)], ["mask_id"])
test_df = spark.createDataFrame([(i,) for i in range(9, 11)], ["mask_id"])

train_df.repartition(1).write.parquet(f"{tmp_path}/train.parquet")
val_df.repartition(1).write.parquet(f"{tmp_path}/val.parquet")
test_df.repartition(1).write.parquet(f"{tmp_path}/test.parquet")
config_dict = {
"column": "id",
"type": "classification",
"split_rate": {"train": 0.8, "val": 0.1, "test": 0.1},
"custom_split_filenames": {"train": f"{tmp_path}/train.parquet",
"valid": f"{tmp_path}/val.parquet",
"test": f"{tmp_path}/test.parquet",
"column": ["ID"]},
}
label_configs = [NodeLabelConfig(config_dict)]

0 comments on commit 35bbeb0

Please sign in to comment.