diff --git a/tests/unit-tests/gconstruct/test_remap_result.py b/tests/unit-tests/gconstruct/test_remap_result.py index 84dc3a7da4..6d6c3f59af 100644 --- a/tests/unit-tests/gconstruct/test_remap_result.py +++ b/tests/unit-tests/gconstruct/test_remap_result.py @@ -351,6 +351,39 @@ def check_write_content(fname, col_names): csv_pred_data = np.array(csv_pred_data, dtype=np.float32) assert_almost_equal(data["pred"], csv_pred_data) + # without renaming columns + with tempfile.TemporaryDirectory() as tmpdirname: + file_prefix = os.path.join(tmpdirname, "test") + write_data_csv_file(data, file_prefix, col_name_map=None) + output_fname = f"{file_prefix}.csv" + + check_write_content(output_fname, ["emb", "nid", "pred"]) + + # rename all column names + with tempfile.TemporaryDirectory() as tmpdirname: + col_name_map = { + "emb": "new_emb", + "nid": "new_nid", + "pred": "new_pred" + } + file_prefix = os.path.join(tmpdirname, "test") + write_data_csv_file(data, file_prefix, col_name_map=col_name_map) + output_fname = f"{file_prefix}.csv" + + check_write_content(output_fname, ["new_emb", "new_nid", "new_pred"]) + + # rename part of column names + with tempfile.TemporaryDirectory() as tmpdirname: + col_name_map = { + "emb": "new_emb", + "nid": "new_nid", + } + file_prefix = os.path.join(tmpdirname, "test") + write_data_csv_file(data, file_prefix, col_name_map=col_name_map) + output_fname = f"{file_prefix}.csv" + + check_write_content(output_fname, ["new_emb", "new_nid", "pred"]) + if __name__ == '__main__': test_write_data_csv_file()