diff --git a/docs/api/datamodules.rst b/docs/api/datamodules.rst index abb0c6eaefa..fdcef5450d1 100644 --- a/docs/api/datamodules.rst +++ b/docs/api/datamodules.rst @@ -124,6 +124,11 @@ GID-15 .. autoclass:: GID15DataModule +HySpecNet-11k +^^^^^^^^^^^^^ + +.. autoclass:: HySpecNet11kDataModule + Inria Aerial Image Labeling ^^^^^^^^^^^^^^^^^^^^^^^^^^^ @@ -213,6 +218,11 @@ SustainBench Crop Yield .. autoclass:: SustainBenchCropYieldDataModule +TreeSatAI +^^^^^^^^^ + +.. autoclass:: TreeSatAIDataModule + Tropical Cyclone ^^^^^^^^^^^^^^^^ diff --git a/docs/api/datasets.rst b/docs/api/datasets.rst index 96ca225344a..b8f2137c920 100644 --- a/docs/api/datasets.rst +++ b/docs/api/datasets.rst @@ -312,6 +312,11 @@ GID-15 .. autoclass:: GID15 +HySpecNet-11k +^^^^^^^^^^^^^ + +.. autoclass:: HySpecNet11k + IDTReeS ^^^^^^^ @@ -469,6 +474,11 @@ SustainBench Crop Yield .. autoclass:: SustainBenchCropYield +TreeSatAI +^^^^^^^^^ + +.. autoclass:: TreeSatAI + Tropical Cyclone ^^^^^^^^^^^^^^^^ diff --git a/docs/api/datasets/non_geo_datasets.csv b/docs/api/datasets/non_geo_datasets.csv index f000bc1d8da..7d7a17a4b94 100644 --- a/docs/api/datasets/non_geo_datasets.csv +++ b/docs/api/datasets/non_geo_datasets.csv @@ -21,6 +21,7 @@ Dataset,Task,Source,License,# Samples,# Classes,Size (px),Resolution (m),Bands `Forest Damage`_,OD,Drone imagery,"CDLA-Permissive-1.0","1,543",4,"1,500x1,500",,RGB `GeoNRW`_,S,Aerial,"CC-BY-4.0","7,783",11,"1,000x1,000",1,"RGB, DEM" `GID-15`_,S,Gaofen-2,-,150,15,"6,800x7,200",3,RGB +`HySpecNet-11k`_,-,EnMAP,CC0-1.0,11k,-,128,30,HSI `IDTReeS`_,"OD,C",Aerial,"CC-BY-4.0",591,33,200x200,0.1--1,RGB `Inria Aerial Image Labeling`_,S,Aerial,-,360,2,"5,000x5,000",0.3,RGB `LandCover.ai`_,S,Aerial,"CC-BY-NC-SA-4.0","10,674",5,512x512,0.25--0.5,RGB @@ -52,6 +53,7 @@ Dataset,Task,Source,License,# Samples,# Classes,Size (px),Resolution (m),Bands `SSL4EO-L Benchmark`_,S,Lansat & CDL,"CC0-1.0",25K,134,264x264,30,MSI `SSL4EO-L Benchmark`_,S,Lansat & NLCD,"CC0-1.0",25K,17,264x264,30,MSI `SustainBench Crop Yield`_,R,MODIS,"CC-BY-SA-4.0",11k,-,32x32,-,MSI +`TreeSatAI`_,"C, R, S","Aerial, Sentinel-1/2",CC-BY-4.0,50K,"12, 15, 20","6, 20, 304","0.2, 10","CIR, MSI, SAR" `Tropical Cyclone`_,R,GOES 8--16,"CC-BY-4.0","108,110",-,256x256,4K--8K,MSI `UC Merced`_,C,USGS National Map,"public domain","2,100",21,256x256,0.3,RGB `USAVars`_,R,NAIP Aerial,"CC-BY-4.0",100K,-,-,4,"RGB, NIR" diff --git a/package-lock.json b/package-lock.json index 371112f9d23..07cbc636bd9 100644 --- a/package-lock.json +++ b/package-lock.json @@ -5,13 +5,14 @@ "packages": { "": { "dependencies": { - "prettier": ">=3.3.3" + "prettier": ">=3" } }, "node_modules/prettier": { - "version": "3.3.3", - "resolved": "https://registry.npmjs.org/prettier/-/prettier-3.3.3.tgz", - "integrity": "sha512-i2tDNA0O5IrMO757lfrdQZCc2jPNDVntV0m/+4whiDfWaTKfMNgR7Qz0NAeGz/nRqF4m5/6CLzbP4/liHt12Ew==", + "version": "3.4.1", + "resolved": "https://registry.npmjs.org/prettier/-/prettier-3.4.1.tgz", + "integrity": "sha512-G+YdqtITVZmOJje6QkXQWzl3fSfMxFwm1tjTyo9exhkmWSqC4Yhd1+lug++IlR2mvRVAxEDDWYkQdeSztajqgg==", + "license": "MIT", "bin": { "prettier": "bin/prettier.cjs" }, diff --git a/pyproject.toml b/pyproject.toml index 3541689f5c6..a37a7bd759e 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -61,9 +61,9 @@ dependencies = [ # pyproj 3.3+ required for Python 3.10 wheels "pyproj>=3.3", # rasterio 1.3+ required for Python 3.10 wheels - # rasterio 1.4+ no longer supports merging WarpedVRT objects + # rasterio 1.4.0-1.4.2 lack support for merging WarpedVRT objects # https://github.com/rasterio/rasterio/issues/3196 - "rasterio>=1.3,<1.4", + "rasterio>=1.3,!=1.4.0,!=1.4.1,!=1.4.2", # rtree 1+ required for Python 3.10 wheels "rtree>=1", # segmentation-models-pytorch 0.2+ required for smp.losses module diff --git a/requirements/required.txt b/requirements/required.txt index cee750e8c37..288fc286918 100644 --- a/requirements/required.txt +++ b/requirements/required.txt @@ -5,14 +5,14 @@ setuptools==75.6.0 einops==0.8.0 fiona==1.10.1 kornia==0.7.4 -lightly==1.5.14 +lightly==1.5.15 lightning[pytorch-extra]==2.4.0 -matplotlib==3.9.2 +matplotlib==3.9.3 numpy==2.1.3 pandas==2.2.3 pillow==11.0.0 pyproj==3.7.0 -rasterio==1.3.11 +rasterio==1.4.3 rtree==1.3.0 segmentation-models-pytorch==0.3.4 shapely==2.0.6 diff --git a/requirements/style.txt b/requirements/style.txt index d0eb263ea5a..2734bead562 100644 --- a/requirements/style.txt +++ b/requirements/style.txt @@ -1,3 +1,3 @@ # style mypy==1.13.0 -ruff==0.8.0 +ruff==0.8.1 diff --git a/requirements/tests.txt b/requirements/tests.txt index 0d8b32d5e26..8a4d222b6b3 100644 --- a/requirements/tests.txt +++ b/requirements/tests.txt @@ -1,4 +1,4 @@ # tests nbmake==1.5.4 -pytest==8.3.3 +pytest==8.3.4 pytest-cov==6.0.0 diff --git a/tests/conf/hyspecnet_byol.yaml b/tests/conf/hyspecnet_byol.yaml new file mode 100644 index 00000000000..5c0fa31d609 --- /dev/null +++ b/tests/conf/hyspecnet_byol.yaml @@ -0,0 +1,11 @@ +model: + class_path: BYOLTask + init_args: + model: 'resnet18' + in_channels: 202 +data: + class_path: HySpecNet11kDataModule + init_args: + batch_size: 2 + dict_kwargs: + root: 'tests/data/hyspecnet' diff --git a/tests/conf/hyspecnet_moco.yaml b/tests/conf/hyspecnet_moco.yaml new file mode 100644 index 00000000000..732b83912c1 --- /dev/null +++ b/tests/conf/hyspecnet_moco.yaml @@ -0,0 +1,11 @@ +model: + class_path: MoCoTask + init_args: + model: 'resnet18' + in_channels: 202 +data: + class_path: HySpecNet11kDataModule + init_args: + batch_size: 2 + dict_kwargs: + root: 'tests/data/hyspecnet' diff --git a/tests/conf/hyspecnet_simclr.yaml b/tests/conf/hyspecnet_simclr.yaml new file mode 100644 index 00000000000..d16e8209326 --- /dev/null +++ b/tests/conf/hyspecnet_simclr.yaml @@ -0,0 +1,11 @@ +model: + class_path: SimCLRTask + init_args: + model: 'resnet18' + in_channels: 202 +data: + class_path: HySpecNet11kDataModule + init_args: + batch_size: 2 + dict_kwargs: + root: 'tests/data/hyspecnet' diff --git a/tests/conf/treesatai.yaml b/tests/conf/treesatai.yaml new file mode 100644 index 00000000000..e605b688b82 --- /dev/null +++ b/tests/conf/treesatai.yaml @@ -0,0 +1,13 @@ +model: + class_path: MultiLabelClassificationTask + init_args: + model: 'resnet18' + in_channels: 19 + num_classes: 15 + loss: 'bce' +data: + class_path: TreeSatAIDataModule + init_args: + batch_size: 1 + dict_kwargs: + root: 'tests/data/treesatai' diff --git a/tests/data/hyspecnet/data.py b/tests/data/hyspecnet/data.py new file mode 100755 index 00000000000..3b4b701106e --- /dev/null +++ b/tests/data/hyspecnet/data.py @@ -0,0 +1,61 @@ +#!/usr/bin/env python3 + +# Copyright (c) Microsoft Corporation. All rights reserved. +# Licensed under the MIT License. + +import os +import shutil + +import numpy as np +import rasterio +from rasterio import Affine +from rasterio.crs import CRS + +SIZE = 32 +DTYPE = 'int16' + +np.random.seed(0) + +# Tile name purposefully shortened to avoid Windows git filename length limit. +tiles = ['ENMAP01_20221103T162438Z'] +patches = ['Y01460273_X05670694', 'Y01460273_X06950822'] + +profile = { + 'driver': 'GTiff', + 'dtype': DTYPE, + 'nodata': -32768.0, + 'width': SIZE, + 'height': SIZE, + 'count': 224, + 'crs': CRS.from_epsg(32618), + 'transform': Affine(30.0, 0.0, 691845.0, 0.0, -30.0, 4561935.0), + 'blockysize': 3, + 'tiled': False, + 'compress': 'deflate', + 'interleave': 'band', +} + +root = 'hyspecnet-11k' +path = os.path.join(root, 'splits', 'easy') +os.makedirs(path, exist_ok=True) +for tile in tiles: + for patch in patches: + # Split CSV + path = os.path.join(tile, f'{tile}-{patch}', f'{tile}-{patch}-DATA.npy') + for split in ['train', 'val', 'test']: + with open(os.path.join(root, 'splits', 'easy', f'{split}.csv'), 'a+') as f: + f.write(f'{path}\n') + + # Spectral image + path = os.path.join(root, 'patches', path) + os.makedirs(os.path.dirname(path), exist_ok=True) + path = path.replace('DATA.npy', 'SPECTRAL_IMAGE.TIF') + Z = np.random.randint( + np.iinfo(DTYPE).min, np.iinfo(DTYPE).max, size=(SIZE, SIZE), dtype=DTYPE + ) + with rasterio.open(path, 'w', **profile) as src: + for i in range(1, profile['count'] + 1): + src.write(Z, i) + +shutil.make_archive(f'{root}-01', 'gztar', '.', os.path.join(root, 'patches')) +shutil.make_archive(f'{root}-splits', 'gztar', '.', os.path.join(root, 'splits')) diff --git a/tests/data/hyspecnet/hyspecnet-11k-01.tar.gz b/tests/data/hyspecnet/hyspecnet-11k-01.tar.gz new file mode 100644 index 00000000000..b5a5ec766a5 Binary files /dev/null and b/tests/data/hyspecnet/hyspecnet-11k-01.tar.gz differ diff --git a/tests/data/hyspecnet/hyspecnet-11k-splits.tar.gz b/tests/data/hyspecnet/hyspecnet-11k-splits.tar.gz new file mode 100644 index 00000000000..152f71c040f Binary files /dev/null and b/tests/data/hyspecnet/hyspecnet-11k-splits.tar.gz differ diff --git a/tests/data/hyspecnet/hyspecnet-11k/patches/ENMAP01_20221103T162438Z/ENMAP01_20221103T162438Z-Y01460273_X05670694/ENMAP01_20221103T162438Z-Y01460273_X05670694-SPECTRAL_IMAGE.TIF b/tests/data/hyspecnet/hyspecnet-11k/patches/ENMAP01_20221103T162438Z/ENMAP01_20221103T162438Z-Y01460273_X05670694/ENMAP01_20221103T162438Z-Y01460273_X05670694-SPECTRAL_IMAGE.TIF new file mode 100644 index 00000000000..498bf304fa1 Binary files /dev/null and b/tests/data/hyspecnet/hyspecnet-11k/patches/ENMAP01_20221103T162438Z/ENMAP01_20221103T162438Z-Y01460273_X05670694/ENMAP01_20221103T162438Z-Y01460273_X05670694-SPECTRAL_IMAGE.TIF differ diff --git a/tests/data/hyspecnet/hyspecnet-11k/patches/ENMAP01_20221103T162438Z/ENMAP01_20221103T162438Z-Y01460273_X06950822/ENMAP01_20221103T162438Z-Y01460273_X06950822-SPECTRAL_IMAGE.TIF b/tests/data/hyspecnet/hyspecnet-11k/patches/ENMAP01_20221103T162438Z/ENMAP01_20221103T162438Z-Y01460273_X06950822/ENMAP01_20221103T162438Z-Y01460273_X06950822-SPECTRAL_IMAGE.TIF new file mode 100644 index 00000000000..5142ff4fbcf Binary files /dev/null and b/tests/data/hyspecnet/hyspecnet-11k/patches/ENMAP01_20221103T162438Z/ENMAP01_20221103T162438Z-Y01460273_X06950822/ENMAP01_20221103T162438Z-Y01460273_X06950822-SPECTRAL_IMAGE.TIF differ diff --git a/tests/data/hyspecnet/hyspecnet-11k/splits/easy/test.csv b/tests/data/hyspecnet/hyspecnet-11k/splits/easy/test.csv new file mode 100644 index 00000000000..14393bce82a --- /dev/null +++ b/tests/data/hyspecnet/hyspecnet-11k/splits/easy/test.csv @@ -0,0 +1,2 @@ +ENMAP01_20221103T162438Z/ENMAP01_20221103T162438Z-Y01460273_X05670694/ENMAP01_20221103T162438Z-Y01460273_X05670694-DATA.npy +ENMAP01_20221103T162438Z/ENMAP01_20221103T162438Z-Y01460273_X06950822/ENMAP01_20221103T162438Z-Y01460273_X06950822-DATA.npy diff --git a/tests/data/hyspecnet/hyspecnet-11k/splits/easy/train.csv b/tests/data/hyspecnet/hyspecnet-11k/splits/easy/train.csv new file mode 100644 index 00000000000..14393bce82a --- /dev/null +++ b/tests/data/hyspecnet/hyspecnet-11k/splits/easy/train.csv @@ -0,0 +1,2 @@ +ENMAP01_20221103T162438Z/ENMAP01_20221103T162438Z-Y01460273_X05670694/ENMAP01_20221103T162438Z-Y01460273_X05670694-DATA.npy +ENMAP01_20221103T162438Z/ENMAP01_20221103T162438Z-Y01460273_X06950822/ENMAP01_20221103T162438Z-Y01460273_X06950822-DATA.npy diff --git a/tests/data/hyspecnet/hyspecnet-11k/splits/easy/val.csv b/tests/data/hyspecnet/hyspecnet-11k/splits/easy/val.csv new file mode 100644 index 00000000000..14393bce82a --- /dev/null +++ b/tests/data/hyspecnet/hyspecnet-11k/splits/easy/val.csv @@ -0,0 +1,2 @@ +ENMAP01_20221103T162438Z/ENMAP01_20221103T162438Z-Y01460273_X05670694/ENMAP01_20221103T162438Z-Y01460273_X05670694-DATA.npy +ENMAP01_20221103T162438Z/ENMAP01_20221103T162438Z-Y01460273_X06950822/ENMAP01_20221103T162438Z-Y01460273_X06950822-DATA.npy diff --git a/tests/data/treesatai/aerial/60m/Acer_pseudoplatanus_3_5758_WEFL_NLF.tif b/tests/data/treesatai/aerial/60m/Acer_pseudoplatanus_3_5758_WEFL_NLF.tif new file mode 100644 index 00000000000..7df6abd74c0 Binary files /dev/null and b/tests/data/treesatai/aerial/60m/Acer_pseudoplatanus_3_5758_WEFL_NLF.tif differ diff --git a/tests/data/treesatai/aerial/60m/Acer_pseudoplatanus_4_6058_WEFL_NLF.tif b/tests/data/treesatai/aerial/60m/Acer_pseudoplatanus_4_6058_WEFL_NLF.tif new file mode 100644 index 00000000000..967876aa69c Binary files /dev/null and b/tests/data/treesatai/aerial/60m/Acer_pseudoplatanus_4_6058_WEFL_NLF.tif differ diff --git a/tests/data/treesatai/aerial/60m/Alnus_spec._5_13114_WEFL_NLF.tif b/tests/data/treesatai/aerial/60m/Alnus_spec._5_13114_WEFL_NLF.tif new file mode 100644 index 00000000000..36c6c049001 Binary files /dev/null and b/tests/data/treesatai/aerial/60m/Alnus_spec._5_13114_WEFL_NLF.tif differ diff --git a/tests/data/treesatai/aerial/60m/Fagus_sylvatica_9_29995_WEFL_NLF.tif b/tests/data/treesatai/aerial/60m/Fagus_sylvatica_9_29995_WEFL_NLF.tif new file mode 100644 index 00000000000..48b36565180 Binary files /dev/null and b/tests/data/treesatai/aerial/60m/Fagus_sylvatica_9_29995_WEFL_NLF.tif differ diff --git a/tests/data/treesatai/aerial/60m/Picea_abies_2_46896_WEFL_NLF.tif b/tests/data/treesatai/aerial/60m/Picea_abies_2_46896_WEFL_NLF.tif new file mode 100644 index 00000000000..3fb7ec2f4b7 Binary files /dev/null and b/tests/data/treesatai/aerial/60m/Picea_abies_2_46896_WEFL_NLF.tif differ diff --git a/tests/data/treesatai/aerial/60m/Picea_abies_3_46636_WEFL_NLF.tif b/tests/data/treesatai/aerial/60m/Picea_abies_3_46636_WEFL_NLF.tif new file mode 100644 index 00000000000..fc3913038bf Binary files /dev/null and b/tests/data/treesatai/aerial/60m/Picea_abies_3_46636_WEFL_NLF.tif differ diff --git a/tests/data/treesatai/aerial/60m/Pseudotsuga_menziesii_1_339575_BI_NLF.tif b/tests/data/treesatai/aerial/60m/Pseudotsuga_menziesii_1_339575_BI_NLF.tif new file mode 100644 index 00000000000..2fdb09a25a2 Binary files /dev/null and b/tests/data/treesatai/aerial/60m/Pseudotsuga_menziesii_1_339575_BI_NLF.tif differ diff --git a/tests/data/treesatai/aerial/60m/Quercus_petraea_2_84375_WEFL_NLF.tif b/tests/data/treesatai/aerial/60m/Quercus_petraea_2_84375_WEFL_NLF.tif new file mode 100644 index 00000000000..f7e0af9eb85 Binary files /dev/null and b/tests/data/treesatai/aerial/60m/Quercus_petraea_2_84375_WEFL_NLF.tif differ diff --git a/tests/data/treesatai/aerial/60m/Quercus_petraea_5_80549_WEFL_NLF.tif b/tests/data/treesatai/aerial/60m/Quercus_petraea_5_80549_WEFL_NLF.tif new file mode 100644 index 00000000000..52889605c84 Binary files /dev/null and b/tests/data/treesatai/aerial/60m/Quercus_petraea_5_80549_WEFL_NLF.tif differ diff --git a/tests/data/treesatai/aerial/60m/Quercus_rubra_1_92184_WEFL_NLF.tif b/tests/data/treesatai/aerial/60m/Quercus_rubra_1_92184_WEFL_NLF.tif new file mode 100644 index 00000000000..cffd19dbffe Binary files /dev/null and b/tests/data/treesatai/aerial/60m/Quercus_rubra_1_92184_WEFL_NLF.tif differ diff --git a/tests/data/treesatai/aerial_60m_acer_pseudoplatanus.zip b/tests/data/treesatai/aerial_60m_acer_pseudoplatanus.zip new file mode 100644 index 00000000000..b24e8514895 Binary files /dev/null and b/tests/data/treesatai/aerial_60m_acer_pseudoplatanus.zip differ diff --git a/tests/data/treesatai/aerial_60m_alnus_spec.zip b/tests/data/treesatai/aerial_60m_alnus_spec.zip new file mode 100644 index 00000000000..15cb0ecb3e2 Binary files /dev/null and b/tests/data/treesatai/aerial_60m_alnus_spec.zip differ diff --git a/tests/data/treesatai/aerial_60m_fagus_sylvatica.zip b/tests/data/treesatai/aerial_60m_fagus_sylvatica.zip new file mode 100644 index 00000000000..42716c30c93 Binary files /dev/null and b/tests/data/treesatai/aerial_60m_fagus_sylvatica.zip differ diff --git a/tests/data/treesatai/aerial_60m_picea_abies.zip b/tests/data/treesatai/aerial_60m_picea_abies.zip new file mode 100644 index 00000000000..33baaf54215 Binary files /dev/null and b/tests/data/treesatai/aerial_60m_picea_abies.zip differ diff --git a/tests/data/treesatai/aerial_60m_pseudotsuga_menziesii.zip b/tests/data/treesatai/aerial_60m_pseudotsuga_menziesii.zip new file mode 100644 index 00000000000..23a3636a759 Binary files /dev/null and b/tests/data/treesatai/aerial_60m_pseudotsuga_menziesii.zip differ diff --git a/tests/data/treesatai/aerial_60m_quercus_petraea.zip b/tests/data/treesatai/aerial_60m_quercus_petraea.zip new file mode 100644 index 00000000000..268ee1134ac Binary files /dev/null and b/tests/data/treesatai/aerial_60m_quercus_petraea.zip differ diff --git a/tests/data/treesatai/aerial_60m_quercus_rubra.zip b/tests/data/treesatai/aerial_60m_quercus_rubra.zip new file mode 100644 index 00000000000..4552c6fc66c Binary files /dev/null and b/tests/data/treesatai/aerial_60m_quercus_rubra.zip differ diff --git a/tests/data/treesatai/data.py b/tests/data/treesatai/data.py new file mode 100755 index 00000000000..dac5337cff8 --- /dev/null +++ b/tests/data/treesatai/data.py @@ -0,0 +1,129 @@ +#!/usr/bin/env python3 + +# Copyright (c) Microsoft Corporation. All rights reserved. +# Licensed under the MIT License. + +import glob +import json +import os +import random +import shutil +import zipfile + +import numpy as np +import rasterio +from rasterio import Affine +from rasterio.crs import CRS + +SIZE = 32 + +random.seed(0) +np.random.seed(0) + +classes = ( + 'Abies', + 'Acer', + 'Alnus', + 'Betula', + 'Cleared', + 'Fagus', + 'Fraxinus', + 'Larix', + 'Picea', + 'Pinus', + 'Populus', + 'Prunus', + 'Pseudotsuga', + 'Quercus', + 'Tilia', +) + +species = ( + 'Acer_pseudoplatanus', + 'Alnus_spec', + 'Fagus_sylvatica', + 'Picea_abies', + 'Pseudotsuga_menziesii', + 'Quercus_petraea', + 'Quercus_rubra', +) + +profile = { + 'aerial': { + 'driver': 'GTiff', + 'dtype': 'uint8', + 'nodata': None, + 'width': SIZE, + 'height': SIZE, + 'count': 4, + 'crs': CRS.from_epsg(25832), + 'transform': Affine( + 0.19999999999977022, 0.0, 552245.4, 0.0, -0.19999999999938728, 5728215.0 + ), + }, + 's1': { + 'driver': 'GTiff', + 'dtype': 'float32', + 'nodata': -9999.0, + 'width': SIZE // 16, + 'height': SIZE // 16, + 'count': 3, + 'crs': CRS.from_epsg(32632), + 'transform': Affine(10.0, 0.0, 552245.0, 0.0, -10.0, 5728215.0), + }, + 's2': { + 'driver': 'GTiff', + 'dtype': 'uint16', + 'nodata': None, + 'width': SIZE // 16, + 'height': SIZE // 16, + 'count': 12, + 'crs': CRS.from_epsg(32632), + 'transform': Affine(10.0, 0.0, 552241.6565, 0.0, -10.0, 5728211.6251), + }, +} + +multi_labels = {} +for split in ['train', 'test']: + with open(f'{split}_filenames.lst') as f: + for filename in f: + filename = filename.strip() + for sensor in ['aerial', 's1', 's2']: + kwargs = profile[sensor] + directory = os.path.join(sensor, '60m') + os.makedirs(directory, exist_ok=True) + if 'int' in kwargs['dtype']: + Z = np.random.randint( + np.iinfo(kwargs['dtype']).min, + np.iinfo(kwargs['dtype']).max, + size=(kwargs['height'], kwargs['width']), + dtype=kwargs['dtype'], + ) + else: + Z = np.random.rand(kwargs['height'], kwargs['width']) + + path = os.path.join(directory, filename) + with rasterio.open(path, 'w', **kwargs) as src: + for i in range(1, kwargs['count'] + 1): + src.write(Z, i) + + k = random.randrange(1, 4) + labels = random.choices(classes, k=k) + pcts = np.random.rand(k) + pcts /= np.sum(pcts) + multi_labels[filename] = list(map(list, zip(labels, map(float, pcts)))) + +os.makedirs('labels', exist_ok=True) +path = os.path.join('labels', 'TreeSatBA_v9_60m_multi_labels.json') +with open(path, 'w') as f: + json.dump(multi_labels, f) + +for sensor in ['s1', 's2', 'labels']: + shutil.make_archive(sensor, 'zip', '.', sensor) + +for spec in species: + path = f'aerial_60m_{spec}.zip'.lower() + with zipfile.ZipFile(path, 'w') as f: + for path in glob.iglob(os.path.join('aerial', '60m', f'{spec}_*.tif')): + filename = os.path.split(path)[-1] + f.write(path, arcname=filename) diff --git a/tests/data/treesatai/labels.zip b/tests/data/treesatai/labels.zip new file mode 100644 index 00000000000..24a773a5ef5 Binary files /dev/null and b/tests/data/treesatai/labels.zip differ diff --git a/tests/data/treesatai/labels/TreeSatBA_v9_60m_multi_labels.json b/tests/data/treesatai/labels/TreeSatBA_v9_60m_multi_labels.json new file mode 100644 index 00000000000..e9f9a12a37b --- /dev/null +++ b/tests/data/treesatai/labels/TreeSatBA_v9_60m_multi_labels.json @@ -0,0 +1 @@ +{"Picea_abies_3_46636_WEFL_NLF.tif": [["Prunus", 0.20692122963708826], ["Fraxinus", 0.7930787703629117]], "Pseudotsuga_menziesii_1_339575_BI_NLF.tif": [["Tilia", 0.4243067837573989], ["Larix", 0.5756932162426011]], "Quercus_rubra_1_92184_WEFL_NLF.tif": [["Tilia", 0.5816157697641007], ["Fagus", 0.4183842302358993]], "Fagus_sylvatica_9_29995_WEFL_NLF.tif": [["Larix", 1.0]], "Quercus_petraea_5_80549_WEFL_NLF.tif": [["Alnus", 0.5749721529276662], ["Acer", 0.4250278470723338]], "Acer_pseudoplatanus_3_5758_WEFL_NLF.tif": [["Tilia", 0.8430361090251272], ["Larix", 0.1569638909748729]], "Alnus_spec._5_13114_WEFL_NLF.tif": [["Pseudotsuga", 0.17881149698366108], ["Quercus", 0.38732907538618866], ["Cleared", 0.4338594276301503]], "Quercus_petraea_2_84375_WEFL_NLF.tif": [["Acer", 0.3909090505343164], ["Pseudotsuga", 0.2628926194326892], ["Cleared", 0.34619833003299444]], "Picea_abies_2_46896_WEFL_NLF.tif": [["Acer", 0.4953810312272686], ["Fraxinus", 0.0006659055704136941], ["Pinus", 0.5039530632023177]], "Acer_pseudoplatanus_4_6058_WEFL_NLF.tif": [["Tilia", 1.0]]} \ No newline at end of file diff --git a/tests/data/treesatai/s1.zip b/tests/data/treesatai/s1.zip new file mode 100644 index 00000000000..052d0dc5553 Binary files /dev/null and b/tests/data/treesatai/s1.zip differ diff --git a/tests/data/treesatai/s1/60m/Acer_pseudoplatanus_3_5758_WEFL_NLF.tif b/tests/data/treesatai/s1/60m/Acer_pseudoplatanus_3_5758_WEFL_NLF.tif new file mode 100644 index 00000000000..e3180fbed8e Binary files /dev/null and b/tests/data/treesatai/s1/60m/Acer_pseudoplatanus_3_5758_WEFL_NLF.tif differ diff --git a/tests/data/treesatai/s1/60m/Acer_pseudoplatanus_4_6058_WEFL_NLF.tif b/tests/data/treesatai/s1/60m/Acer_pseudoplatanus_4_6058_WEFL_NLF.tif new file mode 100644 index 00000000000..0d8403f3f3b Binary files /dev/null and b/tests/data/treesatai/s1/60m/Acer_pseudoplatanus_4_6058_WEFL_NLF.tif differ diff --git a/tests/data/treesatai/s1/60m/Alnus_spec._5_13114_WEFL_NLF.tif b/tests/data/treesatai/s1/60m/Alnus_spec._5_13114_WEFL_NLF.tif new file mode 100644 index 00000000000..5f73542d330 Binary files /dev/null and b/tests/data/treesatai/s1/60m/Alnus_spec._5_13114_WEFL_NLF.tif differ diff --git a/tests/data/treesatai/s1/60m/Fagus_sylvatica_9_29995_WEFL_NLF.tif b/tests/data/treesatai/s1/60m/Fagus_sylvatica_9_29995_WEFL_NLF.tif new file mode 100644 index 00000000000..343126b9235 Binary files /dev/null and b/tests/data/treesatai/s1/60m/Fagus_sylvatica_9_29995_WEFL_NLF.tif differ diff --git a/tests/data/treesatai/s1/60m/Picea_abies_2_46896_WEFL_NLF.tif b/tests/data/treesatai/s1/60m/Picea_abies_2_46896_WEFL_NLF.tif new file mode 100644 index 00000000000..b15947f122c Binary files /dev/null and b/tests/data/treesatai/s1/60m/Picea_abies_2_46896_WEFL_NLF.tif differ diff --git a/tests/data/treesatai/s1/60m/Picea_abies_3_46636_WEFL_NLF.tif b/tests/data/treesatai/s1/60m/Picea_abies_3_46636_WEFL_NLF.tif new file mode 100644 index 00000000000..c9878414adf Binary files /dev/null and b/tests/data/treesatai/s1/60m/Picea_abies_3_46636_WEFL_NLF.tif differ diff --git a/tests/data/treesatai/s1/60m/Pseudotsuga_menziesii_1_339575_BI_NLF.tif b/tests/data/treesatai/s1/60m/Pseudotsuga_menziesii_1_339575_BI_NLF.tif new file mode 100644 index 00000000000..00ba9b03129 Binary files /dev/null and b/tests/data/treesatai/s1/60m/Pseudotsuga_menziesii_1_339575_BI_NLF.tif differ diff --git a/tests/data/treesatai/s1/60m/Quercus_petraea_2_84375_WEFL_NLF.tif b/tests/data/treesatai/s1/60m/Quercus_petraea_2_84375_WEFL_NLF.tif new file mode 100644 index 00000000000..2e4898fb55d Binary files /dev/null and b/tests/data/treesatai/s1/60m/Quercus_petraea_2_84375_WEFL_NLF.tif differ diff --git a/tests/data/treesatai/s1/60m/Quercus_petraea_5_80549_WEFL_NLF.tif b/tests/data/treesatai/s1/60m/Quercus_petraea_5_80549_WEFL_NLF.tif new file mode 100644 index 00000000000..0562717348c Binary files /dev/null and b/tests/data/treesatai/s1/60m/Quercus_petraea_5_80549_WEFL_NLF.tif differ diff --git a/tests/data/treesatai/s1/60m/Quercus_rubra_1_92184_WEFL_NLF.tif b/tests/data/treesatai/s1/60m/Quercus_rubra_1_92184_WEFL_NLF.tif new file mode 100644 index 00000000000..db825c3ff27 Binary files /dev/null and b/tests/data/treesatai/s1/60m/Quercus_rubra_1_92184_WEFL_NLF.tif differ diff --git a/tests/data/treesatai/s2.zip b/tests/data/treesatai/s2.zip new file mode 100644 index 00000000000..eb5dabc8c98 Binary files /dev/null and b/tests/data/treesatai/s2.zip differ diff --git a/tests/data/treesatai/s2/60m/Acer_pseudoplatanus_3_5758_WEFL_NLF.tif b/tests/data/treesatai/s2/60m/Acer_pseudoplatanus_3_5758_WEFL_NLF.tif new file mode 100644 index 00000000000..9d182f62584 Binary files /dev/null and b/tests/data/treesatai/s2/60m/Acer_pseudoplatanus_3_5758_WEFL_NLF.tif differ diff --git a/tests/data/treesatai/s2/60m/Acer_pseudoplatanus_4_6058_WEFL_NLF.tif b/tests/data/treesatai/s2/60m/Acer_pseudoplatanus_4_6058_WEFL_NLF.tif new file mode 100644 index 00000000000..d61c7b7a20b Binary files /dev/null and b/tests/data/treesatai/s2/60m/Acer_pseudoplatanus_4_6058_WEFL_NLF.tif differ diff --git a/tests/data/treesatai/s2/60m/Alnus_spec._5_13114_WEFL_NLF.tif b/tests/data/treesatai/s2/60m/Alnus_spec._5_13114_WEFL_NLF.tif new file mode 100644 index 00000000000..660f23905de Binary files /dev/null and b/tests/data/treesatai/s2/60m/Alnus_spec._5_13114_WEFL_NLF.tif differ diff --git a/tests/data/treesatai/s2/60m/Fagus_sylvatica_9_29995_WEFL_NLF.tif b/tests/data/treesatai/s2/60m/Fagus_sylvatica_9_29995_WEFL_NLF.tif new file mode 100644 index 00000000000..bf8c659fb45 Binary files /dev/null and b/tests/data/treesatai/s2/60m/Fagus_sylvatica_9_29995_WEFL_NLF.tif differ diff --git a/tests/data/treesatai/s2/60m/Picea_abies_2_46896_WEFL_NLF.tif b/tests/data/treesatai/s2/60m/Picea_abies_2_46896_WEFL_NLF.tif new file mode 100644 index 00000000000..7bd25b4c837 Binary files /dev/null and b/tests/data/treesatai/s2/60m/Picea_abies_2_46896_WEFL_NLF.tif differ diff --git a/tests/data/treesatai/s2/60m/Picea_abies_3_46636_WEFL_NLF.tif b/tests/data/treesatai/s2/60m/Picea_abies_3_46636_WEFL_NLF.tif new file mode 100644 index 00000000000..b62e8364578 Binary files /dev/null and b/tests/data/treesatai/s2/60m/Picea_abies_3_46636_WEFL_NLF.tif differ diff --git a/tests/data/treesatai/s2/60m/Pseudotsuga_menziesii_1_339575_BI_NLF.tif b/tests/data/treesatai/s2/60m/Pseudotsuga_menziesii_1_339575_BI_NLF.tif new file mode 100644 index 00000000000..938c8528c28 Binary files /dev/null and b/tests/data/treesatai/s2/60m/Pseudotsuga_menziesii_1_339575_BI_NLF.tif differ diff --git a/tests/data/treesatai/s2/60m/Quercus_petraea_2_84375_WEFL_NLF.tif b/tests/data/treesatai/s2/60m/Quercus_petraea_2_84375_WEFL_NLF.tif new file mode 100644 index 00000000000..69603a72ae3 Binary files /dev/null and b/tests/data/treesatai/s2/60m/Quercus_petraea_2_84375_WEFL_NLF.tif differ diff --git a/tests/data/treesatai/s2/60m/Quercus_petraea_5_80549_WEFL_NLF.tif b/tests/data/treesatai/s2/60m/Quercus_petraea_5_80549_WEFL_NLF.tif new file mode 100644 index 00000000000..affe18983a6 Binary files /dev/null and b/tests/data/treesatai/s2/60m/Quercus_petraea_5_80549_WEFL_NLF.tif differ diff --git a/tests/data/treesatai/s2/60m/Quercus_rubra_1_92184_WEFL_NLF.tif b/tests/data/treesatai/s2/60m/Quercus_rubra_1_92184_WEFL_NLF.tif new file mode 100644 index 00000000000..ccd44d2b692 Binary files /dev/null and b/tests/data/treesatai/s2/60m/Quercus_rubra_1_92184_WEFL_NLF.tif differ diff --git a/tests/data/treesatai/test_filenames.lst b/tests/data/treesatai/test_filenames.lst new file mode 100644 index 00000000000..9d81989c444 --- /dev/null +++ b/tests/data/treesatai/test_filenames.lst @@ -0,0 +1 @@ +Acer_pseudoplatanus_4_6058_WEFL_NLF.tif diff --git a/tests/data/treesatai/train_filenames.lst b/tests/data/treesatai/train_filenames.lst new file mode 100644 index 00000000000..9a92169b832 --- /dev/null +++ b/tests/data/treesatai/train_filenames.lst @@ -0,0 +1,9 @@ +Picea_abies_3_46636_WEFL_NLF.tif +Pseudotsuga_menziesii_1_339575_BI_NLF.tif +Quercus_rubra_1_92184_WEFL_NLF.tif +Fagus_sylvatica_9_29995_WEFL_NLF.tif +Quercus_petraea_5_80549_WEFL_NLF.tif +Acer_pseudoplatanus_3_5758_WEFL_NLF.tif +Alnus_spec._5_13114_WEFL_NLF.tif +Quercus_petraea_2_84375_WEFL_NLF.tif +Picea_abies_2_46896_WEFL_NLF.tif diff --git a/tests/datasets/test_hyspecnet.py b/tests/datasets/test_hyspecnet.py new file mode 100644 index 00000000000..1e5a646cee6 --- /dev/null +++ b/tests/datasets/test_hyspecnet.py @@ -0,0 +1,58 @@ +# Copyright (c) Microsoft Corporation. All rights reserved. +# Licensed under the MIT License. + +import glob +import os +import shutil +from pathlib import Path + +import matplotlib.pyplot as plt +import pytest +import torch.nn as nn +from pytest import MonkeyPatch +from torch import Tensor + +from torchgeo.datasets import DatasetNotFoundError, HySpecNet11k, RGBBandsMissingError + +root = os.path.join('tests', 'data', 'hyspecnet') +md5s = {'hyspecnet-11k-01.tar.gz': '', 'hyspecnet-11k-splits.tar.gz': ''} + + +class TestHySpecNet11k: + @pytest.fixture + def dataset(self, monkeypatch: MonkeyPatch) -> HySpecNet11k: + monkeypatch.setattr(HySpecNet11k, 'url', root + os.sep) + monkeypatch.setattr(HySpecNet11k, 'md5s', md5s) + transforms = nn.Identity() + return HySpecNet11k(root, transforms=transforms) + + def test_getitem(self, dataset: HySpecNet11k) -> None: + x = dataset[0] + assert isinstance(x, dict) + assert isinstance(x['image'], Tensor) + + def test_len(self, dataset: HySpecNet11k) -> None: + assert len(dataset) == 2 + + def test_download(self, dataset: HySpecNet11k, tmp_path: Path) -> None: + HySpecNet11k(tmp_path, download=True) + + def test_extract(self, dataset: HySpecNet11k, tmp_path: Path) -> None: + for file in glob.iglob(os.path.join(root, '*.tar.gz')): + shutil.copy(file, tmp_path) + HySpecNet11k(tmp_path) + + def test_not_downloaded(self, tmp_path: Path) -> None: + with pytest.raises(DatasetNotFoundError, match='Dataset not found'): + HySpecNet11k(tmp_path) + + def test_plot(self, dataset: HySpecNet11k) -> None: + x = dataset[0] + dataset.plot(x, suptitle='Test') + plt.close() + + def test_plot_rgb(self, dataset: HySpecNet11k) -> None: + dataset = HySpecNet11k(root=dataset.root, bands=(1, 2, 3)) + match = 'Dataset does not contain some of the RGB bands' + with pytest.raises(RGBBandsMissingError, match=match): + dataset.plot(dataset[0]) diff --git a/tests/datasets/test_treesatai.py b/tests/datasets/test_treesatai.py new file mode 100644 index 00000000000..7788adb9791 --- /dev/null +++ b/tests/datasets/test_treesatai.py @@ -0,0 +1,62 @@ +# Copyright (c) Microsoft Corporation. All rights reserved. +# Licensed under the MIT License. + +import glob +import os +import shutil +from pathlib import Path + +import matplotlib.pyplot as plt +import pytest +import torch.nn as nn +from pytest import MonkeyPatch +from torch import Tensor + +from torchgeo.datasets import DatasetNotFoundError, TreeSatAI + +root = os.path.join('tests', 'data', 'treesatai') +md5s = { + 'aerial_60m_acer_pseudoplatanus.zip': '', + 'labels.zip': '', + 's1.zip': '', + 's2.zip': '', + 'test_filenames.lst': '', + 'train_filenames.lst': '', +} + + +class TestTreeSatAI: + @pytest.fixture + def dataset(self, monkeypatch: MonkeyPatch) -> TreeSatAI: + monkeypatch.setattr(TreeSatAI, 'url', root + os.sep) + monkeypatch.setattr(TreeSatAI, 'md5s', md5s) + transforms = nn.Identity() + return TreeSatAI(root, transforms=transforms) + + def test_getitem(self, dataset: TreeSatAI) -> None: + x = dataset[0] + assert isinstance(x, dict) + assert isinstance(x['label'], Tensor) + for sensor in dataset.sensors: + assert isinstance(x[f'image_{sensor}'], Tensor) + + def test_len(self, dataset: TreeSatAI) -> None: + assert len(dataset) == 9 + + def test_download(self, dataset: TreeSatAI, tmp_path: Path) -> None: + TreeSatAI(tmp_path, download=True) + + def test_extract(self, dataset: TreeSatAI, tmp_path: Path) -> None: + for file in glob.iglob(os.path.join(root, '*.*')): + shutil.copy(file, tmp_path) + TreeSatAI(tmp_path) + + def test_not_downloaded(self, tmp_path: Path) -> None: + with pytest.raises(DatasetNotFoundError, match='Dataset not found'): + TreeSatAI(tmp_path) + + def test_plot(self, dataset: TreeSatAI) -> None: + x = dataset[0] + x['prediction'] = x['label'] + dataset.plot(x) + plt.close() diff --git a/tests/trainers/test_byol.py b/tests/trainers/test_byol.py index b0c13b6075b..808bf937220 100644 --- a/tests/trainers/test_byol.py +++ b/tests/trainers/test_byol.py @@ -41,6 +41,7 @@ class TestBYOLTask: 'name', [ 'chesapeake_cvpr_prior_byol', + 'hyspecnet_byol', 'seco_byol_1', 'seco_byol_2', 'ssl4eo_l_byol_1', diff --git a/tests/trainers/test_classification.py b/tests/trainers/test_classification.py index be8132c808b..e2e2d9bb3e5 100644 --- a/tests/trainers/test_classification.py +++ b/tests/trainers/test_classification.py @@ -237,7 +237,7 @@ def test_freeze_backbone(self, model_name: str) -> None: class TestMultiLabelClassificationTask: @pytest.mark.parametrize( - 'name', ['bigearthnet_all', 'bigearthnet_s1', 'bigearthnet_s2'] + 'name', ['bigearthnet_all', 'bigearthnet_s1', 'bigearthnet_s2', 'treesatai'] ) def test_trainer( self, monkeypatch: MonkeyPatch, name: str, fast_dev_run: bool diff --git a/tests/trainers/test_moco.py b/tests/trainers/test_moco.py index 32c002dc573..002944b929e 100644 --- a/tests/trainers/test_moco.py +++ b/tests/trainers/test_moco.py @@ -29,6 +29,7 @@ class TestMoCoTask: 'name', [ 'chesapeake_cvpr_prior_moco', + 'hyspecnet_moco', 'seco_moco_1', 'seco_moco_2', 'ssl4eo_l_moco_1', diff --git a/tests/trainers/test_simclr.py b/tests/trainers/test_simclr.py index 7e1292ab7c0..3924b6e3785 100644 --- a/tests/trainers/test_simclr.py +++ b/tests/trainers/test_simclr.py @@ -29,6 +29,7 @@ class TestSimCLRTask: 'name', [ 'chesapeake_cvpr_prior_simclr', + 'hyspecnet_simclr', 'seco_simclr_1', 'seco_simclr_2', 'ssl4eo_l_simclr_1', diff --git a/torchgeo/datamodules/__init__.py b/torchgeo/datamodules/__init__.py index dc9513a6524..6dd7231e3df 100644 --- a/torchgeo/datamodules/__init__.py +++ b/torchgeo/datamodules/__init__.py @@ -21,6 +21,7 @@ from .geo import BaseDataModule, GeoDataModule, NonGeoDataModule from .geonrw import GeoNRWDataModule from .gid15 import GID15DataModule +from .hyspecnet import HySpecNet11kDataModule from .inria import InriaAerialImageLabelingDataModule from .iobench import IOBenchDataModule from .l7irish import L7IrishDataModule @@ -47,6 +48,7 @@ from .ssl4eo import SSL4EOLDataModule, SSL4EOS12DataModule from .ssl4eo_benchmark import SSL4EOLBenchmarkDataModule from .sustainbench_crop_yield import SustainBenchCropYieldDataModule +from .treesatai import TreeSatAIDataModule from .ucmerced import UCMercedDataModule from .usavars import USAVarsDataModule from .utils import MisconfigurationException @@ -75,6 +77,7 @@ 'GID15DataModule', 'GeoDataModule', 'GeoNRWDataModule', + 'HySpecNet11kDataModule', 'IOBenchDataModule', 'InriaAerialImageLabelingDataModule', 'L7IrishDataModule', @@ -108,6 +111,7 @@ 'SpaceNet6DataModule', 'SpaceNetBaseDataModule', 'SustainBenchCropYieldDataModule', + 'TreeSatAIDataModule', 'TropicalCycloneDataModule', 'UCMercedDataModule', 'USAVarsDataModule', diff --git a/torchgeo/datamodules/hyspecnet.py b/torchgeo/datamodules/hyspecnet.py new file mode 100644 index 00000000000..3e508ef11a7 --- /dev/null +++ b/torchgeo/datamodules/hyspecnet.py @@ -0,0 +1,35 @@ +# Copyright (c) Microsoft Corporation. All rights reserved. +# Licensed under the MIT License. + +"""HySpecNet datamodule.""" + +from typing import Any + +import torch + +from ..datasets import HySpecNet11k +from .geo import NonGeoDataModule + + +class HySpecNet11kDataModule(NonGeoDataModule): + """LightningDataModule implementation for the HySpecNet11k dataset. + + .. versionadded:: 0.7 + """ + + # https://git.tu-berlin.de/rsim/hyspecnet-tools/-/blob/main/tif_to_npy.ipynb + mean = torch.tensor(0) + std = torch.tensor(10000) + + def __init__( + self, batch_size: int = 64, num_workers: int = 0, **kwargs: Any + ) -> None: + """Initialize a new HySpecNet11kDataModule instance. + + Args: + batch_size: Size of each mini-batch. + num_workers: Number of workers for parallel data loading. + **kwargs: Additional keyword arguments passed to + :class:`~torchgeo.datasets.HySpecNet11k`. + """ + super().__init__(HySpecNet11k, batch_size, num_workers, **kwargs) diff --git a/torchgeo/datamodules/treesatai.py b/torchgeo/datamodules/treesatai.py new file mode 100644 index 00000000000..3db24b4724a --- /dev/null +++ b/torchgeo/datamodules/treesatai.py @@ -0,0 +1,142 @@ +# Copyright (c) Microsoft Corporation. All rights reserved. +# Licensed under the MIT License. + +"""TreeSatAI datamodules.""" + +from typing import Any + +import kornia.augmentation as K +import torch +from torch import Tensor +from torch.utils.data import random_split + +from ..datasets import TreeSatAI +from ..samplers.utils import _to_tuple +from .geo import NonGeoDataModule + +# https://git.tu-berlin.de/rsim/treesat_benchmark/-/blob/master/configs/multimodal/AllModes_Xformer_ResnetScratch_v8.json +means = { + 'aerial': [ + 151.26809261440323, + 93.1159469148246, + 85.05016794624635, + 81.0471576353153, + ], + 's1': [-6.933713050794077, -12.628564056094067, 0.47448312147709354], + 's2': [ + 231.43385024546893, + 376.94788434611434, + 241.03688288984037, + 2809.8421354087955, + 616.5578221193639, + 2104.3826773960823, + 2695.083864757169, + 2969.868417923599, + 1306.0814241837832, + 587.0608264363341, + 249.1888624097736, + 2950.2294375352285, + ], +} +stds = { + 'aerial': [ + 48.70879149145466, + 33.59622314610158, + 28.000497087051126, + 33.683983599997724, + ], + 's1': [87.8762246957811, 47.03070478433704, 1.297291303623673], + 's2': [ + 123.16515044781909, + 139.78991338362886, + 140.6154081184225, + 786.4508872594147, + 202.51268536579394, + 530.7255451201194, + 710.2650071967689, + 777.4421400779165, + 424.30312334282684, + 247.21468849049668, + 122.80062680549261, + 702.7404237034002, + ], +} + + +class TreeSatAIDataModule(NonGeoDataModule): + """LightningDataModule implementation for the TreeSatAI dataset. + + .. versionadded:: 0.7 + """ + + def __init__( + self, + batch_size: int = 64, + patch_size: int | tuple[int, int] = 304, + num_workers: int = 0, + **kwargs: Any, + ) -> None: + """Initialize a new TreeSatAIDataModule instance. + + Args: + batch_size: Size of each mini-batch. + patch_size: Size of each patch, either ``size`` or ``(height, width)``. + num_workers: Number of workers for parallel data loading. + **kwargs: Additional keyword arguments passed to + :class:`~torchgeo.datasets.TreeSatAI`. + """ + super().__init__(TreeSatAI, batch_size, num_workers, **kwargs) + + self.patch_size = _to_tuple(patch_size) + self.sensors = kwargs.get('sensors', TreeSatAI.all_sensors) + + self.train_aug = K.AugmentationSequential( + K.RandomVerticalFlip(p=0.5), + K.RandomHorizontalFlip(p=0.5), + K.Resize(self.patch_size), + data_keys=None, + keepdim=True, + ) + self.aug = K.AugmentationSequential( + K.Resize(self.patch_size), data_keys=None, keepdim=True + ) + + def setup(self, stage: str) -> None: + """Set up datasets. + + Args: + stage: Either 'fit', 'validate', 'test', or 'predict'. + """ + # Convert 90-10 train-test split to 80-10-10 train-val-test split + train_val_dataset = TreeSatAI(split='train', **self.kwargs) + self.test_dataset = TreeSatAI(split='test', **self.kwargs) + generator = torch.Generator().manual_seed(0) + self.train_dataset, self.val_dataset = random_split( + train_val_dataset, + [len(train_val_dataset) - len(self.test_dataset), len(self.test_dataset)], + generator=generator, + ) + + def on_after_batch_transfer( + self, batch: dict[str, Tensor], dataloader_idx: int + ) -> dict[str, Tensor]: + """Apply batch augmentations to the batch after it is transferred to the device. + + Args: + batch: A batch of data that needs to be altered or augmented. + dataloader_idx: The index of the dataloader to which the batch belongs. + + Returns: + A batch of data. + """ + batch = super().on_after_batch_transfer(batch, dataloader_idx) + + images = [] + for sensor in self.sensors: + aug = K.Normalize(mean=means[sensor], std=stds[sensor], keepdim=True) + batch[f'image_{sensor}'] = aug(batch[f'image_{sensor}']) + images.append(batch[f'image_{sensor}']) + + batch['image'] = torch.cat(images, dim=1) + + return batch diff --git a/torchgeo/datasets/__init__.py b/torchgeo/datasets/__init__.py index 3016c3af7e2..f55ef3af22c 100644 --- a/torchgeo/datasets/__init__.py +++ b/torchgeo/datasets/__init__.py @@ -61,6 +61,7 @@ from .geonrw import GeoNRW from .gid15 import GID15 from .globbiomass import GlobBiomass +from .hyspecnet import HySpecNet11k from .idtrees import IDTReeS from .inaturalist import INaturalist from .inria import InriaAerialImageLabeling @@ -131,6 +132,7 @@ from .ssl4eo import SSL4EO, SSL4EOL, SSL4EOS12 from .ssl4eo_benchmark import SSL4EOLBenchmark from .sustainbench_crop_yield import SustainBenchCropYield +from .treesatai import TreeSatAI from .ucmerced import UCMerced from .usavars import USAVars from .utils import ( @@ -214,6 +216,7 @@ 'GeoDataset', 'GeoNRW', 'GlobBiomass', + 'HySpecNet11k', 'IDTReeS', 'INaturalist', 'IOBench', @@ -274,6 +277,7 @@ 'SpaceNet7', 'SpaceNet8', 'SustainBenchCropYield', + 'TreeSatAI', 'TropicalCyclone', 'UCMerced', 'USAVars', diff --git a/torchgeo/datasets/hyspecnet.py b/torchgeo/datasets/hyspecnet.py new file mode 100644 index 00000000000..412ea504b24 --- /dev/null +++ b/torchgeo/datasets/hyspecnet.py @@ -0,0 +1,229 @@ +# Copyright (c) Microsoft Corporation. All rights reserved. +# Licensed under the MIT License. + +"""HySpecNet dataset.""" + +import os +from collections.abc import Callable, Sequence +from typing import ClassVar + +import rasterio as rio +import torch +from einops import rearrange +from matplotlib import pyplot as plt +from matplotlib.figure import Figure +from torch import Tensor + +from .errors import DatasetNotFoundError, RGBBandsMissingError +from .geo import NonGeoDataset +from .utils import Path, download_url, extract_archive, percentile_normalization + +# https://git.tu-berlin.de/rsim/hyspecnet-tools/-/blob/main/tif_to_npy.ipynb +invalid_channels = [ + 126, + 127, + 128, + 129, + 130, + 131, + 132, + 133, + 134, + 135, + 136, + 137, + 138, + 139, + 140, + 160, + 161, + 162, + 163, + 164, + 165, + 166, +] +valid_channels_ids = [c + 1 for c in range(224) if c not in invalid_channels] + + +class HySpecNet11k(NonGeoDataset): + """HySpecNet-11k dataset. + + `HySpecNet-11k `__ is a large-scale + benchmark dataset for hyperspectral image compression and self-supervised learning. + It is made up of 11,483 nonoverlapping image patches acquired by the + `EnMAP satellite `_. Each patch is a portion of 128 x 128 + pixels with 224 spectral bands and with a ground sample distance of 30 m. + + To construct HySpecNet-11k, a total of 250 EnMAP tiles acquired during the routine + operation phase between 2 November 2022 and 9 November 2022 were considered. The + considered tiles are associated with less than 10% cloud and snow cover. The tiles + were radiometrically, geometrically and atmospherically corrected (L2A water & land + product). Then, the tiles were divided into nonoverlapping image patches. The + cropped patches at the borders of the tiles were eliminated. As a result, more than + 45 patches per tile are obtained, resulting in 11,483 patches for the full dataset. + + We provide predefined splits obtained by randomly dividing HySpecNet into: + + #. a training set that includes 70% of the patches, + #. a validation set that includes 20% of the patches, and + #. a test set that includes 10% of the patches. + + Depending on the way that we used for splitting the dataset, we define two + different splits: + + #. an easy split, where patches from the same tile can be present in different sets + (patchwise splitting); and + #. a hard split, where all patches from one tile belong to the same set + (tilewise splitting). + + If you use this dataset in your research, please cite the following paper: + + * https://arxiv.org/abs/2306.00385 + + .. versionadded:: 0.7 + """ + + url = 'https://hf.co/datasets/torchgeo/hyspecnet/resolve/13e110422a6925cbac0f11edff610219b9399227/' + md5s: ClassVar[dict[str, str]] = { + 'hyspecnet-11k-01.tar.gz': '974aae9197006727b42ec81796049efe', + 'hyspecnet-11k-02.tar.gz': 'f80574485f835b8a263b6c64076c0c62', + 'hyspecnet-11k-03.tar.gz': '6bc1de573f97fa4a75b79719b9270cb3', + 'hyspecnet-11k-04.tar.gz': '2463dc10653cb8be10d44951307c5e7d', + 'hyspecnet-11k-05.tar.gz': '16c1bd9e684673e741c0849bd015c988', + 'hyspecnet-11k-06.tar.gz': '8eef16b67d71af6eb4bc836d294fe3c4', + 'hyspecnet-11k-07.tar.gz': 'f61f0e7d6b05c861e69026b09130a5d6', + 'hyspecnet-11k-08.tar.gz': '19d390bc9e61b85e7d765f3077984976', + 'hyspecnet-11k-09.tar.gz': '197ff47befe5b9de88be5e1321c5ce5d', + 'hyspecnet-11k-10.tar.gz': '9e674cca126a9d139d6584be148d4bac', + 'hyspecnet-11k-splits.tar.gz': '94fad9e3c979c612c29a045406247d6c', + } + + all_bands = valid_channels_ids + rgb_bands = (43, 28, 10) + + def __init__( + self, + root: Path = 'data', + split: str = 'train', + strategy: str = 'easy', + bands: Sequence[int] = all_bands, + transforms: Callable[[dict[str, Tensor]], dict[str, Tensor]] | None = None, + download: bool = False, + checksum: bool = False, + ) -> None: + """Initialize a new HySpecNet11k instance. + + Args: + root: Root directory where dataset can be found. + split: One of 'train', 'val', or 'test'. + strategy: Either 'easy' for patchwise splitting or 'hard' for tilewise + splitting. + bands: Bands to return. + transforms: A function/transform that takes input sample and its target as + entry and returns a transformed version. + download: If True, download dataset and store it in the root directory. + checksum: If True, check the MD5 of the downloaded files (may be slow). + + Raises: + DatasetNotFoundError: If dataset is not found and *download* is False. + """ + self.root = root + self.split = split + self.strategy = strategy + self.bands = bands + self.transforms = transforms + self.download = download + self.checksum = checksum + + self._verify() + + path = os.path.join(root, 'hyspecnet-11k', 'splits', strategy, f'{split}.csv') + with open(path) as f: + self.files = f.read().strip().split('\n') + + def __len__(self) -> int: + """Return the number of data points in the dataset. + + Returns: + Length of the dataset. + """ + return len(self.files) + + def __getitem__(self, index: int) -> dict[str, Tensor]: + """Return an index within the dataset. + + Args: + index: Index to return. + + Returns: + Data and label at that index. + """ + file = self.files[index].replace('DATA.npy', 'SPECTRAL_IMAGE.TIF') + with rio.open(os.path.join(self.root, 'hyspecnet-11k', 'patches', file)) as src: + sample = {'image': torch.tensor(src.read(self.bands).astype('float32'))} + + if self.transforms is not None: + sample = self.transforms(sample) + + return sample + + def _verify(self) -> None: + """Verify the integrity of the dataset.""" + # Check if the extracted files already exist + exists = [] + for directory in ['patches', 'splits']: + path = os.path.join(self.root, 'hyspecnet-11k', directory) + exists.append(os.path.isdir(path)) + + if all(exists): + return + + for file, md5 in self.md5s.items(): + # Check if the file has already been downloaded + path = os.path.join(self.root, file) + if os.path.isfile(path): + extract_archive(path) + continue + + # Check if the user requested to download the dataset + if self.download: + url = self.url + file + download_url(url, self.root, md5=md5 if self.checksum else None) + extract_archive(path) + continue + + raise DatasetNotFoundError(self) + + def plot(self, sample: dict[str, Tensor], suptitle: str | None = None) -> Figure: + """Plot a sample from the dataset. + + Args: + sample: A sample returned by :meth:`__getitem__`. + suptitle: optional string to use as a suptitle + + Returns: + A matplotlib Figure with the rendered sample. + + Raises: + RGBBandsMissingError: If *bands* does not include all RGB bands. + """ + rgb_indices = [] + for band in self.rgb_bands: + if band in self.bands: + rgb_indices.append(self.bands.index(band)) + else: + raise RGBBandsMissingError() + + image = sample['image'][rgb_indices].cpu().numpy() + image = rearrange(image, 'c h w -> h w c') + image = percentile_normalization(image) + + fig, ax = plt.subplots() + ax.imshow(image) + ax.axis('off') + + if suptitle: + fig.suptitle(suptitle) + + return fig diff --git a/torchgeo/datasets/treesatai.py b/torchgeo/datasets/treesatai.py new file mode 100644 index 00000000000..5f55f158361 --- /dev/null +++ b/torchgeo/datasets/treesatai.py @@ -0,0 +1,293 @@ +# Copyright (c) Microsoft Corporation. All rights reserved. +# Licensed under the MIT License. + +"""TreeSatAI datasets.""" + +import json +import os +from collections.abc import Callable, Sequence +from typing import ClassVar + +import rasterio as rio +import torch +from einops import rearrange +from matplotlib import pyplot as plt +from matplotlib.figure import Figure +from torch import Tensor + +from .errors import DatasetNotFoundError +from .geo import NonGeoDataset +from .utils import Path, download_url, extract_archive, percentile_normalization + + +class TreeSatAI(NonGeoDataset): + """TreeSatAI Benchmark Archive. + + `TreeSatAI Benchmark Archive `_ is a + multi-sensor, multi-label dataset for tree species classification in remote + sensing. It was created by combining labels from the federal forest inventory of + Lower Saxony, Germany with 20 cm Color-Infrared (CIR) and 10 m Sentinel imagery. + + The TreeSatAI Benchmark Archive contains: + + * 50,381 image triplets (aerial, Sentinel-1, Sentinel-2) + * synchronized time steps and locations + * all original spectral bands/polarizations from the sensors + * 20 species classes (single labels) + * 12 age classes (single labels) + * 15 genus classes (multi labels) + * 60 m and 200 m patches + * fixed split for train (90%) and test (10%) data + * additional single labels such as English species name, genus, + forest stand type, foliage type, land cover + + If you use this dataset in your research, please cite the following paper: + + * https://doi.org/10.5194/essd-15-681-2023 + + .. versionadded:: 0.7 + """ + + url = 'https://zenodo.org/records/6780578/files/' + md5s: ClassVar[dict[str, str]] = { + 'aerial_60m_abies_alba.zip': '4298b1c9fbf6d0d85f7aa208ff5fe0c9', + 'aerial_60m_acer_pseudoplatanus.zip': '7c31d7ddea841f6509deece8f984a79e', + 'aerial_60m_alnus_spec.zip': '34ea107f43c6172c6d2652dbf26306af', + 'aerial_60m_betula_spec.zip': '69de9373739a027692a823846434fa0c', + 'aerial_60m_cleared.zip': '8dffbb2f6aad17ef83721cffa5b52d96', + 'aerial_60m_fagus_sylvatica.zip': '77b277e69e90bfbd3c5fd15a73d228fe', + 'aerial_60m_fraxinus_excelsior.zip': '9a88a8e6821f8a54ded950de9238831f', + 'aerial_60m_larix_decidua.zip': 'aa0bc5b091b099018a078536ef429031', + 'aerial_60m_larix_kaempferi.zip': '429df073f69f8bbf60aef765e1c925ba', + 'aerial_60m_picea_abies.zip': 'edb9b1bc9a5a7b405f4cbb0d71cedf54', + 'aerial_60m_pinus_nigra.zip': '96bf1798ef82f712ea46c2963ddb7083', + 'aerial_60m_pinus_strobus.zip': '0ff818c6d31f59b8488880e49b300c7a', + 'aerial_60m_pinus_sylvestris.zip': '298cbaac4d9f07a204e1e74e8446798d', + 'aerial_60m_populus_spec.zip': '46fcff76b119cc24f3caf938a0bb433a', + 'aerial_60m_prunus_spec.zip': 'fb1c570d3ea925a049630224ccb354bc', + 'aerial_60m_pseudotsuga_menziesii.zip': '2d05511ceabf4037b869eca928f3c04e', + 'aerial_60m_quercus_petraea.zip': '31f573fb0419b2b453ed7da1c4d2a298', + 'aerial_60m_quercus_robur.zip': 'bcd90506509de26692c043f4c8d73af0', + 'aerial_60m_quercus_rubra.zip': '71d8495725ed1b4f27d9e382409fcc5e', + 'aerial_60m_tilia_spec.zip': 'f81558c9c7189ac8a257d041ee43c1c9', + 'geojson.zip': 'aa749718f3cb76c1dfc9cddc2ed201db', + 'labels.zip': '656f1b68ec9ab70afd02bb127b75bb24', + 's1.zip': 'bed4fc8cb65da46a24ec1bc6cea2763c', + 's2.zip': '453ba69056aa33a3c6b97afb7b6afadb', + 'test_filenames.lst': '2166903d947f0025f61e342da466f917', + 'train_filenames.lst': 'a1a0148e8120b0268f76d2e98a68436f', + } + + # Genus-level classes (species-level labels also exist) + classes = ( + 'Abies', # fir + 'Acer', # maple + 'Alnus', # alder + 'Betula', # birch + 'Cleared', # none + 'Fagus', # beech + 'Fraxinus', # ash + 'Larix', # larch + 'Picea', # spruce + 'Pinus', # pine + 'Populus', # poplar + 'Prunus', # cherry + 'Pseudotsuga', # Douglas fir + 'Quercus', # oak + 'Tilia', # linden + ) + + # https://zenodo.org/records/6780578/files/220629_doc_TreeSatAI_benchmark_archive.pdf + all_sensors = ('aerial', 's1', 's2') + all_bands: ClassVar[dict[str, list[str]]] = { + 'aerial': ['IR', 'G', 'B', 'R'], + 's1': ['VV', 'VH', 'VV/VH'], + 's2': [ + 'B02', + 'B03', + 'B04', + 'B08', + 'B05', + 'B06', + 'B07', + 'B8A', + 'B11', + 'B12', + 'B01', + 'B09', + ], + } + rgb_bands: ClassVar[dict[str, list[str]]] = { + 'aerial': ['R', 'G', 'B'], + 's1': ['VV', 'VH', 'VV/VH'], + 's2': ['B04', 'B03', 'B02'], + } + + def __init__( + self, + root: Path = 'data', + split: str = 'train', + sensors: Sequence[str] = all_sensors, + transforms: Callable[[dict[str, Tensor]], dict[str, Tensor]] | None = None, + download: bool = False, + checksum: bool = False, + ) -> None: + """Initialize a new TreeSatAI instance. + + Args: + root: Root directory where dataset can be found. + split: Either 'train' or 'test'. + sensors: One or more of 'aerial', 's1', and/or 's2'. + transforms: A function/transform that takes input sample and its target as + entry and returns a transformed version. + download: If True, download dataset and store it in the root directory. + checksum: If True, check the MD5 of the downloaded files (may be slow). + + Raises: + AssertionError: If invalid *sensors* are chosen. + DatasetNotFoundError: If dataset is not found and *download* is False. + """ + assert set(sensors) <= set(self.all_sensors) + + self.root = root + self.split = split + self.sensors = sensors + self.transforms = transforms + self.download = download + self.checksum = checksum + + self._verify() + + path = os.path.join(self.root, f'{split}_filenames.lst') + with open(path) as f: + self.files = f.read().strip().split('\n') + + path = os.path.join(self.root, 'labels', 'TreeSatBA_v9_60m_multi_labels.json') + with open(path) as f: + self.labels = json.load(f) + + def __len__(self) -> int: + """Return the number of data points in the dataset. + + Returns: + Length of the dataset. + """ + return len(self.files) + + def __getitem__(self, index: int) -> dict[str, Tensor]: + """Return an index within the dataset. + + Args: + index: Index to return. + + Returns: + Data and label at that index. + """ + file = self.files[index] + label = torch.zeros(len(self.classes)) + for genus, _ in self.labels[file]: + i = self.classes.index(genus) + label[i] = 1 + + sample = {'label': label} + for directory in self.sensors: + with rio.open(os.path.join(self.root, directory, '60m', file)) as f: + sample[f'image_{directory}'] = torch.tensor(f.read().astype('float32')) + + if self.transforms is not None: + sample = self.transforms(sample) + + return sample + + def _verify(self) -> None: + """Verify the integrity of the dataset.""" + # Check if the extracted files already exist + exists = [] + for directory in self.sensors: + exists.append(os.path.isdir(os.path.join(self.root, directory))) + + if all(exists): + return + + for file, md5 in self.md5s.items(): + # Check if the file has already been downloaded + if os.path.isfile(os.path.join(self.root, file)): + self._extract(file) + continue + + # Check if the user requested to download the dataset + if self.download: + url = self.url + file + download_url(url, self.root, md5=md5 if self.checksum else None) + self._extract(file) + continue + + raise DatasetNotFoundError(self) + + def _extract(self, file: str) -> None: + """Extract file. + + Args: + file: The file to extract. + """ + if not file.endswith('.zip'): + return + + to_path = self.root + if file.startswith('aerial'): + to_path = os.path.join(self.root, 'aerial', '60m') + + extract_archive(os.path.join(self.root, file), to_path) + + def plot(self, sample: dict[str, Tensor], show_titles: bool = True) -> Figure: + """Plot a sample from the dataset. + + Args: + sample: A sample returned by :meth:`__getitem__`. + show_titles: Flag indicating whether to show titles above each panel. + + Returns: + A matplotlib Figure with the rendered sample. + """ + fig, ax = plt.subplots(ncols=len(self.sensors), squeeze=False) + + for i, sensor in enumerate(self.sensors): + image = sample[f'image_{sensor}'].cpu().numpy() + bands = [self.all_bands[sensor].index(b) for b in self.rgb_bands[sensor]] + image = rearrange(image[bands], 'c h w -> h w c') + image = percentile_normalization(image) + ax[0, i].imshow(image) + ax[0, i].axis('off') + + if show_titles: + ax[0, i].set_title(sensor) + + if show_titles: + label = self._multilabel_to_string(sample['label']) + suptitle = f'Label: ({label})' + + if 'prediction' in sample: + prediction = self._multilabel_to_string(sample['prediction']) + suptitle += f'\nPrediction: ({prediction})' + + fig.suptitle(suptitle) + + fig.tight_layout() + return fig + + def _multilabel_to_string(self, multilabel: Tensor) -> str: + """Convert a tensor of multilabel class probabilities to human readable format. + + Args: + multilabel: A tensor of multilabel class probabilities. + + Returns: + Class names and percentages sorted by percentage. + """ + labels: list[tuple[str, float]] = [] + for i, pct in enumerate(multilabel.cpu().numpy()): + if pct > 0.001: + labels.append((self.classes[i], pct)) + + labels.sort(key=lambda label: label[1], reverse=True) + return ', '.join([f'{genus}: {pct:.1%}' for genus, pct in labels])