Skip to content

Commit

Permalink
Update
Browse files Browse the repository at this point in the history
  • Loading branch information
Xiang Song committed Dec 11, 2023
1 parent b964123 commit b36656d
Showing 1 changed file with 33 additions and 0 deletions.
33 changes: 33 additions & 0 deletions tests/unit-tests/gconstruct/test_remap_result.py
Original file line number Diff line number Diff line change
Expand Up @@ -351,6 +351,39 @@ def check_write_content(fname, col_names):
csv_pred_data = np.array(csv_pred_data, dtype=np.float32)
assert_almost_equal(data["pred"], csv_pred_data)

# without renaming columns
with tempfile.TemporaryDirectory() as tmpdirname:
file_prefix = os.path.join(tmpdirname, "test")
write_data_csv_file(data, file_prefix, col_name_map=None)
output_fname = f"{file_prefix}.csv"

check_write_content(output_fname, ["emb", "nid", "pred"])

# rename all column names
with tempfile.TemporaryDirectory() as tmpdirname:
col_name_map = {
"emb": "new_emb",
"nid": "new_nid",
"pred": "new_pred"
}
file_prefix = os.path.join(tmpdirname, "test")
write_data_csv_file(data, file_prefix, col_name_map=col_name_map)
output_fname = f"{file_prefix}.csv"

check_write_content(output_fname, ["new_emb", "new_nid", "new_pred"])

# rename part of column names
with tempfile.TemporaryDirectory() as tmpdirname:
col_name_map = {
"emb": "new_emb",
"nid": "new_nid",
}
file_prefix = os.path.join(tmpdirname, "test")
write_data_csv_file(data, file_prefix, col_name_map=col_name_map)
output_fname = f"{file_prefix}.csv"

check_write_content(output_fname, ["new_emb", "new_nid", "pred"])


if __name__ == '__main__':
test_write_data_csv_file()
Expand Down

0 comments on commit b36656d

Please sign in to comment.