Skip to content

Commit

Permalink
Adds the LandCoverAI100 dataset and datamodule for use in semantic se…
Browse files Browse the repository at this point in the history
…gmentation notebooks (#2262)

* Add dataset and datamodule

* Add docs

* Tests

* Ran ruff one time

* Fixture needs a params kwarg

* Make dataset work

* Add versionadded to datamodule

* Add conf file to test new datamodule

* Test datamodule

* Changing dataset URL

* Update main hash
  • Loading branch information
calebrob6 authored Sep 11, 2024
1 parent de31549 commit 94960bb
Show file tree
Hide file tree
Showing 9 changed files with 87 additions and 17 deletions.
1 change: 1 addition & 0 deletions docs/api/datamodules.rst
Original file line number Diff line number Diff line change
Expand Up @@ -123,6 +123,7 @@ LandCover.ai
^^^^^^^^^^^^

.. autoclass:: LandCoverAIDataModule
.. autoclass:: LandCoverAI100DataModule

LEVIR-CD
^^^^^^^^
Expand Down
1 change: 1 addition & 0 deletions docs/api/datasets.rst
Original file line number Diff line number Diff line change
Expand Up @@ -316,6 +316,7 @@ LandCover.ai
^^^^^^^^^^^^

.. autoclass:: LandCoverAI
.. autoclass:: LandCoverAI100

LEVIR-CD
^^^^^^^^
Expand Down
16 changes: 16 additions & 0 deletions tests/conf/landcoverai100.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,16 @@
model:
class_path: SemanticSegmentationTask
init_args:
loss: "ce"
model: "unet"
backbone: "resnet18"
in_channels: 3
num_classes: 5
num_filters: 1
ignore_index: null
data:
class_path: LandCoverAI100DataModule
init_args:
batch_size: 1
dict_kwargs:
root: "tests/data/landcoverai"
19 changes: 13 additions & 6 deletions tests/datasets/test_landcoverai.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@

import os
import shutil
from itertools import product
from pathlib import Path

import matplotlib.pyplot as plt
Expand All @@ -17,6 +18,7 @@
BoundingBox,
DatasetNotFoundError,
LandCoverAI,
LandCoverAI100,
LandCoverAIGeo,
)

Expand Down Expand Up @@ -72,20 +74,25 @@ def test_plot(self, dataset: LandCoverAIGeo) -> None:
class TestLandCoverAI:
pytest.importorskip('cv2', minversion='4.5.4')

@pytest.fixture(params=['train', 'val', 'test'])
@pytest.fixture(
params=product([LandCoverAI100, LandCoverAI], ['train', 'val', 'test'])
)
def dataset(
self, monkeypatch: MonkeyPatch, tmp_path: Path, request: SubRequest
) -> LandCoverAI:
base_class: type[LandCoverAI] = request.param[0]
split: str = request.param[1]
md5 = 'ff8998857cc8511f644d3f7d0f3688d0'
monkeypatch.setattr(LandCoverAI, 'md5', md5)
monkeypatch.setattr(base_class, 'md5', md5)
url = os.path.join('tests', 'data', 'landcoverai', 'landcover.ai.v1.zip')
monkeypatch.setattr(LandCoverAI, 'url', url)
monkeypatch.setattr(base_class, 'url', url)
sha256 = 'ecec8e871faf1bbd8ca525ca95ddc1c1f5213f40afb94599884bd85f990ebd6b'
monkeypatch.setattr(LandCoverAI, 'sha256', sha256)
monkeypatch.setattr(base_class, 'sha256', sha256)
if base_class == LandCoverAI100:
monkeypatch.setattr(base_class, 'filename', 'landcover.ai.v1.zip')
root = tmp_path
split = request.param
transforms = nn.Identity()
return LandCoverAI(root, split, transforms, download=True, checksum=True)
return base_class(root, split, transforms, download=True, checksum=True)

