diff --git a/graphstorm-processing/graphstorm_processing/distributed_executor.py b/graphstorm-processing/graphstorm_processing/distributed_executor.py index 3bc1bc4936..ee6858b544 100644 --- a/graphstorm-processing/graphstorm_processing/distributed_executor.py +++ b/graphstorm-processing/graphstorm_processing/distributed_executor.py @@ -59,7 +59,7 @@ import tempfile import time from collections.abc import Mapping -from typing import Any, Dict +from typing import Any, Dict, Optional import boto3 import botocore @@ -106,8 +106,8 @@ class ExecutorConfig: The filesystem type, can be LOCAL or S3 add_reverse_edges : bool Whether to create reverse edges for each edge type. - graph_name: str - The name of the graph being processed + graph_name: str, optional + The name of the graph being processed. If not provided we use part of the input_prefix. do_repartition: bool Whether to apply repartitioning to the graph on the Spark leader. """ @@ -121,7 +121,7 @@ class ExecutorConfig: config_filename: str filesystem_type: FilesystemType add_reverse_edges: bool - graph_name: str + graph_name: Optional[str] do_repartition: bool @@ -135,7 +135,7 @@ class GSProcessingArguments: num_output_files: int add_reverse_edges: bool log_level: str - graph_name: str + graph_name: Optional[str] do_repartition: bool @@ -162,7 +162,14 @@ def __init__( self.filesystem_type = executor_config.filesystem_type self.execution_env = executor_config.execution_env self.add_reverse_edges = executor_config.add_reverse_edges - self.graph_name = executor_config.graph_name + # We use the data location as the graph name if a name is not provided + if executor_config.graph_name: + self.graph_name = executor_config.graph_name + else: + derived_name = s3_utils.s3_path_remove_trailing(self.input_prefix).split("/")[-1] + logging.warning("Setting graph name derived from input path: %s", derived_name) + self.graph_name = derived_name + check_graph_name(self.graph_name) self.repartition_on_leader = executor_config.do_repartition # Input config dict using GSProcessing schema self.gsp_config_dict = {} @@ -541,11 +548,14 @@ def parse_args() -> argparse.Namespace: parser.add_argument( "--graph-name", type=str, - help="Name for the graph being processed." - "The graph name must adhere to the Python " - "identifier naming rules with the exception " - "that hyphens (-) are permitted and the name " - "can start with numbers", + help=( + "Name for the graph being processed." + "The graph name must adhere to the Python " + "identifier naming rules with the exception " + "that hyphens (-) are permitted and the name " + "can start with numbers. If not provided, we will use the last " + "section of the input prefix path." + ), required=False, default=None, ) @@ -604,7 +614,6 @@ def main(): level=gsprocessing_args.log_level, format="[GSPROCESSING] %(asctime)s %(levelname)-8s %(message)s", ) - check_graph_name(gsprocessing_args.graph_name) # Determine execution environment if os.path.exists("/opt/ml/config/processingjobconfig.json"): diff --git a/graphstorm-processing/graphstorm_processing/graph_loaders/dist_heterogeneous_loader.py b/graphstorm-processing/graphstorm_processing/graph_loaders/dist_heterogeneous_loader.py index 65ef636e5f..e92e1068cc 100644 --- a/graphstorm-processing/graphstorm_processing/graph_loaders/dist_heterogeneous_loader.py +++ b/graphstorm-processing/graphstorm_processing/graph_loaders/dist_heterogeneous_loader.py @@ -324,11 +324,6 @@ def process_and_write_graph_data( self.timers["process_edge_data"] = perf_counter() - edges_start_time metadata_dict["edge_data"] = edge_data_dict metadata_dict["edges"] = edge_structure_dict - # We use the data location as the graph name, can also take from user? - # TODO: Fix this, take from config? - metadata_dict["graph_name"] = ( - self.graph_name if self.graph_name else self.input_prefix.split("/")[-1] - ) # Ensure output dict has the correct order of keys for edge_type in metadata_dict["edge_type"]: @@ -447,6 +442,8 @@ def _initialize_metadata_dict( metadata_dict["edge_type"] = edge_types metadata_dict["node_type"] = sorted(node_type_set) + metadata_dict["graph_name"] = self.graph_name + return metadata_dict def _finalize_graphinfo_dict(self, metadata_dict: Dict) -> Dict: diff --git a/graphstorm-processing/tests/test_dist_executor.py b/graphstorm-processing/tests/test_dist_executor.py index b88ac38ab6..3290126bb0 100644 --- a/graphstorm-processing/tests/test_dist_executor.py +++ b/graphstorm-processing/tests/test_dist_executor.py @@ -62,8 +62,9 @@ def user_state_categorical_precomp_file_fixture(): os.remove(precomp_file) -def test_dist_executor_run_with_precomputed(tempdir: str, user_state_categorical_precomp_file): - """Test run function with local data""" +@pytest.fixture(name="executor_configuration") +def executor_config_fixture(tempdir: str): + """Create a re-usable ExecutorConfig""" input_path = os.path.join(_ROOT, "resources/small_heterogeneous_graph") executor_configuration = ExecutorConfig( local_config_path=input_path, @@ -79,6 +80,15 @@ def test_dist_executor_run_with_precomputed(tempdir: str, user_state_categorical do_repartition=True, ) + yield executor_configuration + + +def test_dist_executor_run_with_precomputed( + tempdir: str, + user_state_categorical_precomp_file: str, + executor_configuration: ExecutorConfig, +): + """Test run function with local data""" original_precomp_file = user_state_categorical_precomp_file with open(original_precomp_file, "r", encoding="utf-8") as f: @@ -106,23 +116,8 @@ def test_dist_executor_run_with_precomputed(tempdir: str, user_state_categorical # TODO: Verify other metadata files that verify_integ_test_output doesn't check for -def test_merge_input_and_transform_dicts(tempdir: str): +def test_merge_input_and_transform_dicts(executor_configuration: ExecutorConfig): """Test the _merge_config_with_transformations function with hardcoded json data""" - input_path = os.path.join(_ROOT, "resources/small_heterogeneous_graph") - executor_configuration = ExecutorConfig( - local_config_path=input_path, - local_metadata_output_path=tempdir, - input_prefix=input_path, - output_prefix=tempdir, - num_output_files=-1, - config_filename="gsprocessing-config.json", - execution_env=ExecutionEnv.LOCAL, - filesystem_type=FilesystemType.LOCAL, - add_reverse_edges=True, - graph_name="small_heterogeneous_graph", - do_repartition=True, - ) - dist_executor = DistributedExecutor(executor_configuration) pre_comp_transormations = { @@ -148,3 +143,28 @@ def test_merge_input_and_transform_dicts(tempdir: str): if "state" == feature["column"]: transform_for_feature = feature["precomputed_transformation"] assert transform_for_feature["transformation_name"] == "categorical" + + +def test_dist_executor_graph_name(executor_configuration: ExecutorConfig): + """Test cases for graph name""" + + # Ensure we can set a valid graph name + executor_configuration.graph_name = "2024-a_valid_name" + dist_executor = DistributedExecutor(executor_configuration) + assert dist_executor.graph_name == "2024-a_valid_name" + + # Ensure default value is used when graph_name is not provided + executor_configuration.graph_name = None + dist_executor = DistributedExecutor(executor_configuration) + assert dist_executor.graph_name == "small_heterogeneous_graph" + + # Ensure we raise when invalid graph name is provided + with pytest.raises(AssertionError): + executor_configuration.graph_name = "graph.name" + dist_executor = DistributedExecutor(executor_configuration) + + # Ensure a valid default graph name is parsed when the input ends in '/' + executor_configuration.graph_name = None + executor_configuration.input_prefix = executor_configuration.input_prefix + "/" + dist_executor = DistributedExecutor(executor_configuration) + assert dist_executor.graph_name == "small_heterogeneous_graph"