Skip to content

Commit

Permalink
EuroSAT: redistribute split files on Hugging Face (#2432)
Browse files Browse the repository at this point in the history
  • Loading branch information
adamjstewart authored Dec 3, 2024
1 parent 04cfff1 commit 2dbb039
Show file tree
Hide file tree
Showing 9 changed files with 34 additions and 51 deletions.
Binary file added tests/data/eurosat/EuroSAT100.zip
Binary file not shown.
2 changes: 2 additions & 0 deletions tests/data/eurosat/eurosat-100-test.txt
Original file line number Diff line number Diff line change
@@ -0,0 +1,2 @@
AnnualCrop_1.tif
Forest_1.tif
2 changes: 2 additions & 0 deletions tests/data/eurosat/eurosat-100-train.txt
Original file line number Diff line number Diff line change
@@ -0,0 +1,2 @@
AnnualCrop_1.tif
Forest_1.tif
2 changes: 2 additions & 0 deletions tests/data/eurosat/eurosat-100-val.txt
Original file line number Diff line number Diff line change
@@ -0,0 +1,2 @@
AnnualCrop_1.tif
Forest_1.tif
2 changes: 2 additions & 0 deletions tests/data/eurosat/eurosat-spatial-test.txt
Original file line number Diff line number Diff line change
@@ -0,0 +1,2 @@
AnnualCrop_1.tif
Forest_1.tif
2 changes: 2 additions & 0 deletions tests/data/eurosat/eurosat-spatial-train.txt
Original file line number Diff line number Diff line change
@@ -0,0 +1,2 @@
AnnualCrop_1.tif
Forest_1.tif
2 changes: 2 additions & 0 deletions tests/data/eurosat/eurosat-spatial-val.txt
Original file line number Diff line number Diff line change
@@ -0,0 +1,2 @@
AnnualCrop_1.tif
Forest_1.tif
36 changes: 6 additions & 30 deletions tests/datasets/test_eurosat.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,34 +32,10 @@ def dataset(
) -> EuroSAT:
base_class: type[EuroSAT] = request.param[0]
split: str = request.param[1]
md5 = 'aa051207b0547daba0ac6af57808d68e'
monkeypatch.setattr(base_class, 'md5', md5)
url = os.path.join('tests', 'data', 'eurosat', 'EuroSATallBands.zip')
url = os.path.join('tests', 'data', 'eurosat') + os.sep
monkeypatch.setattr(base_class, 'url', url)
monkeypatch.setattr(base_class, 'filename', 'EuroSATallBands.zip')
monkeypatch.setattr(
base_class,
'split_urls',
{
'train': os.path.join('tests', 'data', 'eurosat', 'eurosat-train.txt'),
'val': os.path.join('tests', 'data', 'eurosat', 'eurosat-val.txt'),
'test': os.path.join('tests', 'data', 'eurosat', 'eurosat-test.txt'),
},
)
monkeypatch.setattr(
base_class,
'split_md5s',
{
'train': '4af60a00fdfdf8500572ae5360694b71',
'val': '4af60a00fdfdf8500572ae5360694b71',
'test': '4af60a00fdfdf8500572ae5360694b71',
},
)
root = tmp_path
transforms = nn.Identity()
return base_class(
root=root, split=split, transforms=transforms, download=True, checksum=True
)
return base_class(tmp_path, split=split, transforms=transforms, download=True)

def test_getitem(self, dataset: EuroSAT) -> None:
x = dataset[0]
Expand All @@ -84,14 +60,14 @@ def test_add(self, dataset: EuroSAT) -> None:
assert len(ds) == 4

def test_already_downloaded(self, dataset: EuroSAT, tmp_path: Path) -> None:
EuroSAT(root=tmp_path, download=True)
type(dataset)(tmp_path)

def test_already_downloaded_not_extracted(
self, dataset: EuroSAT, tmp_path: Path
) -> None:
shutil.rmtree(dataset.root)
shutil.copy(dataset.url, tmp_path)
EuroSAT(root=tmp_path, download=False)
shutil.copy(dataset.url + dataset.filename, tmp_path)
type(dataset)(tmp_path)

def test_not_downloaded(self, tmp_path: Path) -> None:
with pytest.raises(DatasetNotFoundError, match='Dataset not found'):
Expand All @@ -108,7 +84,7 @@ def test_plot(self, dataset: EuroSAT) -> None:
plt.close()

def test_plot_rgb(self, dataset: EuroSAT, tmp_path: Path) -> None:
dataset = EuroSAT(root=tmp_path, bands=('B03',))
dataset = type(dataset)(tmp_path, bands=('B03',))
with pytest.raises(
RGBBandsMissingError, match='Dataset does not contain some of the RGB bands'
):
Expand Down
37 changes: 16 additions & 21 deletions torchgeo/datasets/eurosat.py
Original file line number Diff line number Diff line change
Expand Up @@ -54,7 +54,7 @@ class EuroSAT(NonGeoClassificationDataset):
* https://ieeexplore.ieee.org/document/8519248
"""

url = 'https://hf.co/datasets/torchgeo/eurosat/resolve/06fd1b090bceecc0ce724cd21578ba7a6664fe8d/EuroSATallBands.zip'
url = 'https://hf.co/datasets/torchgeo/eurosat/resolve/1ce6f1bfb56db63fd91b6ecc466ea67f2509774c/'
filename = 'EuroSATallBands.zip'
md5 = '5ac12b3b2557aa56e1826e981e8e200e'

Expand All @@ -64,10 +64,10 @@ class EuroSAT(NonGeoClassificationDataset):
)

splits = ('train', 'val', 'test')
split_urls: ClassVar[dict[str, str]] = {
'train': 'https://storage.googleapis.com/remote_sensing_representations/eurosat-train.txt',
'val': 'https://storage.googleapis.com/remote_sensing_representations/eurosat-val.txt',
'test': 'https://storage.googleapis.com/remote_sensing_representations/eurosat-test.txt',
split_filenames: ClassVar[dict[str, str]] = {
'train': 'eurosat-train.txt',
'val': 'eurosat-val.txt',
'test': 'eurosat-test.txt',
}
split_md5s: ClassVar[dict[str, str]] = {
'train': '908f142e73d6acdf3f482c5e80d851b1',
Expand Down Expand Up @@ -141,7 +141,7 @@ def __init__(
self._verify()

valid_fns = set()
with open(os.path.join(self.root, f'eurosat-{split}.txt')) as f:
with open(os.path.join(self.root, self.split_filenames[split])) as f:
for fn in f:
valid_fns.add(fn.strip().replace('.jpg', '.tif'))

Expand Down Expand Up @@ -207,16 +207,12 @@ def _verify(self) -> None:
def _download(self) -> None:
"""Download the dataset."""
download_url(
self.url,
self.root,
filename=self.filename,
md5=self.md5 if self.checksum else None,
self.url + self.filename, self.root, md5=self.md5 if self.checksum else None
)
for split in self.splits:
download_url(
self.split_urls[split],
self.url + self.split_filenames[split],
self.root,
filename=f'eurosat-{split}.txt',
md5=self.split_md5s[split] if self.checksum else None,
)

Expand Down Expand Up @@ -305,10 +301,10 @@ class EuroSATSpatial(EuroSAT):
.. versionadded:: 0.6
"""

split_urls: ClassVar[dict[str, str]] = {
'train': 'https://hf.co/datasets/torchgeo/eurosat/resolve/1c11c73a87b40b0485d103231a97829991b8e22f/eurosat-spatial-train.txt',
'val': 'https://hf.co/datasets/torchgeo/eurosat/resolve/1c11c73a87b40b0485d103231a97829991b8e22f/eurosat-spatial-val.txt',
'test': 'https://hf.co/datasets/torchgeo/eurosat/resolve/1c11c73a87b40b0485d103231a97829991b8e22f/eurosat-spatial-test.txt',
split_filenames: ClassVar[dict[str, str]] = {
'train': 'eurosat-spatial-train.txt',
'val': 'eurosat-spatial-val.txt',
'test': 'eurosat-spatial-test.txt',
}
split_md5s: ClassVar[dict[str, str]] = {
'train': '7be3254be39f23ce4d4d144290c93292',
Expand All @@ -328,14 +324,13 @@ class EuroSAT100(EuroSAT):
.. versionadded:: 0.5
"""

url = 'https://hf.co/datasets/torchgeo/eurosat/resolve/06fd1b090bceecc0ce724cd21578ba7a6664fe8d/EuroSAT100.zip'
filename = 'EuroSAT100.zip'
md5 = 'c21c649ba747e86eda813407ef17d596'

split_urls: ClassVar[dict[str, str]] = {
'train': 'https://hf.co/datasets/torchgeo/eurosat/resolve/06fd1b090bceecc0ce724cd21578ba7a6664fe8d/eurosat-train.txt',
'val': 'https://hf.co/datasets/torchgeo/eurosat/resolve/06fd1b090bceecc0ce724cd21578ba7a6664fe8d/eurosat-val.txt',
'test': 'https://hf.co/datasets/torchgeo/eurosat/resolve/06fd1b090bceecc0ce724cd21578ba7a6664fe8d/eurosat-test.txt',
split_filenames: ClassVar[dict[str, str]] = {
'train': 'eurosat-100-train.txt',
'val': 'eurosat-100-val.txt',
'test': 'eurosat-100-test.txt',
}
split_md5s: ClassVar[dict[str, str]] = {
'train': '033d0c23e3a75e3fa79618b0e35fe1c7',
Expand Down

0 comments on commit 2dbb039

Please sign in to comment.