Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

FLAIR#2 Dataset and Datamodule Integration #2394

Open
wants to merge 78 commits into
base: main
Choose a base branch
from

Conversation

MathiasBaumgartinger
Copy link
Contributor

@MathiasBaumgartinger MathiasBaumgartinger commented Nov 5, 2024

FLAIR#2 dataset

The FLAIR #2 <https://github.com/IGNF/FLAIR-2> dataset is an extensive dataset from the French National Institute of Geographical and Forest Information (IGN) that provides a unique and rich resource for large-scale geospatial analysis.
The dataset is sampled countrywide and is composed of over 20 billion annotated pixels of very high resolution aerial imagery at 0.2 m spatial resolution, acquired over three years and different months (spatio-temporal domains).

The FLAIR2 dataset is a dataset for semantic segmentation of aerial images. It contains aerial images, sentinel-2 images and masks for 13 classes.
The dataset is split into a training and test set.

Dataset features:

* over 20 billion annotated pixels
* aerial imagery
    * 5x512x512
    * 0.2m spatial resolution
    * 5 channels (RGB-NIR-Elevation)
* Sentinel-2 imagery
    * 10-20m spatial resolution
    * 10 spectral bands
    * snow/cloud masks (with 0-100 probability)
    * multiple time steps (T)
    * Tx10xWxH, T, W, H are variable
* label (masks)
    * 512x512
    * 13 classes

Dataset classes:

0: "building",
1: "pervious surface",
2: "impervious surface",
3: "bare soil",
4: "water",
5: "coniferous",
6: "deciduous",
7: "brushwood",
8: "vineyard",
9: "herbaceous vegetation",
10: "agricultural land",
11: "plowed land",
12: "other"  

If you use this dataset in your research, please cite the following paper:

* https://doi.org/10.48550/arXiv.2310.13336

image

Implementation Details

NonGeoDataset, __init()__

After discussions following #2303, we decided that at least until faulty mask data are fixed the flair2 ds will be of type NonGeoDataset. Other than with common NonGeoDatasets, FLAIR2 exposes a use_toy and use_sentinel argument. The use_toy-flag will instead use the toy data which is a small subset of data. The use_sentinel argument on the other hand decides whether a sample includes the augmented sentinel data provided by the maintainers of FLAIR2.

_verify, _download, _extract

