Skip to content

Commit

Permalink
Rename node_id_mappings to raw_id_mappings
Browse files Browse the repository at this point in the history
  • Loading branch information
thvasilo committed Nov 17, 2023
1 parent c06b4c2 commit e9b58b6
Show file tree
Hide file tree
Showing 9 changed files with 76 additions and 42 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -489,13 +489,13 @@ def _add_node_mappings_to_metadata(self, metadata_dict: Dict) -> Dict:
"""
Adds node mappings to the metadata dict that is eventually written to disk.
"""
metadata_dict["node_id_mappings"] = {}
metadata_dict["raw_id_mappings"] = {}
for node_type in metadata_dict["node_type"]:
node_mapping_metadata_dict = {
"format": {"name": "parquet", "delimiter": ""},
"data": self.node_mapping_paths[node_type],
}
metadata_dict["node_id_mappings"][node_type] = node_mapping_metadata_dict
metadata_dict["raw_id_mappings"][node_type] = node_mapping_metadata_dict

return metadata_dict

Expand Down Expand Up @@ -761,7 +761,7 @@ def _write_nodeid_mapping_and_update_state(
Also modifies the loader's state to add the mapping path to
the node_mapping_paths member variable.
"""
mapping_output_path = f"{self.output_prefix}/node_id_mappings/{node_type}"
mapping_output_path = f"{self.output_prefix}/raw_id_mappings/{node_type}"

# TODO: For node-file-exists path: Test to see if it's better to keep these in memory
# until needed instead of writing out now i.e. we can maintain a dict of DFs instead
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -52,7 +52,8 @@ def __init__(self, metadata_dict: dict, output_prefix: str, filesystem_type: str
# Increase default retries because we are likely to run into
# throttling errors
self.pyarrow_fs = fs.S3FileSystem(
region=bucket_region, retry_strategy=fs.AwsStandardS3RetryStrategy(max_attempts=10)
region=bucket_region,
retry_strategy=fs.AwsStandardS3RetryStrategy(max_attempts=10),
)
else:
self.pyarrow_fs = fs.LocalFileSystem()
Expand All @@ -79,7 +80,7 @@ def add_row_counts_to_metadata(self, metadata_dict: dict) -> dict:
self._add_counts_for_features(top_level_key="edge_data", edge_or_node_type_key="edge_type")

all_node_mapping_counts = self._add_counts_for_graph_structure(
top_level_key="node_id_mappings", edge_or_node_type_key="node_type"
top_level_key="raw_id_mappings", edge_or_node_type_key="node_type"
)
self._add_counts_for_features(top_level_key="node_data", edge_or_node_type_key="node_type")

Expand Down Expand Up @@ -179,7 +180,7 @@ def _add_counts_for_graph_structure(
top_level_key : str
The top level key that refers to the structure we'll be getting
counts for, can be "edges" to get counts for edges structure,
or "node_id_mappings" to get counts for node mappings.
or "raw_id_mappings" to get counts for node mappings.
edge_or_node_type_key : str
The secondary key we use to iterate over structure types,
can be 'edge_type' or 'node_type'.
Expand All @@ -191,8 +192,8 @@ def _add_counts_for_graph_structure(
inner list is a row count.
"""
# We use the order of types in edge_type and node_type to create the counts
assert top_level_key in {"edges", "node_id_mappings"}, (
"top_level_key needs to be one of 'edges', 'node_id_mappings' " f"got {top_level_key}"
assert top_level_key in {"edges", "raw_id_mappings"}, (
"top_level_key needs to be one of 'edges', 'raw_id_mappings' " f"got {top_level_key}"
)
assert edge_or_node_type_key in {"edge_type", "node_type"}, (
"edge_or_node_type_key needs to be one of 'edge_type', 'node_type' "
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -238,7 +238,9 @@ def test_parquet_input_multi_categorical(spark: SparkSession, check_df_schema):
df_parquet = spark.read.parquet(parquet_path)

# Show the DataFrame loaded from the Parquet file
dist_categorical_transormation = DistMultiCategoryTransformation(cols=["names"], separator=None)
dist_categorical_transormation = DistMultiCategoryTransformation(
cols=["names"], separator=None
)

transformed_df = dist_categorical_transormation.apply(df_parquet)
check_df_schema(transformed_df)
Expand Down
43 changes: 34 additions & 9 deletions graphstorm-processing/tests/test_dist_heterogenous_loader.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,13 +31,21 @@
NODE_MAPPING_STR,
)
from graphstorm_processing.data_transformations.dist_label_loader import SplitRates
from graphstorm_processing.config.label_config_base import NodeLabelConfig, EdgeLabelConfig
from graphstorm_processing.config.label_config_base import (
NodeLabelConfig,
EdgeLabelConfig,
)
from graphstorm_processing.config.config_parser import (
create_config_objects,
EdgeConfig,
)
from graphstorm_processing.config.config_conversion import GConstructConfigConverter
from graphstorm_processing.constants import COLUMN_NAME, MIN_VALUE, MAX_VALUE, VALUE_COUNTS
from graphstorm_processing.constants import (
COLUMN_NAME,
MIN_VALUE,
MAX_VALUE,
VALUE_COUNTS,
)

pytestmark = pytest.mark.usefixtures("spark")
_ROOT = os.path.abspath(os.path.dirname(__file__))
Expand Down Expand Up @@ -176,7 +184,7 @@ def verify_integ_test_output(
# TODO: The following Parquet reads assume there's only one file in the output
for node_type in metadata["node_type"]:
nrows = pq.ParquetFile(
os.path.join(loader.output_path, metadata["node_id_mappings"][node_type]["data"][0])
os.path.join(loader.output_path, metadata["raw_id_mappings"][node_type]["data"][0])
).metadata.num_rows
assert nrows == expected_node_counts[node_type]

Expand Down Expand Up @@ -238,7 +246,10 @@ def test_load_dist_heterogen_node_class(dghl_loader: DistHeterogeneousGraphLoade
"task_type": "node_class",
"label_map": {"male": 0, "female": 1},
"label_properties": {
"user": {"COLUMN_NAME": "gender", "VALUE_COUNTS": {"male": 3, "female": 1, "null": 1}}
"user": {
"COLUMN_NAME": "gender",
"VALUE_COUNTS": {"male": 3, "female": 1, "null": 1},
}
},
}

Expand All @@ -252,12 +263,16 @@ def test_load_dist_heterogen_node_class(dghl_loader: DistHeterogeneousGraphLoade
assert metadata["node_data"][node_type].keys() == expected_node_data[node_type]


def test_load_dist_hgl_without_labels(dghl_loader_no_label: DistHeterogeneousGraphLoader):
def test_load_dist_hgl_without_labels(
dghl_loader_no_label: DistHeterogeneousGraphLoader,
):
"""End 2 end test when no labels are provided"""
dghl_loader_no_label.load()

with open(
os.path.join(dghl_loader_no_label.output_path, "metadata.json"), "r", encoding="utf-8"
os.path.join(dghl_loader_no_label.output_path, "metadata.json"),
"r",
encoding="utf-8",
) as mfile:
metadata = json.load(mfile)

Expand Down Expand Up @@ -292,7 +307,11 @@ def test_write_edge_structure_no_reverse_edges(
dghl_loader_no_reverse_edges.create_node_id_maps_from_edges(edge_configs, missing_node_types)

edge_dict: Dict[str, Dict] = {
"data": {"format": "csv", "files": ["edges/user-rated-movie.csv"], "separator": ","},
"data": {
"format": "csv",
"files": ["edges/user-rated-movie.csv"],
"separator": ",",
},
"source": {"column": "~from", "type": "user"},
"relation": {"type": "rated"},
"dest": {"column": "~to", "type": "movie"},
Expand All @@ -307,7 +326,9 @@ def test_write_edge_structure_no_reverse_edges(


def test_create_all_mapppings_from_edges(
spark: SparkSession, data_configs_with_label, dghl_loader: DistHeterogeneousGraphLoader
spark: SparkSession,
data_configs_with_label,
dghl_loader: DistHeterogeneousGraphLoader,
):
"""Test creating all node mappings only from edge files"""
edge_configs = data_configs_with_label["edges"]
Expand Down Expand Up @@ -495,7 +516,11 @@ def test_create_split_files_from_rates(
)

ensure_masks_are_correct(
train_mask_df, test_mask_df, val_mask_df, non_missing_data_points, split_rates.tolist()
train_mask_df,
test_mask_df,
val_mask_df,
non_missing_data_points,
split_rates.tolist(),
)


Expand Down
12 changes: 6 additions & 6 deletions python/graphstorm/gconstruct/construct_graph.py
Original file line number Diff line number Diff line change
Expand Up @@ -667,18 +667,18 @@ def process_graph(args):
if len(output_format) == 1 and output_format[0] == "DistDGL" else None
convert2ext_mem = ExtMemArrayMerger(ext_mem_workspace, args.ext_mem_feat_size)

node_id_map, node_data, node_label_stats = \
raw_node_id_maps, node_data, node_label_stats = \
process_node_data(process_confs['nodes'], convert2ext_mem,
args.remap_node_id, ext_mem_workspace,
num_processes=num_processes_for_nodes)
sys_tracker.check('Process the node data')
edges, edge_data, edge_label_stats = \
process_edge_data(process_confs['edges'], node_id_map,
process_edge_data(process_confs['edges'], raw_node_id_maps,
convert2ext_mem, ext_mem_workspace,
num_processes=num_processes_for_edges,
skip_nonexist_edges=args.skip_nonexist_edges)
sys_tracker.check('Process the edge data')
num_nodes = {ntype: len(node_id_map[ntype]) for ntype in node_id_map}
num_nodes = {ntype: len(raw_node_id_maps[ntype]) for ntype in raw_node_id_maps}
if args.output_conf_file is not None:
# Save the new config file.
with open(args.output_conf_file, "w", encoding="utf8") as outfile:
Expand Down Expand Up @@ -742,9 +742,9 @@ def process_graph(args):
if len(edge_label_stats) > 0:
save_edge_label_stats(args.output_dir, edge_label_stats)

for ntype, node_id_map in node_id_map.items():
map_prefix = os.path.join(args.output_dir, "node_id_mappings", ntype)
node_id_map.save(map_prefix)
for ntype, raw_id_map in raw_node_id_maps.items():
map_prefix = os.path.join(args.output_dir, "raw_id_mappings", ntype)
raw_id_map.save(map_prefix)
logging.info("Graph construction generated new node IDs for '%s'. " + \
"The ID map is saved under %s.", ntype, map_prefix)

Expand Down
4 changes: 2 additions & 2 deletions python/graphstorm/gconstruct/remap_result.py
Original file line number Diff line number Diff line change
Expand Up @@ -504,7 +504,7 @@ def _parse_gs_config(config):
list of str: etypes that have prediction results
"""
part_config = config.part_config
node_id_mapping = os.path.join(os.path.dirname(part_config), "node_id_mappings")
node_id_mapping = os.path.join(os.path.dirname(part_config), "raw_id_mappings")
predict_dir = config.save_prediction_path
emb_dir = config.save_embed_path
task_type = config.task_type
Expand Down Expand Up @@ -690,7 +690,7 @@ def main(args, gs_config_args):
id_maps[ntype] = \
IdReverseMap(mapping_prefix)
else:
logging.fatal("ID mapping prefix %s does not exist, skipping remapping",
logging.warning("ID mapping prefix %s does not exist, skipping remapping",
mapping_prefix)
sys.exit(0)

Expand Down
30 changes: 18 additions & 12 deletions python/graphstorm/sagemaker/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -212,7 +212,7 @@ def download_model(model_artifact_s3, model_path, sagemaker_session):

def download_graph(graph_data_s3, graph_name, part_id, world_size,
local_path, sagemaker_session,
node_mapping_prefix_s3=None):
raw_node_mapping_prefix_s3=None):
""" download graph data
Parameters
Expand All @@ -229,7 +229,7 @@ def download_graph(graph_data_s3, graph_name, part_id, world_size,
Path to store graph data
sagemaker_session: sagemaker.session.Session
sagemaker_session to run download
node_mapping_prefix_s3: str, optional
raw_node_mapping_prefix_s3: str, optional
S3 prefix to where the node_id_mapping data are stored
Return
Expand All @@ -250,14 +250,14 @@ def download_graph(graph_data_s3, graph_name, part_id, world_size,

# By default we assume the node mappings exist
# under the same path as the rest of the graph data
if not node_mapping_prefix_s3:
node_mapping_prefix_s3 = f"{graph_data_s3}/node_id_mappings"
if not raw_node_mapping_prefix_s3:
raw_node_mapping_prefix_s3 = f"{graph_data_s3}/raw_id_mappings"
else:
node_mapping_prefix_s3 = (
node_mapping_prefix_s3[:-1] if node_mapping_prefix_s3.endswith('/')
else node_mapping_prefix_s3)
assert node_mapping_prefix_s3.endswith("node_id_mappings"), \
"node_mapping_prefix_s3 must end with 'node_id_mappings'"
raw_node_mapping_prefix_s3 = (
raw_node_mapping_prefix_s3[:-1] if raw_node_mapping_prefix_s3.endswith('/')
else raw_node_mapping_prefix_s3)
assert raw_node_mapping_prefix_s3.endswith("raw_id_mappings"), \
"node_mapping_prefix_s3 must end with 'raw_id_mappings'"


# We split on '/' to get the bucket, as it's always the third split element in an S3 URI
Expand Down Expand Up @@ -320,11 +320,17 @@ def download_graph(graph_data_s3, graph_name, part_id, world_size,
logging.info("node id mapping file %s does not exist", s3_path)

# Try to get GraphStorm ID to Original ID remapping files if any
id_map_files = S3Downloader.list(node_mapping_prefix_s3, sagemaker_session=sagemaker_session)
id_map_files = S3Downloader.list(
raw_node_mapping_prefix_s3, sagemaker_session=sagemaker_session)
for mapping_file in id_map_files:
# The expected layout for mapping files on S3 is:
# raw_id_mappings/node_type/part-xxxxx.parquet
ntype = mapping_file.split("/")[-2]
try:
S3Downloader.download(mapping_file, graph_path,
sagemaker_session=sagemaker_session)
S3Downloader.download(
mapping_file,
os.path.join(graph_path, "raw_id_mappings", ntype),
sagemaker_session=sagemaker_session)
except Exception: # pylint: disable=broad-except
logging.warning("Could not download node id remap file %s",
mapping_file)
Expand Down
6 changes: 3 additions & 3 deletions tests/end2end-tests/data_process/test_data.py
Original file line number Diff line number Diff line change
Expand Up @@ -57,9 +57,9 @@ def read_data_parquet(data_file):
else:
raise ValueError('Invalid graph format: {}'.format(args.graph_format))

node1_map = read_data_parquet(os.path.join(out_dir, "node_id_mappings", "node1"))
node1_map = read_data_parquet(os.path.join(out_dir, "raw_id_mappings", "node1"))
reverse_node1_map = {val: key for key, val in zip(node1_map['orig'], node1_map['new'])}
node3_map = read_data_parquet(os.path.join(out_dir, "node_id_mappings", "node3"))
node3_map = read_data_parquet(os.path.join(out_dir, "raw_id_mappings", "node3"))
reverse_node3_map = {val: key for key, val in zip(node3_map['orig'], node3_map['new'])}

# Test the first node data
Expand Down Expand Up @@ -133,7 +133,7 @@ def read_data_parquet(data_file):
assert len(node_conf["features"][0]["transform"]["mapping"]) == 10

# id remap for node4 exists
assert os.path.isdir(os.path.join(out_dir, "node_id_mappings", "node4"))
assert os.path.isdir(os.path.join(out_dir, "raw_id_mappings", "node4"))

# Test the edge data of edge type 1
src_ids, dst_ids = g.edges(etype=('node1', 'relation1', 'node2'))
Expand Down
2 changes: 1 addition & 1 deletion tests/unit-tests/gconstruct/test_construct_graph.py
Original file line number Diff line number Diff line change
Expand Up @@ -1723,4 +1723,4 @@ def test_gc():
test_label()
test_multicolumn(None)
test_multicolumn("/tmp/")
test_feature_wrapper()
test_feature_wrapper()

0 comments on commit e9b58b6

Please sign in to comment.