From e9b58b6b42d62e9e6bc56472507dc180ccfd4d31 Mon Sep 17 00:00:00 2001 From: Theodore Vasiloudis Date: Thu, 16 Nov 2023 22:19:00 +0000 Subject: [PATCH] Rename `node_id_mappings` to `raw_id_mappings` --- .../dist_heterogeneous_loader.py | 6 +-- .../graph_loaders/row_count_utils.py | 11 ++--- .../test_dist_category_transformation.py | 4 +- .../tests/test_dist_heterogenous_loader.py | 43 +++++++++++++++---- .../graphstorm/gconstruct/construct_graph.py | 12 +++--- python/graphstorm/gconstruct/remap_result.py | 4 +- python/graphstorm/sagemaker/utils.py | 30 +++++++------ tests/end2end-tests/data_process/test_data.py | 6 +-- .../gconstruct/test_construct_graph.py | 2 +- 9 files changed, 76 insertions(+), 42 deletions(-) diff --git a/graphstorm-processing/graphstorm_processing/graph_loaders/dist_heterogeneous_loader.py b/graphstorm-processing/graphstorm_processing/graph_loaders/dist_heterogeneous_loader.py index bd00c74c58..3ae49a3b79 100644 --- a/graphstorm-processing/graphstorm_processing/graph_loaders/dist_heterogeneous_loader.py +++ b/graphstorm-processing/graphstorm_processing/graph_loaders/dist_heterogeneous_loader.py @@ -489,13 +489,13 @@ def _add_node_mappings_to_metadata(self, metadata_dict: Dict) -> Dict: """ Adds node mappings to the metadata dict that is eventually written to disk. """ - metadata_dict["node_id_mappings"] = {} + metadata_dict["raw_id_mappings"] = {} for node_type in metadata_dict["node_type"]: node_mapping_metadata_dict = { "format": {"name": "parquet", "delimiter": ""}, "data": self.node_mapping_paths[node_type], } - metadata_dict["node_id_mappings"][node_type] = node_mapping_metadata_dict + metadata_dict["raw_id_mappings"][node_type] = node_mapping_metadata_dict return metadata_dict @@ -761,7 +761,7 @@ def _write_nodeid_mapping_and_update_state( Also modifies the loader's state to add the mapping path to the node_mapping_paths member variable. """ - mapping_output_path = f"{self.output_prefix}/node_id_mappings/{node_type}" + mapping_output_path = f"{self.output_prefix}/raw_id_mappings/{node_type}" # TODO: For node-file-exists path: Test to see if it's better to keep these in memory # until needed instead of writing out now i.e. we can maintain a dict of DFs instead diff --git a/graphstorm-processing/graphstorm_processing/graph_loaders/row_count_utils.py b/graphstorm-processing/graphstorm_processing/graph_loaders/row_count_utils.py index 6a5a6643f7..2069c487ca 100644 --- a/graphstorm-processing/graphstorm_processing/graph_loaders/row_count_utils.py +++ b/graphstorm-processing/graphstorm_processing/graph_loaders/row_count_utils.py @@ -52,7 +52,8 @@ def __init__(self, metadata_dict: dict, output_prefix: str, filesystem_type: str # Increase default retries because we are likely to run into # throttling errors self.pyarrow_fs = fs.S3FileSystem( - region=bucket_region, retry_strategy=fs.AwsStandardS3RetryStrategy(max_attempts=10) + region=bucket_region, + retry_strategy=fs.AwsStandardS3RetryStrategy(max_attempts=10), ) else: self.pyarrow_fs = fs.LocalFileSystem() @@ -79,7 +80,7 @@ def add_row_counts_to_metadata(self, metadata_dict: dict) -> dict: self._add_counts_for_features(top_level_key="edge_data", edge_or_node_type_key="edge_type") all_node_mapping_counts = self._add_counts_for_graph_structure( - top_level_key="node_id_mappings", edge_or_node_type_key="node_type" + top_level_key="raw_id_mappings", edge_or_node_type_key="node_type" ) self._add_counts_for_features(top_level_key="node_data", edge_or_node_type_key="node_type") @@ -179,7 +180,7 @@ def _add_counts_for_graph_structure( top_level_key : str The top level key that refers to the structure we'll be getting counts for, can be "edges" to get counts for edges structure, - or "node_id_mappings" to get counts for node mappings. + or "raw_id_mappings" to get counts for node mappings. edge_or_node_type_key : str The secondary key we use to iterate over structure types, can be 'edge_type' or 'node_type'. @@ -191,8 +192,8 @@ def _add_counts_for_graph_structure( inner list is a row count. """ # We use the order of types in edge_type and node_type to create the counts - assert top_level_key in {"edges", "node_id_mappings"}, ( - "top_level_key needs to be one of 'edges', 'node_id_mappings' " f"got {top_level_key}" + assert top_level_key in {"edges", "raw_id_mappings"}, ( + "top_level_key needs to be one of 'edges', 'raw_id_mappings' " f"got {top_level_key}" ) assert edge_or_node_type_key in {"edge_type", "node_type"}, ( "edge_or_node_type_key needs to be one of 'edge_type', 'node_type' " diff --git a/graphstorm-processing/tests/test_dist_category_transformation.py b/graphstorm-processing/tests/test_dist_category_transformation.py index 662fe32a1a..74d155fb76 100644 --- a/graphstorm-processing/tests/test_dist_category_transformation.py +++ b/graphstorm-processing/tests/test_dist_category_transformation.py @@ -238,7 +238,9 @@ def test_parquet_input_multi_categorical(spark: SparkSession, check_df_schema): df_parquet = spark.read.parquet(parquet_path) # Show the DataFrame loaded from the Parquet file - dist_categorical_transormation = DistMultiCategoryTransformation(cols=["names"], separator=None) + dist_categorical_transormation = DistMultiCategoryTransformation( + cols=["names"], separator=None + ) transformed_df = dist_categorical_transormation.apply(df_parquet) check_df_schema(transformed_df) diff --git a/graphstorm-processing/tests/test_dist_heterogenous_loader.py b/graphstorm-processing/tests/test_dist_heterogenous_loader.py index f67a52babe..bad1e0bb53 100644 --- a/graphstorm-processing/tests/test_dist_heterogenous_loader.py +++ b/graphstorm-processing/tests/test_dist_heterogenous_loader.py @@ -31,13 +31,21 @@ NODE_MAPPING_STR, ) from graphstorm_processing.data_transformations.dist_label_loader import SplitRates -from graphstorm_processing.config.label_config_base import NodeLabelConfig, EdgeLabelConfig +from graphstorm_processing.config.label_config_base import ( + NodeLabelConfig, + EdgeLabelConfig, +) from graphstorm_processing.config.config_parser import ( create_config_objects, EdgeConfig, ) from graphstorm_processing.config.config_conversion import GConstructConfigConverter -from graphstorm_processing.constants import COLUMN_NAME, MIN_VALUE, MAX_VALUE, VALUE_COUNTS +from graphstorm_processing.constants import ( + COLUMN_NAME, + MIN_VALUE, + MAX_VALUE, + VALUE_COUNTS, +) pytestmark = pytest.mark.usefixtures("spark") _ROOT = os.path.abspath(os.path.dirname(__file__)) @@ -176,7 +184,7 @@ def verify_integ_test_output( # TODO: The following Parquet reads assume there's only one file in the output for node_type in metadata["node_type"]: nrows = pq.ParquetFile( - os.path.join(loader.output_path, metadata["node_id_mappings"][node_type]["data"][0]) + os.path.join(loader.output_path, metadata["raw_id_mappings"][node_type]["data"][0]) ).metadata.num_rows assert nrows == expected_node_counts[node_type] @@ -238,7 +246,10 @@ def test_load_dist_heterogen_node_class(dghl_loader: DistHeterogeneousGraphLoade "task_type": "node_class", "label_map": {"male": 0, "female": 1}, "label_properties": { - "user": {"COLUMN_NAME": "gender", "VALUE_COUNTS": {"male": 3, "female": 1, "null": 1}} + "user": { + "COLUMN_NAME": "gender", + "VALUE_COUNTS": {"male": 3, "female": 1, "null": 1}, + } }, } @@ -252,12 +263,16 @@ def test_load_dist_heterogen_node_class(dghl_loader: DistHeterogeneousGraphLoade assert metadata["node_data"][node_type].keys() == expected_node_data[node_type] -def test_load_dist_hgl_without_labels(dghl_loader_no_label: DistHeterogeneousGraphLoader): +def test_load_dist_hgl_without_labels( + dghl_loader_no_label: DistHeterogeneousGraphLoader, +): """End 2 end test when no labels are provided""" dghl_loader_no_label.load() with open( - os.path.join(dghl_loader_no_label.output_path, "metadata.json"), "r", encoding="utf-8" + os.path.join(dghl_loader_no_label.output_path, "metadata.json"), + "r", + encoding="utf-8", ) as mfile: metadata = json.load(mfile) @@ -292,7 +307,11 @@ def test_write_edge_structure_no_reverse_edges( dghl_loader_no_reverse_edges.create_node_id_maps_from_edges(edge_configs, missing_node_types) edge_dict: Dict[str, Dict] = { - "data": {"format": "csv", "files": ["edges/user-rated-movie.csv"], "separator": ","}, + "data": { + "format": "csv", + "files": ["edges/user-rated-movie.csv"], + "separator": ",", + }, "source": {"column": "~from", "type": "user"}, "relation": {"type": "rated"}, "dest": {"column": "~to", "type": "movie"}, @@ -307,7 +326,9 @@ def test_write_edge_structure_no_reverse_edges( def test_create_all_mapppings_from_edges( - spark: SparkSession, data_configs_with_label, dghl_loader: DistHeterogeneousGraphLoader + spark: SparkSession, + data_configs_with_label, + dghl_loader: DistHeterogeneousGraphLoader, ): """Test creating all node mappings only from edge files""" edge_configs = data_configs_with_label["edges"] @@ -495,7 +516,11 @@ def test_create_split_files_from_rates( ) ensure_masks_are_correct( - train_mask_df, test_mask_df, val_mask_df, non_missing_data_points, split_rates.tolist() + train_mask_df, + test_mask_df, + val_mask_df, + non_missing_data_points, + split_rates.tolist(), ) diff --git a/python/graphstorm/gconstruct/construct_graph.py b/python/graphstorm/gconstruct/construct_graph.py index a149d17742..2ccc6a391d 100644 --- a/python/graphstorm/gconstruct/construct_graph.py +++ b/python/graphstorm/gconstruct/construct_graph.py @@ -667,18 +667,18 @@ def process_graph(args): if len(output_format) == 1 and output_format[0] == "DistDGL" else None convert2ext_mem = ExtMemArrayMerger(ext_mem_workspace, args.ext_mem_feat_size) - node_id_map, node_data, node_label_stats = \ + raw_node_id_maps, node_data, node_label_stats = \ process_node_data(process_confs['nodes'], convert2ext_mem, args.remap_node_id, ext_mem_workspace, num_processes=num_processes_for_nodes) sys_tracker.check('Process the node data') edges, edge_data, edge_label_stats = \ - process_edge_data(process_confs['edges'], node_id_map, + process_edge_data(process_confs['edges'], raw_node_id_maps, convert2ext_mem, ext_mem_workspace, num_processes=num_processes_for_edges, skip_nonexist_edges=args.skip_nonexist_edges) sys_tracker.check('Process the edge data') - num_nodes = {ntype: len(node_id_map[ntype]) for ntype in node_id_map} + num_nodes = {ntype: len(raw_node_id_maps[ntype]) for ntype in raw_node_id_maps} if args.output_conf_file is not None: # Save the new config file. with open(args.output_conf_file, "w", encoding="utf8") as outfile: @@ -742,9 +742,9 @@ def process_graph(args): if len(edge_label_stats) > 0: save_edge_label_stats(args.output_dir, edge_label_stats) - for ntype, node_id_map in node_id_map.items(): - map_prefix = os.path.join(args.output_dir, "node_id_mappings", ntype) - node_id_map.save(map_prefix) + for ntype, raw_id_map in raw_node_id_maps.items(): + map_prefix = os.path.join(args.output_dir, "raw_id_mappings", ntype) + raw_id_map.save(map_prefix) logging.info("Graph construction generated new node IDs for '%s'. " + \ "The ID map is saved under %s.", ntype, map_prefix) diff --git a/python/graphstorm/gconstruct/remap_result.py b/python/graphstorm/gconstruct/remap_result.py index 429c5ebef7..9a64187ed4 100644 --- a/python/graphstorm/gconstruct/remap_result.py +++ b/python/graphstorm/gconstruct/remap_result.py @@ -504,7 +504,7 @@ def _parse_gs_config(config): list of str: etypes that have prediction results """ part_config = config.part_config - node_id_mapping = os.path.join(os.path.dirname(part_config), "node_id_mappings") + node_id_mapping = os.path.join(os.path.dirname(part_config), "raw_id_mappings") predict_dir = config.save_prediction_path emb_dir = config.save_embed_path task_type = config.task_type @@ -690,7 +690,7 @@ def main(args, gs_config_args): id_maps[ntype] = \ IdReverseMap(mapping_prefix) else: - logging.fatal("ID mapping prefix %s does not exist, skipping remapping", + logging.warning("ID mapping prefix %s does not exist, skipping remapping", mapping_prefix) sys.exit(0) diff --git a/python/graphstorm/sagemaker/utils.py b/python/graphstorm/sagemaker/utils.py index e8ca646dc0..6959856152 100644 --- a/python/graphstorm/sagemaker/utils.py +++ b/python/graphstorm/sagemaker/utils.py @@ -212,7 +212,7 @@ def download_model(model_artifact_s3, model_path, sagemaker_session): def download_graph(graph_data_s3, graph_name, part_id, world_size, local_path, sagemaker_session, - node_mapping_prefix_s3=None): + raw_node_mapping_prefix_s3=None): """ download graph data Parameters @@ -229,7 +229,7 @@ def download_graph(graph_data_s3, graph_name, part_id, world_size, Path to store graph data sagemaker_session: sagemaker.session.Session sagemaker_session to run download - node_mapping_prefix_s3: str, optional + raw_node_mapping_prefix_s3: str, optional S3 prefix to where the node_id_mapping data are stored Return @@ -250,14 +250,14 @@ def download_graph(graph_data_s3, graph_name, part_id, world_size, # By default we assume the node mappings exist # under the same path as the rest of the graph data - if not node_mapping_prefix_s3: - node_mapping_prefix_s3 = f"{graph_data_s3}/node_id_mappings" + if not raw_node_mapping_prefix_s3: + raw_node_mapping_prefix_s3 = f"{graph_data_s3}/raw_id_mappings" else: - node_mapping_prefix_s3 = ( - node_mapping_prefix_s3[:-1] if node_mapping_prefix_s3.endswith('/') - else node_mapping_prefix_s3) - assert node_mapping_prefix_s3.endswith("node_id_mappings"), \ - "node_mapping_prefix_s3 must end with 'node_id_mappings'" + raw_node_mapping_prefix_s3 = ( + raw_node_mapping_prefix_s3[:-1] if raw_node_mapping_prefix_s3.endswith('/') + else raw_node_mapping_prefix_s3) + assert raw_node_mapping_prefix_s3.endswith("raw_id_mappings"), \ + "node_mapping_prefix_s3 must end with 'raw_id_mappings'" # We split on '/' to get the bucket, as it's always the third split element in an S3 URI @@ -320,11 +320,17 @@ def download_graph(graph_data_s3, graph_name, part_id, world_size, logging.info("node id mapping file %s does not exist", s3_path) # Try to get GraphStorm ID to Original ID remapping files if any - id_map_files = S3Downloader.list(node_mapping_prefix_s3, sagemaker_session=sagemaker_session) + id_map_files = S3Downloader.list( + raw_node_mapping_prefix_s3, sagemaker_session=sagemaker_session) for mapping_file in id_map_files: + # The expected layout for mapping files on S3 is: + # raw_id_mappings/node_type/part-xxxxx.parquet + ntype = mapping_file.split("/")[-2] try: - S3Downloader.download(mapping_file, graph_path, - sagemaker_session=sagemaker_session) + S3Downloader.download( + mapping_file, + os.path.join(graph_path, "raw_id_mappings", ntype), + sagemaker_session=sagemaker_session) except Exception: # pylint: disable=broad-except logging.warning("Could not download node id remap file %s", mapping_file) diff --git a/tests/end2end-tests/data_process/test_data.py b/tests/end2end-tests/data_process/test_data.py index a7c6f9d319..566c1453e3 100644 --- a/tests/end2end-tests/data_process/test_data.py +++ b/tests/end2end-tests/data_process/test_data.py @@ -57,9 +57,9 @@ def read_data_parquet(data_file): else: raise ValueError('Invalid graph format: {}'.format(args.graph_format)) -node1_map = read_data_parquet(os.path.join(out_dir, "node_id_mappings", "node1")) +node1_map = read_data_parquet(os.path.join(out_dir, "raw_id_mappings", "node1")) reverse_node1_map = {val: key for key, val in zip(node1_map['orig'], node1_map['new'])} -node3_map = read_data_parquet(os.path.join(out_dir, "node_id_mappings", "node3")) +node3_map = read_data_parquet(os.path.join(out_dir, "raw_id_mappings", "node3")) reverse_node3_map = {val: key for key, val in zip(node3_map['orig'], node3_map['new'])} # Test the first node data @@ -133,7 +133,7 @@ def read_data_parquet(data_file): assert len(node_conf["features"][0]["transform"]["mapping"]) == 10 # id remap for node4 exists -assert os.path.isdir(os.path.join(out_dir, "node_id_mappings", "node4")) +assert os.path.isdir(os.path.join(out_dir, "raw_id_mappings", "node4")) # Test the edge data of edge type 1 src_ids, dst_ids = g.edges(etype=('node1', 'relation1', 'node2')) diff --git a/tests/unit-tests/gconstruct/test_construct_graph.py b/tests/unit-tests/gconstruct/test_construct_graph.py index b50928308f..9967d537fe 100644 --- a/tests/unit-tests/gconstruct/test_construct_graph.py +++ b/tests/unit-tests/gconstruct/test_construct_graph.py @@ -1723,4 +1723,4 @@ def test_gc(): test_label() test_multicolumn(None) test_multicolumn("/tmp/") - test_feature_wrapper() \ No newline at end of file + test_feature_wrapper()