diff --git a/python/graphstorm/gconstruct/remap_result.py b/python/graphstorm/gconstruct/remap_result.py index ec3f97da48..2fc0ea606f 100644 --- a/python/graphstorm/gconstruct/remap_result.py +++ b/python/graphstorm/gconstruct/remap_result.py @@ -79,7 +79,14 @@ def write_data_parquet_file(data, file_prefix, col_name_map=None): A mapping from builtin column name to user defined column name. """ if col_name_map is not None: - data = {col_name_map[key]: val for key, val in data.items()} + updated_data = {} + for key, val in data.items(): + if key in col_name_map: + updated_data[col_name_map[key]] = val + else: + updated_data[key] = val + data = updated_data + output_fname = f"{file_prefix}.parquet" write_data_parquet(data, output_fname) @@ -107,7 +114,13 @@ def write_data_csv_file(data, file_prefix, delimiter=",", col_name_map=None): A mapping from builtin column name to user defined column name. """ if col_name_map is not None: - data = {col_name_map[key]: val for key, val in data.items()} + updated_data = {} + for key, val in data.items(): + if key in col_name_map: + updated_data[col_name_map[key]] = val + else: + updated_data[key] = val + data = updated_data output_fname = f"{file_prefix}.csv" csv_data = {} diff --git a/tests/unit-tests/gconstruct/test_remap_result.py b/tests/unit-tests/gconstruct/test_remap_result.py index 1421c8f5b5..6d6c3f59af 100644 --- a/tests/unit-tests/gconstruct/test_remap_result.py +++ b/tests/unit-tests/gconstruct/test_remap_result.py @@ -278,7 +278,116 @@ def test__get_file_range(): assert start == 7 assert end == 10 +def test_write_data_parquet_file(): + data = {"emb": np.random.rand(10, 10), + "nid": np.arange(10), + "pred": np.random.rand(10, 10)} + + def check_write_content(fname, col_names): + # col_names should in order of emb, nid and pred + parq_data = read_data_parquet(fname, col_names) + assert_almost_equal(data["emb"], parq_data[col_names[0]]) + assert_equal(data["nid"], parq_data[col_names[1]]) + assert_almost_equal(data["pred"], parq_data[col_names[2]]) + + # without renaming columns + with tempfile.TemporaryDirectory() as tmpdirname: + file_prefix = os.path.join(tmpdirname, "test") + write_data_parquet_file(data, file_prefix, None) + output_fname = f"{file_prefix}.parquet" + + check_write_content(output_fname, ["emb", "nid", "pred"]) + + # rename all column names + with tempfile.TemporaryDirectory() as tmpdirname: + col_name_map = { + "emb": "new_emb", + "nid": "new_nid", + "pred": "new_pred" + } + file_prefix = os.path.join(tmpdirname, "test") + write_data_parquet_file(data, file_prefix, col_name_map) + output_fname = f"{file_prefix}.parquet" + + check_write_content(output_fname, ["new_emb", "new_nid", "new_pred"]) + + # rename part of column names + with tempfile.TemporaryDirectory() as tmpdirname: + col_name_map = { + "emb": "new_emb", + "nid": "new_nid", + } + file_prefix = os.path.join(tmpdirname, "test") + write_data_parquet_file(data, file_prefix, col_name_map) + output_fname = f"{file_prefix}.parquet" + + check_write_content(output_fname, ["new_emb", "new_nid", "pred"]) + +def test_write_data_csv_file(): + data = {"emb": np.random.rand(10, 10), + "nid": np.arange(10), + "pred": np.random.rand(10, 10)} + + def check_write_content(fname, col_names): + # col_names should in order of emb, nid and pred + csv_data = pd.read_csv(fname, delimiter=",") + # emb + assert col_names[0] in csv_data + csv_emb_data = csv_data[col_names[0]].values.tolist() + csv_emb_data = [d.split(";") for d in csv_emb_data] + csv_emb_data = np.array(csv_emb_data, dtype=np.float32) + assert_almost_equal(data["emb"], csv_emb_data) + + # nid + assert col_names[1] in csv_data + csv_nid_data = csv_data[col_names[1]].values.tolist() + csv_nid_data = np.array(csv_nid_data, dtype=np.int32) + assert_equal(data["nid"], csv_nid_data) + + # pred + assert col_names[2] in csv_data + csv_pred_data = csv_data[col_names[2]].values.tolist() + csv_pred_data = [d.split(";") for d in csv_pred_data] + csv_pred_data = np.array(csv_pred_data, dtype=np.float32) + assert_almost_equal(data["pred"], csv_pred_data) + + # without renaming columns + with tempfile.TemporaryDirectory() as tmpdirname: + file_prefix = os.path.join(tmpdirname, "test") + write_data_csv_file(data, file_prefix, col_name_map=None) + output_fname = f"{file_prefix}.csv" + + check_write_content(output_fname, ["emb", "nid", "pred"]) + + # rename all column names + with tempfile.TemporaryDirectory() as tmpdirname: + col_name_map = { + "emb": "new_emb", + "nid": "new_nid", + "pred": "new_pred" + } + file_prefix = os.path.join(tmpdirname, "test") + write_data_csv_file(data, file_prefix, col_name_map=col_name_map) + output_fname = f"{file_prefix}.csv" + + check_write_content(output_fname, ["new_emb", "new_nid", "new_pred"]) + + # rename part of column names + with tempfile.TemporaryDirectory() as tmpdirname: + col_name_map = { + "emb": "new_emb", + "nid": "new_nid", + } + file_prefix = os.path.join(tmpdirname, "test") + write_data_csv_file(data, file_prefix, col_name_map=col_name_map) + output_fname = f"{file_prefix}.csv" + + check_write_content(output_fname, ["new_emb", "new_nid", "pred"]) + + if __name__ == '__main__': + test_write_data_csv_file() + test_write_data_parquet_file() test__get_file_range() test_worker_remap_edge_pred() test_worker_remap_node_data("pred")