From dba7f4db66e6eedb5c359612edf5ab9e1c507de2 Mon Sep 17 00:00:00 2001 From: Theodore Vasiloudis Date: Thu, 19 Dec 2024 01:39:22 +0000 Subject: [PATCH] [GSProcessing] Ensure we can convert custom split GConstruct data configs whether they are single str or list[str] NOTE: GSProcessing now requires list[str] for its custom split file config input. --- .../gsprocessing/input-configuration.rst | 26 ++- .../config_conversion/gconstruct_converter.py | 34 +++- .../config/config_parser.py | 2 + .../config/label_config_base.py | 25 ++- .../data_transformations/dist_label_loader.py | 14 +- .../dist_heterogeneous_loader.py | 17 +- graphstorm-processing/pyproject.toml | 1 + graphstorm-processing/tests/test_converter.py | 34 +++- .../tests/test_dist_heterogenous_loader.py | 24 +-- .../tests/test_gsprocessing_config.py | 161 ++++++++++++++++++ python/graphstorm/gconstruct/file_io.py | 2 +- 11 files changed, 284 insertions(+), 56 deletions(-) create mode 100644 graphstorm-processing/tests/test_gsprocessing_config.py diff --git a/docs/source/cli/graph-construction/distributed/gsprocessing/input-configuration.rst b/docs/source/cli/graph-construction/distributed/gsprocessing/input-configuration.rst index d2074a338f..fefa863eb8 100644 --- a/docs/source/cli/graph-construction/distributed/gsprocessing/input-configuration.rst +++ b/docs/source/cli/graph-construction/distributed/gsprocessing/input-configuration.rst @@ -173,15 +173,23 @@ objects: assign to the validation set [0.0, 1.0). - ``test``: The percentage of the data with available labels to assign to the test set [0.0, 1.0). - - ``custom_split_filenames`` (JSON object, optional): Specifies the customized - training/validation/test mask. Once it is defined, GSProcessing will ignore - the ``split_rate``. - - ``train``: Path of the training mask parquet file such that each line contains - the original ID for node tasks, or the pair [source_id, destination_id] for edge tasks. - - ``val``: Path of the validation mask parquet file such that each line contains - the original ID for node tasks, or the pair [source_id, destination_id] for edge tasks. - - ``test``: Path of the test mask parquet file such that each line contains - the original ID for node tasks, or the pair [source_id, destination_id] for edge tasks. + - ``custom_split_filenames`` (JSON object, optional): Specifies pre-assigned + training/validation/test masks. If defined, GSProcessing will ignore + ``split_rate`` if provided. + + - ``column``: (List[String], optional) A list of length one for node splits, or two for edge splits, + containing the name(s) of the column(s) that contain the node/edge ids for each split. For example, + if the node ids to include in each split exist in column ``"nid"`` of the custom train/val/test files, this + needs to be ``["nid"]``. For edges it would be a value like ``["src_id", "dst_id"]``. + If not provided for nodes we assume the first column in the data contains the node ids to include. + For edges, we assume the first column is the source id and the second the destination id. + - ``train``: (List[String], optional) Paths of the training mask parquet file such that each line contains + the original ID for node tasks, or the pair ``[source_id, destination_id]`` for edge tasks. + - ``val``: (List[String], optional) Paths of the validation mask parquet file such that each line contains + the original ID for node tasks, or the pair ``[source_id, destination_id]`` for edge tasks. + - ``test``: (List[String], optional) Paths of the test mask parquet file such that each line contains + the original ID for node tasks, or the pair ``[source_id, destination_id]`` for edge tasks. + - Note: At least one of the ``["train", "val", "test"]`` keys must be present. - ``features`` (List of JSON objects, optional)\ **:** Describes the set of features for the current edge type. See the :ref:`features-object` section for details. diff --git a/graphstorm-processing/graphstorm_processing/config/config_conversion/gconstruct_converter.py b/graphstorm-processing/graphstorm_processing/config/config_conversion/gconstruct_converter.py index 33fe40f760..7d9e21c5ed 100644 --- a/graphstorm-processing/graphstorm_processing/config/config_conversion/gconstruct_converter.py +++ b/graphstorm-processing/graphstorm_processing/config/config_conversion/gconstruct_converter.py @@ -50,7 +50,7 @@ def _convert_label(labels: list[dict]) -> list[dict]: labels_list = [] if labels in [[], [{}]]: return [] - for label in labels: + for label in labels: # pylint: disable=too-many-nested-blocks try: label_column = label["label_col"] if "label_col" in label else "" label_type = label["task_type"] @@ -61,7 +61,7 @@ def _convert_label(labels: list[dict]) -> list[dict]: # check if split_pct is valid assert ( math.fsum(label_splitrate) == 1.0 - ), "sum of the label split rate should be ==1.0" + ), f"sum of the label split rate should be ==1.0, got {label_splitrate=}" label_dict["split_rate"] = { "train": label_splitrate[0], "val": label_splitrate[1], @@ -69,11 +69,29 @@ def _convert_label(labels: list[dict]) -> list[dict]: } else: label_custom_split_filenames = label["custom_split_filenames"] - if isinstance(label_custom_split_filenames["column"], list): - assert len(label_custom_split_filenames["column"]) <= 2, ( - "Custom split filenames should have one column for node labels, " - "and two columns for edges labels exactly" - ) + # Ensure at least one of ["train", "valid", "test"] is in the keys + assert any( + x in label_custom_split_filenames.keys() for x in ["train", "valid", "test"] + ), ("At least one of ['train', 'valid', 'test'] " + "needs to exist in custom split configs.") + + # Fill in missing values if needed + for entry in ["train", "valid", "test", "column"]: + entry_val = label_custom_split_filenames.get(entry, None) + if entry_val: + if isinstance(entry_val, str): + label_custom_split_filenames[entry] = [entry_val] + else: + assert isinstance( + entry_val, list + ), "Custom split filenames should be a string or a list of strings" + else: + label_custom_split_filenames[entry] = [] + assert len(label_custom_split_filenames["column"]) <= 2, ( + "Custom split filenames should have one column for node labels, " + "and two columns for edges labels exactly" + ) + label_dict["custom_split_filenames"] = { "train": label_custom_split_filenames["train"], "valid": label_custom_split_filenames["valid"], @@ -213,7 +231,7 @@ def _convert_feature(feats: list[Mapping[str, Any]]) -> list[dict]: return gsp_feats_list @staticmethod - def convert_nodes(nodes_entries): + def convert_nodes(nodes_entries) -> list[NodeConfig]: res = [] for n in nodes_entries: # type, column id diff --git a/graphstorm-processing/graphstorm_processing/config/config_parser.py b/graphstorm-processing/graphstorm_processing/config/config_parser.py index 95f4ab3dd2..4e7c0c557d 100644 --- a/graphstorm-processing/graphstorm_processing/config/config_parser.py +++ b/graphstorm-processing/graphstorm_processing/config/config_parser.py @@ -194,6 +194,8 @@ def __init__(self, edge_dict: Dict[str, Dict], data_dict: Dict[str, Any]): self._dst_ntype = edge_dict["dest"]["type"] self._rel_type = edge_dict["relation"]["type"] self._rel_col: Optional[str] = edge_dict["relation"].get("column", None) + self._feature_configs = [] + self._labels = [] if "features" in edge_dict: for feature_dict in edge_dict["features"]: diff --git a/graphstorm-processing/graphstorm_processing/config/label_config_base.py b/graphstorm-processing/graphstorm_processing/config/label_config_base.py index c98ec81614..000f80c9ca 100644 --- a/graphstorm-processing/graphstorm_processing/config/label_config_base.py +++ b/graphstorm-processing/graphstorm_processing/config/label_config_base.py @@ -60,17 +60,24 @@ def _sanity_check(self): "When no label column is specified, the task type must be link_prediction, " f"got {self._task_type}" ) + # Sanity checks for random and custom splits if "custom_split_filenames" not in self._config: assert isinstance(self._task_type, str) assert isinstance(self._split, dict) - assert isinstance(self._separator, str) if self._multilabel else self._separator is None else: assert isinstance(self._custom_split_filenames, dict) - assert "train" in self._custom_split_filenames - assert "valid" in self._custom_split_filenames - assert "test" in self._custom_split_filenames - assert "column" in self._custom_split_filenames - assert isinstance(self._separator, str) if self._multilabel else self._separator is None + + # Ensure at least one of ["train", "valid", "test"] is in the keys + assert any(x in self._custom_split_filenames.keys() for x in ["train", "valid", "test"]) + + for entry in ["train", "valid", "test", "column"]: + # All existing entries must be lists of strings + if entry in self._custom_split_filenames: + assert isinstance(self._custom_split_filenames[entry], list) + assert all(isinstance(x, str) for x in self._custom_split_filenames[entry]) + + assert isinstance(self._separator, str) if self._multilabel else self._separator is None + if self._mask_field_names: assert isinstance(self._mask_field_names, list) assert all(isinstance(x, str) for x in self._mask_field_names) @@ -111,7 +118,11 @@ def custom_split_filenames(self) -> Dict[str, Any]: @property def mask_field_names(self) -> Optional[tuple[str, str, str]]: """Custom names to assign to masks for multi-task learning.""" - return tuple(self._mask_field_names) if self._mask_field_names else None + if self._mask_field_names is None: + return None + else: + assert len(self._mask_field_names) == 3 + return (self._mask_field_names[0], self._mask_field_names[1], self._mask_field_names[2]) class EdgeLabelConfig(LabelConfig): diff --git a/graphstorm-processing/graphstorm_processing/data_transformations/dist_label_loader.py b/graphstorm-processing/graphstorm_processing/data_transformations/dist_label_loader.py index 956c153320..cd0ba563f0 100644 --- a/graphstorm-processing/graphstorm_processing/data_transformations/dist_label_loader.py +++ b/graphstorm-processing/graphstorm_processing/data_transformations/dist_label_loader.py @@ -73,19 +73,19 @@ class CustomSplit: Parameters ---------- - train : str - Path of the training mask parquet file. + train : list[str] + Paths of the training mask parquet files. valid : str - Path of the validation mask parquet file. + Paths of the validation mask parquet files. test : str - Path of the testing mask parquet file. + Paths of the testing mask parquet files. mask_columns : list[str] List of columns that contain original string ids. """ - train: str - valid: str - test: str + train: list[str] + valid: list[str] + test: list[str] mask_columns: list[str] diff --git a/graphstorm-processing/graphstorm_processing/graph_loaders/dist_heterogeneous_loader.py b/graphstorm-processing/graphstorm_processing/graph_loaders/dist_heterogeneous_loader.py index 2f6e2ebe83..5cea423e56 100644 --- a/graphstorm-processing/graphstorm_processing/graph_loaders/dist_heterogeneous_loader.py +++ b/graphstorm-processing/graphstorm_processing/graph_loaders/dist_heterogeneous_loader.py @@ -177,7 +177,7 @@ def __init__( assert os.path.isabs(loader_config.input_prefix), "We expect an absolute path" self.filesystem_type = "local" - self.spark = spark # type: SparkSession + self.spark: SparkSession = spark self.add_reverse_edges = loader_config.add_reverse_edges # Remove trailing slash in s3 paths if self.filesystem_type == "s3": @@ -2131,11 +2131,11 @@ def create_mapping(input_df): return return_df if mask_type == "train": - file_path = split_file.train + file_paths = split_file.train elif mask_type == "val": - file_path = split_file.valid + file_paths = split_file.valid elif mask_type == "test": - file_path = split_file.test + file_paths = split_file.test else: raise ValueError("Unknown mask type") @@ -2144,7 +2144,7 @@ def create_mapping(input_df): if len(split_file.mask_columns) == 1: # custom split on node original id custom_mask_df = self.spark.read.parquet( - os.path.join(self.input_prefix, file_path) + *[os.path.join(self.input_prefix, file_path) for file_path in file_paths] ).select(col(split_file.mask_columns[0]).alias(f"custom_{mask_type}_mask")) input_df_id = create_mapping(input_df) mask_df = input_df_id.join( @@ -2162,7 +2162,7 @@ def create_mapping(input_df): elif len(split_file.mask_columns) == 2: # custom split on edge (srd, dst) original ids custom_mask_df = self.spark.read.parquet( - os.path.join(self.input_prefix, file_path) + *[os.path.join(self.input_prefix, file_path) for file_path in file_paths] ).select( col(split_file.mask_columns[0]).alias(f"custom_{mask_type}_mask_src"), col(split_file.mask_columns[1]).alias(f"custom_{mask_type}_mask_dst"), @@ -2184,7 +2184,10 @@ def create_mapping(input_df): .alias(mask_name), ).select(mask_name) else: - raise ValueError("The number of column should be only 1 or 2.") + raise ValueError( + "The number of column should be only 1 or 2, got columns: " + f"{split_file.mask_columns}" + ) return mask_df diff --git a/graphstorm-processing/pyproject.toml b/graphstorm-processing/pyproject.toml index 8d9f7573f1..ef19663f25 100644 --- a/graphstorm-processing/pyproject.toml +++ b/graphstorm-processing/pyproject.toml @@ -41,6 +41,7 @@ pre-commit = "^3.3.3" types-mock = "^5.1.0.1" pylint = "~2.17.5" diff-cover = "^9.0.0" +pytest-cov = "^6.0.0" [project] requires-python = ">=3.9" # TODO: Do we need a tilde here? diff --git a/graphstorm-processing/tests/test_converter.py b/graphstorm-processing/tests/test_converter.py index a4871342d2..55af93197b 100644 --- a/graphstorm-processing/tests/test_converter.py +++ b/graphstorm-processing/tests/test_converter.py @@ -24,7 +24,7 @@ @pytest.fixture(name="converter") def fixture_create_converter() -> GConstructConfigConverter: """Creates a new converter object for each test.""" - yield GConstructConfigConverter() + return GConstructConfigConverter() @pytest.fixture(name="node_dict") @@ -65,8 +65,32 @@ def test_try_read_unsupported_feature(converter: GConstructConfigConverter, node _ = converter.convert_nodes(node_dict["nodes"]) +def test_custom_split_config_conversion(converter: GConstructConfigConverter): + """Test custom split file config conversion""" + gconstruct_label_dicts = [ + { + "label_col": "label", + "task_type": "classification", + "custom_split_filenames": { + "column": ["src", "dst"], + }, + "label_stats_type": "frequency_cnt", + } + ] + + # Should raise when none of train/val/test are in the input + with pytest.raises(AssertionError): + converter._convert_label(gconstruct_label_dicts) + + # Ensure single strings are converted to list of strings + gconstruct_label_dicts[0]["custom_split_filenames"]["train"] = "fake_file" + gsprocessing_label_dict = converter._convert_label(gconstruct_label_dicts)[0] + + assert gsprocessing_label_dict["custom_split_filenames"]["train"] == ["fake_file"] + + def test_try_read_invalid_gconstruct_config(converter: GConstructConfigConverter, node_dict: dict): - """Custom Split Columns""" + """Test various invalid input scenarios""" node_dict["nodes"][0]["labels"] = [ { "label_col": "label", @@ -238,9 +262,9 @@ def test_read_node_gconstruct(converter: GConstructConfigConverter, node_dict: d "column": "label", "type": "classification", "custom_split_filenames": { - "train": "customized_label/node_train_idx.parquet", - "valid": "customized_label/node_val_idx.parquet", - "test": "customized_label/node_test_idx.parquet", + "train": ["customized_label/node_train_idx.parquet"], + "valid": ["customized_label/node_val_idx.parquet"], + "test": ["customized_label/node_test_idx.parquet"], "column": ["ID"], }, } diff --git a/graphstorm-processing/tests/test_dist_heterogenous_loader.py b/graphstorm-processing/tests/test_dist_heterogenous_loader.py index 8fb9a2d501..f9a8521f75 100644 --- a/graphstorm-processing/tests/test_dist_heterogenous_loader.py +++ b/graphstorm-processing/tests/test_dist_heterogenous_loader.py @@ -955,9 +955,9 @@ def test_node_custom_label(spark, dghl_loader: DistHeterogeneousGraphLoader, tmp "type": "classification", "split_rate": {"train": 0.8, "val": 0.1, "test": 0.1}, "custom_split_filenames": { - "train": f"{tmp_path}/train.parquet", - "valid": f"{tmp_path}/val.parquet", - "test": f"{tmp_path}/test.parquet", + "train": [f"{tmp_path}/train.parquet"], + "valid": [f"{tmp_path}/val.parquet"], + "test": [f"{tmp_path}/test.parquet"], "column": ["mask_id"], }, } @@ -1010,9 +1010,9 @@ def test_edge_custom_label(spark, dghl_loader: DistHeterogeneousGraphLoader, tmp "column": "", "type": "link_prediction", "custom_split_filenames": { - "train": f"{tmp_path}/train.parquet", - "valid": f"{tmp_path}/val.parquet", - "test": f"{tmp_path}/test.parquet", + "train": [f"{tmp_path}/train.parquet"], + "valid": [f"{tmp_path}/val.parquet"], + "test": [f"{tmp_path}/test.parquet"], "column": ["mask_src_id", "mask_dst_id"], }, } @@ -1061,9 +1061,9 @@ def test_node_custom_label_multitask(spark, dghl_loader: DistHeterogeneousGraphL "type": "classification", "split_rate": {"train": 0.8, "val": 0.1, "test": 0.1}, "custom_split_filenames": { - "train": f"{tmp_path}/train.parquet", - "valid": f"{tmp_path}/val.parquet", - "test": f"{tmp_path}/test.parquet", + "train": [f"{tmp_path}/train.parquet"], + "valid": [f"{tmp_path}/val.parquet"], + "test": [f"{tmp_path}/test.parquet"], "column": ["mask_id"], }, "mask_field_names": class_mask_names, @@ -1173,9 +1173,9 @@ def test_edge_custom_label_multitask(spark, dghl_loader: DistHeterogeneousGraphL "column": "", "type": "link_prediction", "custom_split_filenames": { - "train": f"{tmp_path}/train.parquet", - "valid": f"{tmp_path}/val.parquet", - "test": f"{tmp_path}/test.parquet", + "train": [f"{tmp_path}/train.parquet"], + "valid": [f"{tmp_path}/val.parquet"], + "test": [f"{tmp_path}/test.parquet"], "column": ["mask_src_id", "mask_dst_id"], }, "mask_field_names": class_mask_names, diff --git a/graphstorm-processing/tests/test_gsprocessing_config.py b/graphstorm-processing/tests/test_gsprocessing_config.py new file mode 100644 index 0000000000..7da2f72882 --- /dev/null +++ b/graphstorm-processing/tests/test_gsprocessing_config.py @@ -0,0 +1,161 @@ +""" +Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"). +You may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +""" + +import pytest + +from graphstorm_processing.config.config_parser import ( + NodeConfig, + EdgeConfig, + create_config_objects, + parse_feat_config, +) +from graphstorm_processing.config.numerical_configs import ( + NumericalFeatureConfig, +) +from graphstorm_processing.config.label_config_base import ( + EdgeLabelConfig, + NodeLabelConfig, +) + + +def test_parse_edge_lp_config(): + """Test parsing an edge configuration with link prediction task""" + label_config_dict = [ + { + "column": "", + "type": "link_prediction", + "split_rate": {"train": 0.8, "val": 0.1, "test": 0.1}, + } + ] + edge_dict = { + "source": {"column": "~from", "type": "movie"}, + "relation": {"type": "included_in"}, + "dest": {"column": "~to", "type": "genre"}, + "labels": label_config_dict, + } + data_dict = {"format": "csv", "files": ["edges/movie-included_in-genre.csv"], "separator": ","} + edge_config = EdgeConfig(edge_dict, data_dict) + + assert edge_config.src_ntype == "movie" + assert edge_config.dst_ntype == "genre" + assert edge_config.rel_type == "included_in" + assert edge_config.src_col == "~from" + assert edge_config.dst_col == "~to" + assert edge_config.rel_col is None + assert edge_config.format == "csv" + assert edge_config.files == ["edges/movie-included_in-genre.csv"] + assert edge_config.separator == "," + assert edge_config.label_configs + assert len(edge_config.label_configs) == 1 + lp_config = edge_config.label_configs[0] + assert isinstance(lp_config, EdgeLabelConfig) + assert lp_config.task_type == "link_prediction" + assert lp_config.split_rate == {"train": 0.8, "val": 0.1, "test": 0.1} + + +def test_parse_basic_node_config(): + """Test parsing a basic node configuration""" + node_config = {"column": "~id", "type": "user"} + data_config = {"format": "csv", "files": ["nodes/user.csv"], "separator": ","} + node_config_obj = NodeConfig(node_config, data_config) + + assert node_config_obj.ntype == "user" + assert node_config_obj.node_col == "~id" + assert node_config_obj.format == "csv" + assert node_config_obj.files == ["nodes/user.csv"] + assert node_config_obj.separator == "," + + +def test_parse_num_configs(): + """Test parsing a numerical features configuration""" + feature_dict = { + "column": "age", + "transformation": { + "name": "numerical", + "kwargs": {"imputer": "mean", "normalizer": "min-max"}, + }, + } + feature_config = parse_feat_config(feature_dict) + + assert isinstance(feature_config, NumericalFeatureConfig) + assert feature_config._cols == ["age"] + assert feature_config.feat_type == "numerical" + assert feature_config._transformation_kwargs["imputer"] == "mean" + assert feature_config._transformation_kwargs["normalizer"] == "min-max" + + +def test_parse_node_label_configs(): + """Test parsing a node configuration with a classification label""" + label_config_dict = { + "column": "gender", + "type": "classification", + "split_rate": {"train": 0.8, "val": 0.1, "test": 0.1}, + } + node_config_dict = {"column": "~id", "type": "user", "labels": [label_config_dict]} + data_config = {"format": "csv", "files": ["nodes/user.csv"], "separator": ","} + node_config_obj = NodeConfig(node_config_dict, data_config) + + assert node_config_obj.label_configs + assert len(node_config_obj.label_configs) == 1 + label_config = node_config_obj.label_configs[0] + assert isinstance(label_config, NodeLabelConfig) + assert label_config.label_column == "gender" + assert label_config.task_type == "classification" + assert label_config.split_rate == {"train": 0.8, "val": 0.1, "test": 0.1} + + +def test_create_config_objects(): + """Test conversion of input dicts to config objects""" + graph_config = { + "edges": [ + { + "data": { + "format": "csv", + "files": ["edges/movie-included_in-genre.csv"], + "separator": ",", + }, + "source": {"column": "~from", "type": "movie"}, + "relation": {"type": "included_in"}, + "dest": {"column": "~to", "type": "genre"}, + } + ], + "nodes": [ + { + "data": {"format": "csv", "files": ["nodes/genre.csv"], "separator": ","}, + "column": "~id", + "type": "genre", + } + ], + } + + config_objects = create_config_objects(graph_config) + + assert len(config_objects["edges"]) == 1 + assert len(config_objects["nodes"]) == 1 + assert isinstance(config_objects["edges"][0], EdgeConfig) + assert isinstance(config_objects["nodes"][0], NodeConfig) + + +def test_unsupported_transformation(): + """Test that an unsupported transformation raises an error""" + + feature_dict = {"column": "feature", "transformation": {"name": "unsupported_transform"}} + + with pytest.raises( + RuntimeError, + match="Unknown transformation name: 'unsupported_transform'", + ): + parse_feat_config(feature_dict) diff --git a/python/graphstorm/gconstruct/file_io.py b/python/graphstorm/gconstruct/file_io.py index 5eed485372..31fce06945 100644 --- a/python/graphstorm/gconstruct/file_io.py +++ b/python/graphstorm/gconstruct/file_io.py @@ -118,7 +118,7 @@ def read_index_parquet(data_file, column): if len(column) == 1: res_array = df[column[0]].to_numpy() - elif len(df.columns) == 2: + elif len(column) == 2: res_array = list(zip(df[column[0]].to_numpy(), df[column[1]].to_numpy())) else: raise ValueError("The Parquet file on node mask must contain exactly one column, "