Skip to content

Commit

Permalink
support run op with different executro
Browse files Browse the repository at this point in the history
  • Loading branch information
root committed Apr 26, 2024
1 parent 22e580a commit ef4ab1d
Show file tree
Hide file tree
Showing 3 changed files with 23 additions and 23 deletions.
21 changes: 11 additions & 10 deletions data_juicer/utils/unittest_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -51,30 +51,31 @@ def tearDownClass(cls, hf_model_name=None) -> None:
shutil.rmtree(transformers.TRANSFORMERS_CACHE)

@classmethod
def generate_dataset(cls, data, type='hf'):
def generate_dataset(cls, data, type='standalone'):
"""Generate dataset for a specific executor.
Args:
type (str, optional): `hf` or `ray`. Defaults to "hf".
"""
if type == 'hf':
if type.startswith('standalone'):
return Dataset.from_list(data)
elif type == 'ray':
elif type.startswith('ray'):
return rd.from_items(data)
else:
raise ValueError("Unsupported type")

@classmethod
def run_single_op(cls, dataset, op, type='hf'):
def run_single_op(cls, dataset, op, type='standalone'):
"""Run operator in the specific executor."""
if type == 'hf':
if type.startswith('standalone'):
if isinstance(op, Filter) and Fields.stats not in dataset.features:
# TODO:
# this is a temp solution,
# only add stats when calling filter op
dataset = dataset.add_column(name=Fields.stats,
column=[{}] * dataset.num_rows)
dataset = dataset.map(op.compute_stats)
dataset = dataset.filter(op.process)
dataset = dataset.select_columns(column_names=['text'])
return dataset.to_list()
elif type == 'ray':
pass
elif type.startswith('ray'):
raise ValueError("Unsupported type")
else:
raise ValueError("Unsupported type")
6 changes: 3 additions & 3 deletions tests/ops/filter/test_alphanumeric_filter.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@

class AlphanumericFilterTest(DataJuicerTestCaseBase):

@TEST_TAG("single")
@TEST_TAG("standalone")
def test_case(self):

ds_list = [{
Expand Down Expand Up @@ -40,10 +40,10 @@ def test_case(self):
}]
dataset = DataJuicerTestCaseBase.generate_dataset(ds_list)
op = AlphanumericFilter(min_ratio=0.2, max_ratio=0.9)
result = DataJuicerTestCaseBase.run_single_op(dataset, op)
result = DataJuicerTestCaseBase.run_single_op(dataset, op, AlphanumericFilterTest.current_tag,)
self.assertEqual(result, tgt_list)

@TEST_TAG("single")
@TEST_TAG("standalone")
def test_token_case(self):

ds_list = [{
Expand Down
19 changes: 9 additions & 10 deletions tests/run.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,6 @@
parser.add_argument('--tag', choices=["standalone", "standalone-gpu", "ray", "ray-gpu"],
default="standalone",
help="the tag of tests being run")
parser.add_argument('--list_tests', action='store_true', help='list all tests')
parser.add_argument('--pattern', default='test_*.py', help='test file pattern')
parser.add_argument('--test_dir',
default='tests',
Expand All @@ -36,7 +35,8 @@ def __init__(self, tag=None):
self.tag = tag

def loadTestsFromTestCase(self, testCaseClass):

# set tag to testcase class
setattr(testCaseClass, 'current_tag', self.tag)
test_names = self.getTestCaseNames(testCaseClass)
loaded_suite = self.suiteClass()
for test_name in test_names:
Expand All @@ -46,7 +46,7 @@ def loadTestsFromTestCase(self, testCaseClass):
loaded_suite.addTest(test_case)
return loaded_suite

def gather_test_cases(test_dir, pattern, list_tests, tag):
def gather_test_cases(test_dir, pattern, tag):
test_to_run = unittest.TestSuite()
test_loader = TaggedTestLoader(tag)
discover = test_loader.discover(test_dir, pattern=pattern, top_level_dir=None)
Expand All @@ -57,20 +57,19 @@ def gather_test_cases(test_dir, pattern, list_tests, tag):
for test_case in test_suite:
if type(test_case) in SKIPPED_TESTS.modules.values():
continue
if list_tests:
logger.info(f'Add test case [{str(test_case)}]')
logger.info(f'Add test case [{test_case._testMethodName}]'
f' from {test_case.__class__.__name__}')
test_to_run.addTest(test_case)
return test_to_run


def main():
runner = unittest.TextTestRunner()
test_suite = gather_test_cases(os.path.abspath(args.test_dir),
args.pattern, args.list_tests, args.tag)
if not args.list_tests:
res = runner.run(test_suite)
if not res.wasSuccessful():
exit(1)
args.pattern, args.tag)
res = runner.run(test_suite)
if not res.wasSuccessful():
exit(1)


if __name__ == '__main__':
Expand Down

0 comments on commit ef4ab1d

Please sign in to comment.