def test_getitem(self, dataset: LandCoverAI) -> None:
x = dataset[0]
Expand Down
1 change: 1 addition & 0 deletions tests/trainers/test_segmentation.py
Original file line number Diff line number Diff line change
Expand Up @@ -62,6 +62,7 @@ class TestSemanticSegmentationTask:
'l7irish',
'l8biome',
'landcoverai',
'landcoverai100',
'loveda',
'naipchesapeake',
'potsdam2d',
Expand Down
3 changes: 2 additions & 1 deletion torchgeo/datamodules/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@
from .iobench import IOBenchDataModule
from .l7irish import L7IrishDataModule
from .l8biome import L8BiomeDataModule
from .landcoverai import LandCoverAIDataModule
from .landcoverai import LandCoverAI100DataModule, LandCoverAIDataModule
from .levircd import LEVIRCDDataModule, LEVIRCDPlusDataModule
from .loveda import LoveDADataModule
from .naip import NAIPChesapeakeDataModule
Expand Down Expand Up @@ -82,6 +82,7 @@
'GID15DataModule',
'InriaAerialImageLabelingDataModule',
'LandCoverAIDataModule',
'LandCoverAI100DataModule',
'LEVIRCDDataModule',
'LEVIRCDPlusDataModule',
'LoveDADataModule',
Expand Down
30 changes: 28 additions & 2 deletions torchgeo/datamodules/landcoverai.py
Original file line number Diff line number Diff line change
@@ -1,13 +1,13 @@
# Copyright (c) Microsoft Corporation. All rights reserved.
# Licensed under the MIT License.

"""LandCover.ai datamodule."""
"""LandCover.ai datamodules."""

from typing import Any

import kornia.augmentation as K

from ..datasets import LandCoverAI
from ..datasets import LandCoverAI, LandCoverAI100
from ..transforms import AugmentationSequential
from .geo import NonGeoDataModule

Expand Down Expand Up @@ -43,3 +43,29 @@ def __init__(
self.aug = AugmentationSequential(
K.Normalize(mean=self.mean, std=self.std), data_keys=['image', 'mask']
)


class LandCoverAI100DataModule(NonGeoDataModule):
"""LightningDataModule implementation for the LandCoverAI100 dataset.
Uses the train/val/test splits from the dataset.
.. versionadded:: 0.7
"""

def __init__(
self, batch_size: int = 64, num_workers: int = 0, **kwargs: Any
) -> None:
"""Initialize a new LandCoverAI100DataModule instance.
Args:
batch_size: Size of each mini-batch.
num_workers: Number of workers for parallel data loading.
**kwargs: Additional keyword arguments passed to
:class:`~torchgeo.datasets.LandCoverAI100`.
"""
super().__init__(LandCoverAI100, batch_size, num_workers, **kwargs)

self.aug = AugmentationSequential(
K.Normalize(mean=self.mean, std=self.std), data_keys=['image', 'mask']
)
3 changes: 2 additions & 1 deletion torchgeo/datasets/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -65,7 +65,7 @@
from .iobench import IOBench
from .l7irish import L7Irish
from .l8biome import L8Biome
from .landcoverai import LandCoverAI, LandCoverAIBase, LandCoverAIGeo
from .landcoverai import LandCoverAI, LandCoverAI100, LandCoverAIBase, LandCoverAIGeo
from .landsat import (
Landsat,
Landsat1,
Expand Down Expand Up @@ -224,6 +224,7 @@
'IDTReeS',
'InriaAerialImageLabeling',
'LandCoverAI',
'LandCoverAI100',
'LEVIRCD',
'LEVIRCDBase',
'LEVIRCDPlus',
Expand Down
30 changes: 23 additions & 7 deletions torchgeo/datasets/landcoverai.py
Original file line number Diff line number Diff line change
Expand Up @@ -401,10 +401,26 @@ def _extract(self) -> None:
super()._extract()

# Generate train/val/test splits
# Always check the sha256 of this file before executing
# to avoid malicious code injection
with working_dir(self.root):
with open('split.py') as f:
split = f.read().encode('utf-8')
assert hashlib.sha256(split).hexdigest() == self.sha256
exec(split)
# Always check the sha256 of this file before executing to avoid malicious code injection
# The LandCoverAI100 dataset doesn't contain split.py, so only run if split.py exists
if os.path.exists(os.path.join(self.root, 'split.py')):
with working_dir(self.root):
with open('split.py') as f:
split = f.read().encode('utf-8')
assert hashlib.sha256(split).hexdigest() == self.sha256
exec(split)


class LandCoverAI100(LandCoverAI):
"""Subset of LandCoverAI containing only 100 images.
Intended for tutorials and demonstrations, not for benchmarking.
Maintains the same file structure, classes, and train-val-test split.
.. versionadded:: 0.7
"""

url = 'https://huggingface.co/datasets/torchgeo/landcoverai/resolve/5cdf9299bd6c1232506cf79373df01f6e6596b50/landcoverai100.zip'
filename = 'landcoverai100.zip'
md5 = '66eb33b5a0cabb631836ce0a4eafb7cd'

0 comments on commit 94960bb

Please sign in to comment.