Skip to content

Commit

Permalink
Fix CNNDM dataset tests (pytorch#2246)
Browse files Browse the repository at this point in the history
  • Loading branch information
NicolasHug authored and atalman committed Apr 11, 2024
1 parent de3e711 commit 80e6fbf
Showing 1 changed file with 9 additions and 6 deletions.
15 changes: 9 additions & 6 deletions torchtext/datasets/cnndm.py
Original file line number Diff line number Diff line change
Expand Up @@ -77,12 +77,21 @@ def _hash_urls(s: tuple):


def _get_split_list(source: str, split: str):
from torchdata.datapipes.iter import ( # noqa
IterableWrapper,
OnlineReader,
)
url_dp = IterableWrapper([SPLIT_LIST[source + "_" + split]])
online_dp = OnlineReader(url_dp)
return online_dp.readlines().map(fn=_hash_urls)


def _load_stories(root: str, source: str, split: str):
from torchdata.datapipes.iter import ( # noqa
FileOpener,
IterableWrapper,
GDriveReader,
)
split_list = set(_get_split_list(source, split))
story_dp = IterableWrapper([URL[source]])
cache_compressed_dp = story_dp.on_disk_cache(
Expand Down Expand Up @@ -135,12 +144,6 @@ def CNNDM(root: str, split: Union[Tuple[str], str]):
raise ModuleNotFoundError(
"Package `torchdata` not found. Please install following instructions at https://github.com/pytorch/data"
)
from torchdata.datapipes.iter import ( # noqa
FileOpener,
IterableWrapper,
OnlineReader,
GDriveReader,
)

cnn_dp = _load_stories(root, "cnn", split)
dailymail_dp = _load_stories(root, "dailymail", split)
Expand Down

0 comments on commit 80e6fbf

Please sign in to comment.