diff --git a/docs/source/gs-processing/developer/input-configuration.rst b/docs/source/gs-processing/developer/input-configuration.rst index b667cdb8d2..702d3caad1 100644 --- a/docs/source/gs-processing/developer/input-configuration.rst +++ b/docs/source/gs-processing/developer/input-configuration.rst @@ -173,6 +173,15 @@ objects: assign to the validation set [0.0, 1.0). - ``test``: The percentage of the data with available labels to assign to the test set [0.0, 1.0). + - ``custom_split_filenames`` (JSON object, optional): Specifies the customized + training/validation/test mask. Once it is defined, GSProcessing will ignore + the ``split_rate``. + - ``train``: Path of the training mask file such that each line contains + the original ID for node tasks, or the pair [source_id, destination_id] for edge tasks. + - ``val``: Path of the validation mask file such that each line contains + the original ID for node tasks, or the pair [source_id, destination_id] for edge tasks. + - ``test``: Path of the test mask file such that each line contains + the original ID for node tasks, or the pair [source_id, destination_id] for edge tasks. - ``features`` (List of JSON objects, optional)\ **:** Describes the set of features for the current edge type. See the :ref:`features-object` section for details. diff --git a/graphstorm-processing/tests/test_converter.py b/graphstorm-processing/tests/test_converter.py index 54078e4700..6052dec9d0 100644 --- a/graphstorm-processing/tests/test_converter.py +++ b/graphstorm-processing/tests/test_converter.py @@ -166,6 +166,45 @@ def test_read_node_gconstruct(converter: GConstructConfigConverter, node_dict: d } ] + node_dict["nodes"].append( + { + "node_type": "paper_custom", + "format": {"name": "parquet"}, + "files": ["/tmp/acm_raw/nodes/paper_custom.parquet"], + "node_id_col": "node_id", + "labels": [ + { + "label_col": "label", + "task_type": "classification", + "custom_split_filenames": {"train": "customized_label/node_train_idx.parquet", + "valid": "customized_label/node_val_idx.parquet", + "test": "customized_label/node_test_idx.parquet", + "column": ["ID"]}, + "label_stats_type": "frequency_cnt"} + ], + } + ) + + # nodes with all elements + # [self.type, self.format, self.files, self.separator, self.column, self.features, self.labels] + node_config = converter.convert_nodes(node_dict["nodes"])[2] + assert len(converter.convert_nodes(node_dict["nodes"])) == 3 + assert node_config.node_type == "paper_custom" + assert node_config.file_format == "parquet" + assert node_config.files == ["/tmp/acm_raw/nodes/paper_custom.parquet"] + assert node_config.separator is None + assert node_config.column == "node_id" + assert node_config.labels == [ + { + "column": "label", + "type": "classification", + "custom_split_filenames": {"train": "customized_label/node_train_idx.parquet", + "valid": "customized_label/node_val_idx.parquet", + "test": "customized_label/node_test_idx.parquet", + "column": ["ID"]} + } + ] + @pytest.mark.parametrize("col_name", ["author", ["author"]]) def test_read_edge_gconstruct(converter: GConstructConfigConverter, col_name):