diff --git a/tests/data/test_icl_datasets.py b/tests/data/test_icl_datasets.py index d25671159a..1d545c251c 100644 --- a/tests/data/test_icl_datasets.py +++ b/tests/data/test_icl_datasets.py @@ -37,7 +37,7 @@ def run_test( for i, e in enumerate(evaluators): batch = next(e.dataloader.dataloader.__iter__()) # Check that the dataloader is the correct length for the first task. - if i == 0: + if i == 0 and hasattr(e.dataloader.dataloader, '__len__'): if eval_drop_last: assert len( e.dataloader.dataloader,