Skip to content

Commit

Permalink
Fix GConstruct file parsing when input is a list with a single wildca…
Browse files Browse the repository at this point in the history
…rd path (#1090)

*Issue #, if available:*

* Reported by customer

*Description of changes:*

* Previously if a user passed in a list of input files with a single
wildcard path the path was treated as a regular (non-wildcard) path,
leading to file not found error.
* We make the file parsing consistent by using the `expand_wildcard` and
always passing a list of str to it.
* Add a test case where we used to fail

By submitting this pull request, I confirm that you can use, modify,
copy, and redistribute this contribution, under the terms of your
choice.
  • Loading branch information
thvasilo authored Nov 14, 2024
1 parent 7333fa6 commit ab9d143
Show file tree
Hide file tree
Showing 2 changed files with 19 additions and 18 deletions.
19 changes: 8 additions & 11 deletions python/graphstorm/gconstruct/file_io.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@
import json
import os
import logging
from typing import Union, List

import pyarrow.parquet as pq
import pyarrow as pa
Expand Down Expand Up @@ -76,9 +77,9 @@ def read_index(split_info):
return res[0], res[1], res[2]


def expand_wildcard(data_files):
def expand_wildcard(data_files: List[str]) -> List[str]:
"""
Expand the wildcard to the actual file lists.
Expand a list of paths that can contain wildcards to the actual file lists.
Parameters
----------
Expand Down Expand Up @@ -490,7 +491,7 @@ def _parse_file_format(conf, is_node, in_mem):
parse_node_file_format = partial(_parse_file_format, is_node=True)
parse_edge_file_format = partial(_parse_file_format, is_node=False)

def get_in_files(in_files):
def get_in_files(in_files: Union[List[str], str]) -> List[str]:
""" Get the input files.
The input file string may contains a wildcard. This function
Expand All @@ -505,16 +506,12 @@ def get_in_files(in_files):
-------
a list of str : the full name of input files.
"""
# If the input file has a wildcard, get all files that matches the input file name.
if '*' in in_files:
in_file_list = sorted(glob.glob(in_files))
assert len(in_file_list) > 0, \
f"There is no file matching {in_files} pattern"
in_files = in_file_list
# This is a single file.
elif not isinstance(in_files, list):
# Convert single str to list of str if needed
if isinstance(in_files, str):
in_files = [in_files]

in_files = expand_wildcard(in_files)

# Verify the existence of the input files.
for in_file in in_files:
assert os.path.isfile(in_file), \
Expand Down
18 changes: 11 additions & 7 deletions tests/unit-tests/gconstruct/test_gconstruct_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,10 +19,11 @@

import dgl
import numpy as np
import torch as th
import pytest
import pandas as pd
import pyarrow.parquet as pq
import pyarrow as pa
import torch as th

from numpy.testing import assert_almost_equal

Expand Down Expand Up @@ -292,22 +293,25 @@ def test_get_in_files():
for i in range(10):
data = {"test": np.random.rand(10)}
write_data_parquet(data, files[i])
files.sort()

# Test single string wildcard path
in_files = get_in_files(os.path.join(tmpdirname,"*.parquet"))
assert len(in_files) == 10
files.sort()

assert files == in_files

# Test list of wildcard paths with a single element
in_files = get_in_files([os.path.join(tmpdirname, "*.parquet")])
assert len(in_files) == 10
assert files == in_files

in_files = get_in_files(os.path.join(tmpdirname,"test9.parquet"))
assert len(in_files) == 1
assert os.path.join(tmpdirname,"test9.parquet") == in_files[0]

pass_test = False
try:
with pytest.raises(AssertionError):
in_files = get_in_files(os.path.join(tmpdirname,"test10.parquet"))
except:
pass_test = True
assert pass_test

def test_get_hard_edge_negs_feats():
hard_trans0 = HardEdgeDstNegativeTransform("hard_neg", "hard_neg")
Expand Down

0 comments on commit ab9d143

Please sign in to comment.