Skip to content

Commit

Permalink
Add handling for different column names and paths for mappings
Browse files Browse the repository at this point in the history
  • Loading branch information
thvasilo committed Nov 22, 2023
1 parent 15255d3 commit 9212da3
Show file tree
Hide file tree
Showing 3 changed files with 15 additions and 5 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -55,8 +55,8 @@

FORMAT_NAME = "parquet"
DELIMITER = "" if FORMAT_NAME == "parquet" else ","
NODE_MAPPING_STR = "node_str_id"
NODE_MAPPING_INT = "node_int_id"
NODE_MAPPING_STR = "orig"
NODE_MAPPING_INT = "new"


class DistHeterogeneousGraphLoader(HeterogeneousGraphLoader):
Expand Down
10 changes: 8 additions & 2 deletions python/graphstorm/gconstruct/id_map.py
Original file line number Diff line number Diff line change
Expand Up @@ -78,8 +78,14 @@ class IdReverseMap:
"""
def __init__(self, id_map_prefix):
assert os.path.exists(id_map_prefix), \
f"{id_map_prefix} does not exits."
data = read_data_parquet(id_map_prefix, ["orig", "new"])
f"{id_map_prefix} does not exist."
try:
data = read_data_parquet(id_map_prefix, ["orig", "new"])
except AssertionError:
data = read_data_parquet(id_map_prefix, ["node_str_id", "node_int_id"])
data["new"] = data["node_int_id"]
data["orig"] = data["node_str_id"]

sort_idx = np.argsort(data['new'])
self._ids = data['orig'][sort_idx]

Expand Down
6 changes: 5 additions & 1 deletion python/graphstorm/sagemaker/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -323,9 +323,13 @@ def download_graph(graph_data_s3, graph_name, part_id, world_size,
id_map_files = S3Downloader.list(
raw_node_mapping_prefix_s3, sagemaker_session=sagemaker_session)
for mapping_file in id_map_files:
# The expected layout for mapping files on S3 is:
# The expected layout for GConstruct mapping files on S3 is:
# raw_id_mappings/node_type/part-xxxxx.parquet
ntype = mapping_file.split("/")[-2]
# This is the case where the output was generated by GSProcessing
if ntype == "parquet":
# Then we have raw_id_mappings/node_type/parquet/part-xxxxx.parquet
ntype = mapping_file.split("/")[-3]
os.makedirs(os.path.join(graph_path, "raw_id_mappings", ntype), exist_ok=True)
try:
S3Downloader.download(
Expand Down

0 comments on commit 9212da3

Please sign in to comment.