diff --git a/tests/data/eurosat/EuroSAT100.zip b/tests/data/eurosat/EuroSAT100.zip new file mode 100644 index 00000000000..ed2eb18d324 Binary files /dev/null and b/tests/data/eurosat/EuroSAT100.zip differ diff --git a/tests/data/eurosat/eurosat-100-test.txt b/tests/data/eurosat/eurosat-100-test.txt new file mode 100644 index 00000000000..debeff4c852 --- /dev/null +++ b/tests/data/eurosat/eurosat-100-test.txt @@ -0,0 +1,2 @@ +AnnualCrop_1.tif +Forest_1.tif diff --git a/tests/data/eurosat/eurosat-100-train.txt b/tests/data/eurosat/eurosat-100-train.txt new file mode 100644 index 00000000000..debeff4c852 --- /dev/null +++ b/tests/data/eurosat/eurosat-100-train.txt @@ -0,0 +1,2 @@ +AnnualCrop_1.tif +Forest_1.tif diff --git a/tests/data/eurosat/eurosat-100-val.txt b/tests/data/eurosat/eurosat-100-val.txt new file mode 100644 index 00000000000..debeff4c852 --- /dev/null +++ b/tests/data/eurosat/eurosat-100-val.txt @@ -0,0 +1,2 @@ +AnnualCrop_1.tif +Forest_1.tif diff --git a/tests/data/eurosat/eurosat-spatial-test.txt b/tests/data/eurosat/eurosat-spatial-test.txt new file mode 100644 index 00000000000..debeff4c852 --- /dev/null +++ b/tests/data/eurosat/eurosat-spatial-test.txt @@ -0,0 +1,2 @@ +AnnualCrop_1.tif +Forest_1.tif diff --git a/tests/data/eurosat/eurosat-spatial-train.txt b/tests/data/eurosat/eurosat-spatial-train.txt new file mode 100644 index 00000000000..debeff4c852 --- /dev/null +++ b/tests/data/eurosat/eurosat-spatial-train.txt @@ -0,0 +1,2 @@ +AnnualCrop_1.tif +Forest_1.tif diff --git a/tests/data/eurosat/eurosat-spatial-val.txt b/tests/data/eurosat/eurosat-spatial-val.txt new file mode 100644 index 00000000000..debeff4c852 --- /dev/null +++ b/tests/data/eurosat/eurosat-spatial-val.txt @@ -0,0 +1,2 @@ +AnnualCrop_1.tif +Forest_1.tif diff --git a/tests/datasets/test_eurosat.py b/tests/datasets/test_eurosat.py index a9574505ade..282ff581931 100644 --- a/tests/datasets/test_eurosat.py +++ b/tests/datasets/test_eurosat.py @@ -32,34 +32,10 @@ def dataset( ) -> EuroSAT: base_class: type[EuroSAT] = request.param[0] split: str = request.param[1] - md5 = 'aa051207b0547daba0ac6af57808d68e' - monkeypatch.setattr(base_class, 'md5', md5) - url = os.path.join('tests', 'data', 'eurosat', 'EuroSATallBands.zip') + url = os.path.join('tests', 'data', 'eurosat') + os.sep monkeypatch.setattr(base_class, 'url', url) - monkeypatch.setattr(base_class, 'filename', 'EuroSATallBands.zip') - monkeypatch.setattr( - base_class, - 'split_urls', - { - 'train': os.path.join('tests', 'data', 'eurosat', 'eurosat-train.txt'), - 'val': os.path.join('tests', 'data', 'eurosat', 'eurosat-val.txt'), - 'test': os.path.join('tests', 'data', 'eurosat', 'eurosat-test.txt'), - }, - ) - monkeypatch.setattr( - base_class, - 'split_md5s', - { - 'train': '4af60a00fdfdf8500572ae5360694b71', - 'val': '4af60a00fdfdf8500572ae5360694b71', - 'test': '4af60a00fdfdf8500572ae5360694b71', - }, - ) - root = tmp_path transforms = nn.Identity() - return base_class( - root=root, split=split, transforms=transforms, download=True, checksum=True - ) + return base_class(tmp_path, split=split, transforms=transforms, download=True) def test_getitem(self, dataset: EuroSAT) -> None: x = dataset[0] @@ -84,14 +60,14 @@ def test_add(self, dataset: EuroSAT) -> None: assert len(ds) == 4 def test_already_downloaded(self, dataset: EuroSAT, tmp_path: Path) -> None: - EuroSAT(root=tmp_path, download=True) + type(dataset)(tmp_path) def test_already_downloaded_not_extracted( self, dataset: EuroSAT, tmp_path: Path ) -> None: shutil.rmtree(dataset.root) - shutil.copy(dataset.url, tmp_path) - EuroSAT(root=tmp_path, download=False) + shutil.copy(dataset.url + dataset.filename, tmp_path) + type(dataset)(tmp_path) def test_not_downloaded(self, tmp_path: Path) -> None: with pytest.raises(DatasetNotFoundError, match='Dataset not found'): @@ -108,7 +84,7 @@ def test_plot(self, dataset: EuroSAT) -> None: plt.close() def test_plot_rgb(self, dataset: EuroSAT, tmp_path: Path) -> None: - dataset = EuroSAT(root=tmp_path, bands=('B03',)) + dataset = type(dataset)(tmp_path, bands=('B03',)) with pytest.raises( RGBBandsMissingError, match='Dataset does not contain some of the RGB bands' ): diff --git a/torchgeo/datasets/eurosat.py b/torchgeo/datasets/eurosat.py index 26c1c860cf8..45c0d708f92 100644 --- a/torchgeo/datasets/eurosat.py +++ b/torchgeo/datasets/eurosat.py @@ -54,7 +54,7 @@ class EuroSAT(NonGeoClassificationDataset): * https://ieeexplore.ieee.org/document/8519248 """ - url = 'https://hf.co/datasets/torchgeo/eurosat/resolve/06fd1b090bceecc0ce724cd21578ba7a6664fe8d/EuroSATallBands.zip' + url = 'https://hf.co/datasets/torchgeo/eurosat/resolve/1ce6f1bfb56db63fd91b6ecc466ea67f2509774c/' filename = 'EuroSATallBands.zip' md5 = '5ac12b3b2557aa56e1826e981e8e200e' @@ -64,10 +64,10 @@ class EuroSAT(NonGeoClassificationDataset): ) splits = ('train', 'val', 'test') - split_urls: ClassVar[dict[str, str]] = { - 'train': 'https://storage.googleapis.com/remote_sensing_representations/eurosat-train.txt', - 'val': 'https://storage.googleapis.com/remote_sensing_representations/eurosat-val.txt', - 'test': 'https://storage.googleapis.com/remote_sensing_representations/eurosat-test.txt', + split_filenames: ClassVar[dict[str, str]] = { + 'train': 'eurosat-train.txt', + 'val': 'eurosat-val.txt', + 'test': 'eurosat-test.txt', } split_md5s: ClassVar[dict[str, str]] = { 'train': '908f142e73d6acdf3f482c5e80d851b1', @@ -141,7 +141,7 @@ def __init__( self._verify() valid_fns = set() - with open(os.path.join(self.root, f'eurosat-{split}.txt')) as f: + with open(os.path.join(self.root, self.split_filenames[split])) as f: for fn in f: valid_fns.add(fn.strip().replace('.jpg', '.tif')) @@ -207,16 +207,12 @@ def _verify(self) -> None: def _download(self) -> None: """Download the dataset.""" download_url( - self.url, - self.root, - filename=self.filename, - md5=self.md5 if self.checksum else None, + self.url + self.filename, self.root, md5=self.md5 if self.checksum else None ) for split in self.splits: download_url( - self.split_urls[split], + self.url + self.split_filenames[split], self.root, - filename=f'eurosat-{split}.txt', md5=self.split_md5s[split] if self.checksum else None, ) @@ -305,10 +301,10 @@ class EuroSATSpatial(EuroSAT): .. versionadded:: 0.6 """ - split_urls: ClassVar[dict[str, str]] = { - 'train': 'https://hf.co/datasets/torchgeo/eurosat/resolve/1c11c73a87b40b0485d103231a97829991b8e22f/eurosat-spatial-train.txt', - 'val': 'https://hf.co/datasets/torchgeo/eurosat/resolve/1c11c73a87b40b0485d103231a97829991b8e22f/eurosat-spatial-val.txt', - 'test': 'https://hf.co/datasets/torchgeo/eurosat/resolve/1c11c73a87b40b0485d103231a97829991b8e22f/eurosat-spatial-test.txt', + split_filenames: ClassVar[dict[str, str]] = { + 'train': 'eurosat-spatial-train.txt', + 'val': 'eurosat-spatial-val.txt', + 'test': 'eurosat-spatial-test.txt', } split_md5s: ClassVar[dict[str, str]] = { 'train': '7be3254be39f23ce4d4d144290c93292', @@ -328,14 +324,13 @@ class EuroSAT100(EuroSAT): .. versionadded:: 0.5 """ - url = 'https://hf.co/datasets/torchgeo/eurosat/resolve/06fd1b090bceecc0ce724cd21578ba7a6664fe8d/EuroSAT100.zip' filename = 'EuroSAT100.zip' md5 = 'c21c649ba747e86eda813407ef17d596' - split_urls: ClassVar[dict[str, str]] = { - 'train': 'https://hf.co/datasets/torchgeo/eurosat/resolve/06fd1b090bceecc0ce724cd21578ba7a6664fe8d/eurosat-train.txt', - 'val': 'https://hf.co/datasets/torchgeo/eurosat/resolve/06fd1b090bceecc0ce724cd21578ba7a6664fe8d/eurosat-val.txt', - 'test': 'https://hf.co/datasets/torchgeo/eurosat/resolve/06fd1b090bceecc0ce724cd21578ba7a6664fe8d/eurosat-test.txt', + split_filenames: ClassVar[dict[str, str]] = { + 'train': 'eurosat-100-train.txt', + 'val': 'eurosat-100-val.txt', + 'test': 'eurosat-100-test.txt', } split_md5s: ClassVar[dict[str, str]] = { 'train': '033d0c23e3a75e3fa79618b0e35fe1c7',