Skip to content

Commit

Permalink
Revert "[Bug fix] Fix a bug when --column-names does not cover all bu…
Browse files Browse the repository at this point in the history
…iltin names (#682)"

This reverts commit 5e7cbb8.
  • Loading branch information
jalencato committed Dec 15, 2023
1 parent a54c829 commit 54087b9
Show file tree
Hide file tree
Showing 2 changed files with 2 additions and 124 deletions.
17 changes: 2 additions & 15 deletions python/graphstorm/gconstruct/remap_result.py
Original file line number Diff line number Diff line change
Expand Up @@ -79,14 +79,7 @@ def write_data_parquet_file(data, file_prefix, col_name_map=None):
A mapping from builtin column name to user defined column name.
"""
if col_name_map is not None:
updated_data = {}
for key, val in data.items():
if key in col_name_map:
updated_data[col_name_map[key]] = val
else:
updated_data[key] = val
data = updated_data

data = {col_name_map[key]: val for key, val in data.items()}
output_fname = f"{file_prefix}.parquet"
write_data_parquet(data, output_fname)

Expand Down Expand Up @@ -114,13 +107,7 @@ def write_data_csv_file(data, file_prefix, delimiter=",", col_name_map=None):
A mapping from builtin column name to user defined column name.
"""
if col_name_map is not None:
updated_data = {}
for key, val in data.items():
if key in col_name_map:
updated_data[col_name_map[key]] = val
else:
updated_data[key] = val
data = updated_data
data = {col_name_map[key]: val for key, val in data.items()}

output_fname = f"{file_prefix}.csv"
csv_data = {}
Expand Down
109 changes: 0 additions & 109 deletions tests/unit-tests/gconstruct/test_remap_result.py
Original file line number Diff line number Diff line change
Expand Up @@ -278,116 +278,7 @@ def test__get_file_range():
assert start == 7
assert end == 10

def test_write_data_parquet_file():
data = {"emb": np.random.rand(10, 10),
"nid": np.arange(10),
"pred": np.random.rand(10, 10)}

def check_write_content(fname, col_names):
# col_names should in order of emb, nid and pred
parq_data = read_data_parquet(fname, col_names)
assert_almost_equal(data["emb"], parq_data[col_names[0]])
assert_equal(data["nid"], parq_data[col_names[1]])
assert_almost_equal(data["pred"], parq_data[col_names[2]])

# without renaming columns
with tempfile.TemporaryDirectory() as tmpdirname:
file_prefix = os.path.join(tmpdirname, "test")
write_data_parquet_file(data, file_prefix, None)
output_fname = f"{file_prefix}.parquet"

check_write_content(output_fname, ["emb", "nid", "pred"])

# rename all column names
with tempfile.TemporaryDirectory() as tmpdirname:
col_name_map = {
"emb": "new_emb",
"nid": "new_nid",
"pred": "new_pred"
}
file_prefix = os.path.join(tmpdirname, "test")
write_data_parquet_file(data, file_prefix, col_name_map)
output_fname = f"{file_prefix}.parquet"

check_write_content(output_fname, ["new_emb", "new_nid", "new_pred"])

# rename part of column names
with tempfile.TemporaryDirectory() as tmpdirname:
col_name_map = {
"emb": "new_emb",
"nid": "new_nid",
}
file_prefix = os.path.join(tmpdirname, "test")
write_data_parquet_file(data, file_prefix, col_name_map)
output_fname = f"{file_prefix}.parquet"

check_write_content(output_fname, ["new_emb", "new_nid", "pred"])

def test_write_data_csv_file():
data = {"emb": np.random.rand(10, 10),
"nid": np.arange(10),
"pred": np.random.rand(10, 10)}

def check_write_content(fname, col_names):
# col_names should in order of emb, nid and pred
csv_data = pd.read_csv(fname, delimiter=",")
# emb
assert col_names[0] in csv_data
csv_emb_data = csv_data[col_names[0]].values.tolist()
csv_emb_data = [d.split(";") for d in csv_emb_data]
csv_emb_data = np.array(csv_emb_data, dtype=np.float32)
assert_almost_equal(data["emb"], csv_emb_data)

# nid
assert col_names[1] in csv_data
csv_nid_data = csv_data[col_names[1]].values.tolist()
csv_nid_data = np.array(csv_nid_data, dtype=np.int32)
assert_equal(data["nid"], csv_nid_data)

# pred
assert col_names[2] in csv_data
csv_pred_data = csv_data[col_names[2]].values.tolist()
csv_pred_data = [d.split(";") for d in csv_pred_data]
csv_pred_data = np.array(csv_pred_data, dtype=np.float32)
assert_almost_equal(data["pred"], csv_pred_data)

# without renaming columns
with tempfile.TemporaryDirectory() as tmpdirname:
file_prefix = os.path.join(tmpdirname, "test")
write_data_csv_file(data, file_prefix, col_name_map=None)
output_fname = f"{file_prefix}.csv"

check_write_content(output_fname, ["emb", "nid", "pred"])

# rename all column names
with tempfile.TemporaryDirectory() as tmpdirname:
col_name_map = {
"emb": "new_emb",
"nid": "new_nid",
"pred": "new_pred"
}
file_prefix = os.path.join(tmpdirname, "test")
write_data_csv_file(data, file_prefix, col_name_map=col_name_map)
output_fname = f"{file_prefix}.csv"

check_write_content(output_fname, ["new_emb", "new_nid", "new_pred"])

# rename part of column names
with tempfile.TemporaryDirectory() as tmpdirname:
col_name_map = {
"emb": "new_emb",
"nid": "new_nid",
}
file_prefix = os.path.join(tmpdirname, "test")
write_data_csv_file(data, file_prefix, col_name_map=col_name_map)
output_fname = f"{file_prefix}.csv"

check_write_content(output_fname, ["new_emb", "new_nid", "pred"])


if __name__ == '__main__':
test_write_data_csv_file()
test_write_data_parquet_file()
test__get_file_range()
test_worker_remap_edge_pred()
test_worker_remap_node_data("pred")

0 comments on commit 54087b9

Please sign in to comment.