diff --git a/torchtext/datasets/cnndm.py b/torchtext/datasets/cnndm.py index 92b2da8ce1..e94c59e054 100644 --- a/torchtext/datasets/cnndm.py +++ b/torchtext/datasets/cnndm.py @@ -77,12 +77,21 @@ def _hash_urls(s: tuple): def _get_split_list(source: str, split: str): + from torchdata.datapipes.iter import ( # noqa + IterableWrapper, + OnlineReader, + ) url_dp = IterableWrapper([SPLIT_LIST[source + "_" + split]]) online_dp = OnlineReader(url_dp) return online_dp.readlines().map(fn=_hash_urls) def _load_stories(root: str, source: str, split: str): + from torchdata.datapipes.iter import ( # noqa + FileOpener, + IterableWrapper, + GDriveReader, + ) split_list = set(_get_split_list(source, split)) story_dp = IterableWrapper([URL[source]]) cache_compressed_dp = story_dp.on_disk_cache( @@ -135,12 +144,6 @@ def CNNDM(root: str, split: Union[Tuple[str], str]): raise ModuleNotFoundError( "Package `torchdata` not found. Please install following instructions at https://github.com/pytorch/data" ) - from torchdata.datapipes.iter import ( # noqa - FileOpener, - IterableWrapper, - OnlineReader, - GDriveReader, - ) cnn_dp = _load_stories(root, "cnn", split) dailymail_dp = _load_stories(root, "dailymail", split)