Skip to content

Commit

Permalink
add test
Browse files Browse the repository at this point in the history
  • Loading branch information
jalencato committed Oct 3, 2024
1 parent 73ffd69 commit 393a126
Showing 1 changed file with 116 additions and 0 deletions.
116 changes: 116 additions & 0 deletions graphstorm-processing/tests/test_dist_heterogenous_loader.py
Original file line number Diff line number Diff line change
Expand Up @@ -1033,3 +1033,119 @@ def test_edge_custom_label(spark, dghl_loader: DistHeterogeneousGraphLoader, tmp
assert train_total_ones == 3
assert val_total_ones == 3
assert test_total_ones == 3


def test_node_custom_label_multitask(spark, dghl_loader: DistHeterogeneousGraphLoader, tmp_path):
"""Test using custom label splits for nodes"""
data = [(i,) for i in range(1, 11)]

# Create DataFrame
nodes_df = spark.createDataFrame(data, ["orig"])

train_df = spark.createDataFrame([(i,) for i in range(1, 6)], ["mask_id"])
val_df = spark.createDataFrame([(i,) for i in range(6, 9)], ["mask_id"])
test_df = spark.createDataFrame([(i,) for i in range(9, 11)], ["mask_id"])

train_df.repartition(1).write.parquet(f"{tmp_path}/train.parquet")
val_df.repartition(1).write.parquet(f"{tmp_path}/val.parquet")
test_df.repartition(1).write.parquet(f"{tmp_path}/test.parquet")
class_mask_names = [
f"custom_split_train_mask",
f"custom_split_val_mask",
f"custom_split_test_mask",
]
config_dict = {
"column": "orig",
"type": "classification",
"split_rate": {"train": 0.8, "val": 0.1, "test": 0.1},
"custom_split_filenames": {
"train": f"{tmp_path}/train.parquet",
"valid": f"{tmp_path}/val.parquet",
"test": f"{tmp_path}/test.parquet",
"column": ["mask_id"],
},
"mask_field_names": class_mask_names,
}
dghl_loader.input_prefix = ""
label_configs = [NodeLabelConfig(config_dict)]
label_metadata_dicts = dghl_loader._process_node_labels(label_configs, nodes_df, "orig")

assert label_metadata_dicts.keys() == {
"custom_split_train_mask",
"custom_split_val_mask",
"custom_split_test_mask",
"orig",
}

train_mask_df, val_mask_df, test_mask_df = read_masks_from_disk(
spark, dghl_loader, label_metadata_dicts, class_mask_names
)

train_total_ones = train_mask_df.agg(F.sum("custom_split_train_mask")).collect()[0][0]
val_total_ones = val_mask_df.agg(F.sum("custom_split_val_mask")).collect()[0][0]
test_total_ones = test_mask_df.agg(F.sum("custom_split_test_mask")).collect()[0][0]
assert train_total_ones == 5
assert val_total_ones == 3
assert test_total_ones == 2


def test_edge_custom_label_multitask(spark, dghl_loader: DistHeterogeneousGraphLoader, tmp_path):
"""Test using custom label splits for edges"""
data = [(i, j) for i in range(1, 4) for j in range(11, 14)]
# Create DataFrame
edges_df = spark.createDataFrame(data, ["src_str_id", "dst_str_id"])

train_df = spark.createDataFrame(
[(i, j) for i in range(1, 2) for j in range(11, 14)],
["mask_src_id", "mask_dst_id"],
)
val_df = spark.createDataFrame(
[(i, j) for i in range(2, 3) for j in range(11, 14)],
["mask_src_id", "mask_dst_id"],
)
test_df = spark.createDataFrame(
[(i, j) for i in range(3, 4) for j in range(11, 14)],
["mask_src_id", "mask_dst_id"],
)

class_mask_names = [
f"custom_split_train_mask",
f"custom_split_val_mask",
f"custom_split_test_mask",
]
train_df.repartition(1).write.parquet(f"{tmp_path}/train.parquet")
val_df.repartition(1).write.parquet(f"{tmp_path}/val.parquet")
test_df.repartition(1).write.parquet(f"{tmp_path}/test.parquet")
config_dict = {
"column": "",
"type": "link_prediction",
"custom_split_filenames": {
"train": f"{tmp_path}/train.parquet",
"valid": f"{tmp_path}/val.parquet",
"test": f"{tmp_path}/test.parquet",
"column": ["mask_src_id", "mask_dst_id"],
},
"mask_field_names": class_mask_names,
}
dghl_loader.input_prefix = ""
label_configs = [EdgeLabelConfig(config_dict)]
label_metadata_dicts = dghl_loader._process_edge_labels(
label_configs, edges_df, "src_str_id:to:dst_str_id", ""
)

assert label_metadata_dicts.keys() == {
"custom_split_train_mask",
"custom_split_val_mask",
"custom_split_test_mask",
}

train_mask_df, val_mask_df, test_mask_df = read_masks_from_disk(
spark, dghl_loader, label_metadata_dicts, class_mask_names
)

train_total_ones = train_mask_df.agg(F.sum("custom_split_train_mask")).collect()[0][0]
val_total_ones = val_mask_df.agg(F.sum("custom_split_val_mask")).collect()[0][0]
test_total_ones = test_mask_df.agg(F.sum("custom_split_test_mask")).collect()[0][0]
assert train_total_ones == 3
assert val_total_ones == 3
assert test_total_ones == 3

0 comments on commit 393a126

Please sign in to comment.