As each of the splits/sample-types (i.e. [train, test], [aerial, sentinel, labels] are contained in a individual zip download, download and extraction has to happen multiple times. On the other hand, the toy dataset is contained in a singular zip. Furthermore, to map the super-patches of the sentinel data to the actual input image, a flair-2_centroids_sp_to_patch.json is required, which has to be equally has to be downloaded as an individual zip.

_load_image, _load_sentinel, _load_target

For storage reasons, the elevation (5th band) of the image is stored as a uint. The original height thus is multiplied by 5. We decided to divide the height by 5 to get the original height, to make the trained model more usable for other data. See Questions please.

As mentioned previously, additional metadata has to be used to get from the sentinel.npy to the actual area. Initially for debugging reasons, we implemented to return not the cropped image but the original data and the cropping-slices (i.e. indices). Consequently, the images can be plot in a more meaningful matter. Otherwise, the resolution is so low that one can hardly recognize features. This was crucial for debugging to find the correct logic (classic y, x instead of x, y ordering mistake). We do not know if this is smart for "production code". See Questions please.
Moreover, the dimensions of the sentinel data $T \times C=10 \times W \times H$ vary both $T$ and $W$, $H$. This is problematic for the datamodule. We have not done extensive research, but the varying dimensions seem to bug the module. Disabling the use_sentinel-flag will make the module work.

The labels include values from 1 to 19. The datapaper clearly mentions grouping classes $&gt; 13$ into one class other due to underrepresentation. We followed this suggestion. Furthermore, rescaling from 0 to 12 was applied. See Questions please.

Questions

  • Do you consider the Elevation rescaling as distortion of the dataset? Shall I exclude it? The argument for it would be easier re-usability on new datasets.

For storage optimization reasons, this elevation information is multiplied by a factor of 5 and encoded as a 8bit unsigned integer datatype.

  • How shall we load/provide sentinel data? As cropped data or any other way. I do not see the current implementation as fit for production.

    • Also, how do we want to plot it? The small red rectangle in the example plot above is the actual region. The low resolution is quite observable there.
  • Shall we rescale the Classes to start from 0? Shall we group the classes as suggested in the datapaper?

  • Check integrity in download_url does not seem to work (in unit-tests), why?

    • I have to call an own check_integrity call otherwise it passes, even if md5s do not match.
  • The github actions on the forked repo produce a magic ruff error (https://github.com/MathiasBaumgartinger/torchgeo/actions/runs/11687694109/job/32556175383#step:7:1265). Can you help me resolve this mystery?

TODOs/FIXMEs

  • Extend tests for toy datasets and apply md5 check
  • Find correct band for plotting sentinel
  • Datamodule cannot handle sentinel data yet

Mathias Baumgartinger and others added 30 commits September 13, 2024 12:14
…mg and msk)

Updates in the custom raster dataset tutorial and the actual file documentation. The previous recommended approach (overriding `__get_item__`) is outdated.

Refs: microsoft#2292 (reply in thread)
Co-authored-by: Adam J. Stewart <[email protected]>
Co-authored-by: Adam J. Stewart <[email protected]>
Not fully functioning yet, contains copy paste from other datasets
Additionally, some small code refactors are done
Using the entire sentinel-2 image and a matplotlib patch to debug, otherwise it is really hard to find correct spot due to low resolution
…y()` for sentinel

With the nested dict, it was not possible to download dynamically
md5s might change due to timestamps, this eases the process of changing md5
"""Get statistics (min, max, means, stdvs) for each used band in order.

Args:
split (str): Split for which to get statistics (currently only for train)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

For the docstring notion in torchgeo, we do not include the type in the docstring again, only function arguments.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think this will also resolve the failing docs test

return tensor

def _load_sentinel(self, path: Path) -> Tensor:
# FIXME: should this really be returned as a tuple?
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This can be removed, right?

self.root,
md5=self.md5s.get(url, None) if self.checksum else None,
)
# FIXME: Why is download_url not checking integrity (tests run through)?
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is this fixed?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This was still an issue last time i checked. As mentioned in the first text of the PRQ, omehow, when I run the pytests with wrong md5 hashes, the integrity is not checked unless I explicitly call it here again (it is implicity called in download_url).

self.root,
md5=self.md5s.get(url, None) if self.checksum else None,
)
# FIXME: Why is download_url not checking integrity (tests run through)?
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

how about this FIXME?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I left this here as somewhat of a reminder, because it has the same behavior as the parent class. I.e. currently all tests pass even with wrong md5s.

"""
super().__init__(FLAIR2, batch_size, num_workers, **kwargs)

self.patch_size = _to_tuple(patch_size)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think that's a good idea. Could either be included here, or in a separate PR.

torchgeo/datamodules/flair2.py Outdated Show resolved Hide resolved
@@ -0,0 +1,8 @@
/home/mathias/Dev/forks/torchgeo/tests/data/flair2/FLAIR2/flair_2_labels_test.zip: b13c4a3cb7ebb5cadddc36474bb386f9
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

maybe remove the personal directory from this text file.


rgb_indices = [self.all_bands.index(band) for band in self.rgb_bands]
# Check if RGB bands are present in self.bands
if not all([band in self.bands for band in self.rgb_bands]):
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The Codecoverage is indicating that the RGB Band Missing is not being hit, so I think you just need to add a separate plot test similar to

def test_plot_rgb(self, dataset: EuroSAT, tmp_path: Path) -> None:
for example.

K.Normalize(mean=self.mean, std=self.std), data_keys=['image', 'mask']
)

self.augs = augs if augs is not None else self.aug

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is self.augs intended to act on the data somewhere? My immediate thought was that it would be related to the augmentations part of the base datamodule:

# Data augmentation
Transform = Callable[[dict[str, Tensor]], dict[str, Tensor]]
self.aug: Transform = K.AugmentationSequential(
K.Normalize(mean=self.mean, std=self.std), data_keys=None, keepdim=True
)
self.train_aug: Transform | None = None
self.val_aug: Transform | None = None
self.test_aug: Transform | None = None
self.predict_aug: Transform | None = None
But it's not in there, and setting it to an arbitrary value doesn't seem to do anything. Or perhaps I'm missing something?

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

self.aug is applied here if the split specific augmentations are not specified

Copy link

@JacobJeppesen JacobJeppesen Dec 2, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

That was also my understanding, and perhaps it's a typo where self.augs was intended to be self.aug, such that the augmentations would be applied automatically through the base datamodule. However, with the current implementation, if the user provides augmentations through the augs parameter in the FLAIR2 datamodule, they won't have an effect, as they are being added to self.augs, which doesn't seem to be applied to the data (as far as I can tell). I.e., maybe the intention was self.aug = augs if augs is not None else self.aug(?)

Comment on lines 424 to 440
def _load_image(self, path: Path) -> Tensor:
"""Load a single image.

Args:
path: path to the image

Returns:
Tensor: the loaded image
"""
with rasterio.open(path) as f:
array: np.typing.NDArray[np.int_] = f.read()
tensor = torch.from_numpy(array).float()
if 'B05' in self.bands:
# Height channel will always be the last dimension
tensor[-1] = torch.div(tensor[-1], 5)

return tensor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Should the bands perhaps be extracted here based on self.bands? E.g., something like

    def _load_image(self, path: Path) -> Tensor:
        """Load a single image.
        Args:
            path: path to the image
        Returns:
            Tensor: the loaded image
        """
        with rasterio.open(path) as f:
            array: np.typing.NDArray[np.int_] = f.read()
            tensor = torch.from_numpy(array).float()
            if 'B05' in self.bands:
                # Height channel will always be the last dimension
                tensor[-1] = torch.div(tensor[-1], 5)

        # Extract the bands to be used. E.g., self.bands=("B01", "B02", "B03") will extract the RGB bands.
        tensor = tensor[[int(band[-2:]) - 1 for band in self.bands]]

        return tensor

Then when a user has defined n bands in self.bands, the returned sample will only contain those bands, instead of all five. Perhaps self.bands should also be renamed to self.aerial_bands, and a self.sentinel_bands should be added(?)

@MathiasBaumgartinger
Copy link
Contributor Author

Thanks everyone for reviewing. Tried to apply all suggested changes.

As for the checks, according to the log we face the following error: /home/docs/checkouts/readthedocs.org/user_builds/torchgeo/checkouts/2394/docs/api/datasets.rst:194: ERROR: "csv-table" widths do not match the number of columns in table (10).
Truly, I have messed something up here (i.e. something with line endings, separations and quotation marks). However, I resolved this problem and the error in the check seems to persist.

Error: reference before assignment
Copy link
Collaborator

@adamjstewart adamjstewart left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Does anyone have general opinions on whether we should call this:

FLAIR2()

or:

FLAIR(version=2)

docs/api/datasets/non_geo_datasets.csv Outdated Show resolved Hide resolved
tests/data/flair2/data.py Show resolved Hide resolved
torchgeo/datamodules/flair2.py Outdated Show resolved Hide resolved
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We can delete this file and instead test the dataset using a tests/conf/flair2.yaml file and 1 line of code in tests/trainers/test_segmentation.py. We are actively trying to get rid of tests/datamodules since it doesn't test compatibility.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Do the tests in tests/trainers/test_segmentation.py actually work already? For me, 35 fail and 14 pass

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

They work in CI. What error messages are you seeing? Note that PyTorch doesn't yet fully support Python 3.13 so some of the checkpointing tests will fail.

def dataset(
self, monkeypatch: MonkeyPatch, tmp_path: Path, request: SubRequest
) -> FLAIR2:
md5s = {
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Could also just skip checksum=True and not bother monkeypatching any MD5s

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think I will leave it for now. In fact this is the reason i detected torchgeo/datasets/flair2.py#L567-L571

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The reason that download_url isn't checking the MD5 is because we monkeypatch it to avoid downloading anything.

torchgeo/datasets/flair2.py Outdated Show resolved Hide resolved
torchgeo/datasets/flair2.py Outdated Show resolved Hide resolved
torchgeo/datasets/flair2.py Outdated Show resolved Hide resolved
a matplotlib Figure with the rendered sample
"""

def normalize_plot(tensor: Tensor) -> Tensor:
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I would instead use .utils.percentile_normalization

torchgeo/datasets/flair2.py Outdated Show resolved Hide resolved
@JacobJeppesen
Copy link

Does anyone have general opinions on whether we should call this:

FLAIR2()

or:

FLAIR(version=2)

I think it'd make sense to use FLAIR(version=2), as it seems like each new version is a superset of the previous version. Most users will probably use the latest version, so if they are individual datasets, FLAIR1() and FLAIR2() might end up as somewhat unused datasets once FLAIR3() is released. As I understand #2303 (comment), once version 3 is released, version 1 and 2 data can be directly loaded from version 3 by filtering the files. So the lowest complexity solution would probably be a FLAIR() dataset, where version 3 is the full dataset, version 2 is reduced coverage by filtering files/area, and version 1 is the same reduced coverage, but only aerial.

@MathiasBaumgartinger
Copy link
Contributor Author

MathiasBaumgartinger commented Dec 13, 2024

Does anyone have general opinions on whether we should call this:

FLAIR2()

or:

FLAIR(version=2)

I think it'd make sense to use FLAIR(version=2), as it seems like each new version is a superset of the previous version. Most users will probably use the latest version, so if they are individual datasets, FLAIR1() and FLAIR2() might end up as somewhat unused datasets once FLAIR3() is released. As I understand #2303 (comment), once version 3 is released, version 1 and 2 data can be directly loaded from version 3 by filtering the files. So the lowest complexity solution would probably be a FLAIR() dataset, where version 3 is the full dataset, version 2 is reduced coverage by filtering files/area, and version 1 is the same reduced coverage, but only aerial.

I am in contact with @agarioud. If I do understand him correctly, I doubt that new datasets will have the requirement to be backward-compatible.

Unfortunately there is no direct compatibility with FLAIR#1/#2 apart the aerial images as we reworked the supervision (land cover and now LPIS)

But I agree. Adding versioning to FLAIR() will probably result in less dead code. From a design perspective:

  1. We have a parent dataset FLAIR handling all the common logic. By passing a specific version (default will always be latest) a new child class is initialized which overrides version specific logic.
  2. We have a parent datamodule FLAIRModule. Which passes a version down to datasets. I.e. probably only a single FLAIRModule is necessary (no inheritance).
  3. Points 1 and 2 will be applied for FLAIRToy and FLAIRToyModule too (probably both will inherit from the corresponding non-toy classes.)

@agarioud
Copy link

agarioud commented Dec 13, 2024

The upcoming FLAIR-INC dataset will include all data from FLAIR#1 (aerial images) and FLAIR#2 (which added Sentinel-2 data that were previously NPY files covering larger extents but will now have the same spatial extent as the aerial patches and be in TIFF format). Additionally, it will introduce five new modalities and a second supervision dataset.
If needed, a toy dataset can already be shared with @MathiasBaumgartinger

Backward compatibility is not fully ensured. For example, aerial images will have one less channel, as DSM/DTM has been introduced as a separate modality. Additionally, the supervision dataset regarding land-cover has reordered classes.

Therefore, I also believe a FLAIR() dataset would probably be more efficient ?

@JacobJeppesen
Copy link

But I agree. Adding versioning to FLAIR() will probably result in less dead code. From a design perspective:

  1. We have a parent dataset FLAIR handling all the common logic. By passing a specific version (default will always be latest) a new child class is initialized which overrides version specific logic.
  2. We have a parent datamodule FLAIRModule. Which passes a version down to datasets. I.e. probably only a single FLAIRModule is necessary (no inheritance).
  3. Points 1 and 2 will be applied for FLAIRToy and FLAIRToyModule too (probably both will inherit from the corresponding non-toy classes.)

This sounds like a good approach 👍

@agarioud sounds great. Looking forward to the release 🙂

@adamjstewart
Copy link
Collaborator

It's not really a matter of compatibility. Basically, we either use:

class FLAIR(NonGeoDataset, abc.ABC):
    # shared base class

class FLAIR2(FLAIR):
    # override specific stuff

or:

class FLAIR(NonGeoDataset):
    def __init__(self, version=2, ...):
        if version == 2:
            # override specific stuff

It's more a question of whether the devs think of this as a new version of an existing dataset or a new dataset. We usually go with the former because it offers a bit better reproducibility (if the default version gets changed, then reproducibility is broken). See our EuroSAT dataset for an example of this. I think the only cases where we use version instead are in MoCoTask and SimCLRTask. I'm fine with both solutions, just want to make sure we're all on the same page.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
datamodules PyTorch Lightning datamodules datasets Geospatial or benchmark datasets documentation Improvements or additions to documentation testing Continuous integration testing
Projects
None yet
Development

Successfully merging this pull request may close these issues.

5 participants