diff --git a/graphstorm-processing/tests/test_dist_heterogenous_loader.py b/graphstorm-processing/tests/test_dist_heterogenous_loader.py index 401fe6cd41..e64fbbf506 100644 --- a/graphstorm-processing/tests/test_dist_heterogenous_loader.py +++ b/graphstorm-processing/tests/test_dist_heterogenous_loader.py @@ -1033,3 +1033,119 @@ def test_edge_custom_label(spark, dghl_loader: DistHeterogeneousGraphLoader, tmp assert train_total_ones == 3 assert val_total_ones == 3 assert test_total_ones == 3 + + +def test_node_custom_label_multitask(spark, dghl_loader: DistHeterogeneousGraphLoader, tmp_path): + """Test using custom label splits for nodes""" + data = [(i,) for i in range(1, 11)] + + # Create DataFrame + nodes_df = spark.createDataFrame(data, ["orig"]) + + train_df = spark.createDataFrame([(i,) for i in range(1, 6)], ["mask_id"]) + val_df = spark.createDataFrame([(i,) for i in range(6, 9)], ["mask_id"]) + test_df = spark.createDataFrame([(i,) for i in range(9, 11)], ["mask_id"]) + + train_df.repartition(1).write.parquet(f"{tmp_path}/train.parquet") + val_df.repartition(1).write.parquet(f"{tmp_path}/val.parquet") + test_df.repartition(1).write.parquet(f"{tmp_path}/test.parquet") + class_mask_names = [ + f"custom_split_train_mask", + f"custom_split_val_mask", + f"custom_split_test_mask", + ] + config_dict = { + "column": "orig", + "type": "classification", + "split_rate": {"train": 0.8, "val": 0.1, "test": 0.1}, + "custom_split_filenames": { + "train": f"{tmp_path}/train.parquet", + "valid": f"{tmp_path}/val.parquet", + "test": f"{tmp_path}/test.parquet", + "column": ["mask_id"], + }, + "mask_field_names": class_mask_names, + } + dghl_loader.input_prefix = "" + label_configs = [NodeLabelConfig(config_dict)] + label_metadata_dicts = dghl_loader._process_node_labels(label_configs, nodes_df, "orig") + + assert label_metadata_dicts.keys() == { + "custom_split_train_mask", + "custom_split_val_mask", + "custom_split_test_mask", + "orig", + } + + train_mask_df, val_mask_df, test_mask_df = read_masks_from_disk( + spark, dghl_loader, label_metadata_dicts, class_mask_names + ) + + train_total_ones = train_mask_df.agg(F.sum("custom_split_train_mask")).collect()[0][0] + val_total_ones = val_mask_df.agg(F.sum("custom_split_val_mask")).collect()[0][0] + test_total_ones = test_mask_df.agg(F.sum("custom_split_test_mask")).collect()[0][0] + assert train_total_ones == 5 + assert val_total_ones == 3 + assert test_total_ones == 2 + + +def test_edge_custom_label_multitask(spark, dghl_loader: DistHeterogeneousGraphLoader, tmp_path): + """Test using custom label splits for edges""" + data = [(i, j) for i in range(1, 4) for j in range(11, 14)] + # Create DataFrame + edges_df = spark.createDataFrame(data, ["src_str_id", "dst_str_id"]) + + train_df = spark.createDataFrame( + [(i, j) for i in range(1, 2) for j in range(11, 14)], + ["mask_src_id", "mask_dst_id"], + ) + val_df = spark.createDataFrame( + [(i, j) for i in range(2, 3) for j in range(11, 14)], + ["mask_src_id", "mask_dst_id"], + ) + test_df = spark.createDataFrame( + [(i, j) for i in range(3, 4) for j in range(11, 14)], + ["mask_src_id", "mask_dst_id"], + ) + + class_mask_names = [ + f"custom_split_train_mask", + f"custom_split_val_mask", + f"custom_split_test_mask", + ] + train_df.repartition(1).write.parquet(f"{tmp_path}/train.parquet") + val_df.repartition(1).write.parquet(f"{tmp_path}/val.parquet") + test_df.repartition(1).write.parquet(f"{tmp_path}/test.parquet") + config_dict = { + "column": "", + "type": "link_prediction", + "custom_split_filenames": { + "train": f"{tmp_path}/train.parquet", + "valid": f"{tmp_path}/val.parquet", + "test": f"{tmp_path}/test.parquet", + "column": ["mask_src_id", "mask_dst_id"], + }, + "mask_field_names": class_mask_names, + } + dghl_loader.input_prefix = "" + label_configs = [EdgeLabelConfig(config_dict)] + label_metadata_dicts = dghl_loader._process_edge_labels( + label_configs, edges_df, "src_str_id:to:dst_str_id", "" + ) + + assert label_metadata_dicts.keys() == { + "custom_split_train_mask", + "custom_split_val_mask", + "custom_split_test_mask", + } + + train_mask_df, val_mask_df, test_mask_df = read_masks_from_disk( + spark, dghl_loader, label_metadata_dicts, class_mask_names + ) + + train_total_ones = train_mask_df.agg(F.sum("custom_split_train_mask")).collect()[0][0] + val_total_ones = val_mask_df.agg(F.sum("custom_split_val_mask")).collect()[0][0] + test_total_ones = test_mask_df.agg(F.sum("custom_split_test_mask")).collect()[0][0] + assert train_total_ones == 3 + assert val_total_ones == 3 + assert test_total_ones == 3