Skip to content

Commit

Permalink
[Bug fix] Fix a bug when --column-names does not cover all builtin na…
Browse files Browse the repository at this point in the history
…mes (#682)

*Issue #, if available:*
When --column-names does not cover all builtin names, the remap will
crash

Related issue: #674 

*Description of changes:*


By submitting this pull request, I confirm that you can use, modify,
copy, and redistribute this contribution, under the terms of your
choice.

---------

Co-authored-by: Xiang Song <[email protected]>
  • Loading branch information
classicsong and Xiang Song authored Dec 11, 2023
1 parent ca0ed1f commit 5e7cbb8
Show file tree
Hide file tree
Showing 2 changed files with 124 additions and 2 deletions.
17 changes: 15 additions & 2 deletions python/graphstorm/gconstruct/remap_result.py
Original file line number Diff line number Diff line change
Expand Up @@ -79,7 +79,14 @@ def write_data_parquet_file(data, file_prefix, col_name_map=None):
A mapping from builtin column name to user defined column name.
"""
if col_name_map is not None:
data = {col_name_map[key]: val for key, val in data.items()}
updated_data = {}
for key, val in data.items():
if key in col_name_map:
updated_data[col_name_map[key]] = val
else:
updated_data[key] = val
data = updated_data

output_fname = f"{file_prefix}.parquet"
write_data_parquet(data, output_fname)

Expand Down Expand Up @@ -107,7 +114,13 @@ def write_data_csv_file(data, file_prefix, delimiter=",", col_name_map=None):
A mapping from builtin column name to user defined column name.
"""
if col_name_map is not None:
data = {col_name_map[key]: val for key, val in data.items()}
updated_data = {}
for key, val in data.items():
if key in col_name_map:
updated_data[col_name_map[key]] = val
else:
updated_data[key] = val
data = updated_data

output_fname = f"{file_prefix}.csv"
csv_data = {}
Expand Down
109 changes: 109 additions & 0 deletions tests/unit-tests/gconstruct/test_remap_result.py
Original file line number Diff line number Diff line change
Expand Up @@ -278,7 +278,116 @@ def test__get_file_range():
assert start == 7
assert end == 10

def test_write_data_parquet_file():
data = {"emb": np.random.rand(10, 10),
"nid": np.arange(10),
"pred": np.random.rand(10, 10)}

def check_write_content(fname, col_names):
# col_names should in order of emb, nid and pred
parq_data = read_data_parquet(fname, col_names)
assert_almost_equal(data["emb"], parq_data[col_names[0]])
assert_equal(data["nid"], parq_data[col_names[1]])
assert_almost_equal(data["pred"], parq_data[col_names[2]])

# without renaming columns
with tempfile.TemporaryDirectory() as tmpdirname:
file_prefix = os.path.join(tmpdirname, "test")
write_data_parquet_file(data, file_prefix, None)
output_fname = f"{file_prefix}.parquet"

check_write_content(output_fname, ["emb", "nid", "pred"])

# rename all column names
with tempfile.TemporaryDirectory() as tmpdirname:
col_name_map = {
"emb": "new_emb",
"nid": "new_nid",
"pred": "new_pred"
}
file_prefix = os.path.join(tmpdirname, "test")
write_data_parquet_file(data, file_prefix, col_name_map)
output_fname = f"{file_prefix}.parquet"

check_write_content(output_fname, ["new_emb", "new_nid", "new_pred"])

# rename part of column names
with tempfile.TemporaryDirectory() as tmpdirname:
col_name_map = {
"emb": "new_emb",
"nid": "new_nid",
}
file_prefix = os.path.join(tmpdirname, "test")
write_data_parquet_file(data, file_prefix, col_name_map)
output_fname = f"{file_prefix}.parquet"

check_write_content(output_fname, ["new_emb", "new_nid", "pred"])

def test_write_data_csv_file():
data = {"emb": np.random.rand(10, 10),
"nid": np.arange(10),
"pred": np.random.rand(10, 10)}

def check_write_content(fname, col_names):
# col_names should in order of emb, nid and pred
csv_data = pd.read_csv(fname, delimiter=",")
# emb
assert col_names[0] in csv_data
csv_emb_data = csv_data[col_names[0]].values.tolist()
csv_emb_data = [d.split(";") for d in csv_emb_data]
csv_emb_data = np.array(csv_emb_data, dtype=np.float32)
assert_almost_equal(data["emb"], csv_emb_data)

# nid
assert col_names[1] in csv_data
csv_nid_data = csv_data[col_names[1]].values.tolist()
csv_nid_data = np.array(csv_nid_data, dtype=np.int32)
assert_equal(data["nid"], csv_nid_data)

# pred
assert col_names[2] in csv_data
csv_pred_data = csv_data[col_names[2]].values.tolist()
csv_pred_data = [d.split(";") for d in csv_pred_data]
csv_pred_data = np.array(csv_pred_data, dtype=np.float32)
assert_almost_equal(data["pred"], csv_pred_data)

# without renaming columns
with tempfile.TemporaryDirectory() as tmpdirname:
file_prefix = os.path.join(tmpdirname, "test")
write_data_csv_file(data, file_prefix, col_name_map=None)
output_fname = f"{file_prefix}.csv"

check_write_content(output_fname, ["emb", "nid", "pred"])

# rename all column names
with tempfile.TemporaryDirectory() as tmpdirname:
col_name_map = {
"emb": "new_emb",
"nid": "new_nid",
"pred": "new_pred"
}
file_prefix = os.path.join(tmpdirname, "test")
write_data_csv_file(data, file_prefix, col_name_map=col_name_map)
output_fname = f"{file_prefix}.csv"

check_write_content(output_fname, ["new_emb", "new_nid", "new_pred"])

# rename part of column names
with tempfile.TemporaryDirectory() as tmpdirname:
col_name_map = {
"emb": "new_emb",
"nid": "new_nid",
}
file_prefix = os.path.join(tmpdirname, "test")
write_data_csv_file(data, file_prefix, col_name_map=col_name_map)
output_fname = f"{file_prefix}.csv"

check_write_content(output_fname, ["new_emb", "new_nid", "pred"])


if __name__ == '__main__':
test_write_data_csv_file()
test_write_data_parquet_file()
test__get_file_range()
test_worker_remap_edge_pred()
test_worker_remap_node_data("pred")

0 comments on commit 5e7cbb8

Please sign in to comment.