Skip to content

Commit

Permalink
Add MDAS dataset (#2429)
Browse files Browse the repository at this point in the history
* mdas dataset

* docs

* mypy

* ruff

* init order

* typo

* test coverage

* coverag

* docs

* comma

* fix

* cmap
  • Loading branch information
nilsleh authored Jan 8, 2025
1 parent 9819625 commit 5cca8e7
Show file tree
Hide file tree
Showing 40 changed files with 661 additions and 0 deletions.
5 changes: 5 additions & 0 deletions docs/api/datasets.rst
Original file line number Diff line number Diff line change
Expand Up @@ -354,6 +354,11 @@ MapInWild

.. autoclass:: MapInWild

MDAS
^^^^

.. autoclass:: MDAS

Million-AID
^^^^^^^^^^^

Expand Down
1 change: 1 addition & 0 deletions docs/api/datasets/non_geo_datasets.csv
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,7 @@ Dataset,Task,Source,License,# Samples,# Classes,Size (px),Resolution (m),Bands
`LEVIR-CD+`_,CD,Google Earth,-,985,2,"1,024x1,024",0.5,RGB
`LoveDA`_,S,Google Earth,"CC-BY-NC-SA-4.0","5,987",7,"1,024x1,024",0.3,RGB
`MapInWild`_,S,"Sentinel-1/2, ESA WorldCover, NOAA VIIRS DNB","CC-BY-4.0",1018,1,1920x1920,10--463.83,"SAR, MSI, 2020_Map, avg_rad"
`MDAS`_,S,"Sentinel-1/2,EnMAP,HySpex","CC-BY-SA-4.0",3,20,"100x120, 300x360, 1364x1636, 10000x12000, 15000x18000",0.3--30,HSI
`Million-AID`_,C,Google Earth,-,1M,51--73,,0.5--153,RGB
`MMEarth`_,"C, S","Aster, Sentinel, ERA5","CC-BY-4.0","100K--1M",,"128x128 or 64x64",10,MSI
`NASA Marine Debris`_,OD,PlanetScope,"Apache-2.0",707,1,256x256,3,RGB
Expand Down
Binary file added tests/data/mdas/Augsburg_data_4_publication.zip
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
161 changes: 161 additions & 0 deletions tests/data/mdas/data.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,161 @@
#!/usr/bin/env python3

# Copyright (c) Microsoft Corporation. All rights reserved.
# Licensed under the MIT License.

import hashlib
import os
import shutil

import numpy as np
import rasterio
from rasterio.crs import CRS
from rasterio.transform import from_origin

# Set the random seed for reproducibility
np.random.seed(0)

# Define the root directory, dataset name, subareas, and modalities based on mdas.py
root_dir = '.'
ds_root_name = 'Augsburg_data_4_publication'
subareas = ['sub_area_1', 'sub_area_2', 'sub_area_3']
modalities = [
'3K_DSM',
'3K_RGB',
'HySpex',
'EeteS_EnMAP_10m',
'EeteS_EnMAP_30m',
'EeteS_Sentinel_2_10m',
'Sentinel_1',
'Sentinel_2',
'osm_buildings',
'osm_landuse',
'osm_water',
]

landuse_class_codes = [
-2147483647, # no label
7201, # forest
7202, # park
7203, # residential
7204, # industrial
7205, # farm
7206, # cemetery
7207, # allotments
7208, # meadow
7209, # commercial
7210, # nature reserve
7211, # recreation ground
7212, # retail
7213, # military
7214, # quarry
7215, # orchard
7217, # scrub
7218, # grass
7219, # heath
]

# Remove existing dummy data if it exists
dataset_path = os.path.join(root_dir, ds_root_name)
if os.path.exists(dataset_path):
shutil.rmtree(dataset_path)


def create_dummy_geotiff(
path: str,
num_bands: int = 3,
width: int = 32,
height: int = 32,
dtype: np.dtype = np.uint16,
binary: bool = False,
landuse: bool = False,
) -> None:
"""Create a dummy GeoTIFF file."""
crs = CRS.from_epsg(32632)
transform = from_origin(0, 0, 1, 1)

if binary:
data = np.random.randint(0, 2, size=(num_bands, height, width)).astype(dtype)
elif landuse:
num_pixels = num_bands * height * width
no_label_ratio = 0.1
num_no_label = int(no_label_ratio * num_pixels)
num_labels = num_pixels - num_no_label
landuse_values = np.random.choice(landuse_class_codes[1:], size=num_labels)
no_label_values = np.full(num_no_label, landuse_class_codes[0], dtype=dtype)
combined = np.concatenate([landuse_values, no_label_values])
np.random.shuffle(combined)
data = combined.reshape((num_bands, height, width)).astype(dtype)
else:
# Generate random data for other modalities
data = np.random.randint(0, 255, size=(num_bands, height, width)).astype(dtype)

os.makedirs(os.path.dirname(path), exist_ok=True)

with rasterio.open(
path,
'w',
driver='GTiff',
height=height,
width=width,
count=num_bands,
dtype=dtype,
crs=crs,
transform=transform,
) as dst:
dst.write(data)


# Create directory structure and dummy data
for subarea in subareas:
# Format the subarea name for filenames, as in mdas.py _format_subarea method
parts = subarea.split('_')
subarea_formatted = parts[0] + '_' + parts[1] + parts[2] # e.g., 'sub_area1'

subarea_dir = os.path.join(root_dir, ds_root_name, subarea)

for modality in modalities:
filename = f'{modality}_{subarea_formatted}.tif'
file_path = os.path.join(subarea_dir, filename)

if modality in ['osm_buildings', 'osm_water']:
create_dummy_geotiff(file_path, num_bands=1, dtype=np.uint8, binary=True)
elif modality == 'osm_landuse':
create_dummy_geotiff(file_path, num_bands=1, dtype=np.float64, landuse=True)
elif modality == 'HySpex':
create_dummy_geotiff(file_path, num_bands=368, dtype=np.int16)
elif modality in ['EeteS_EnMAP_10m', 'EeteS_EnMAP_30m']:
create_dummy_geotiff(file_path, num_bands=242, dtype=np.uint16)
elif modality == 'Sentinel_1':
create_dummy_geotiff(file_path, num_bands=2, dtype=np.float32)
elif modality in ['Sentinel_2', 'EeteS_Sentinel_2_10m']:
create_dummy_geotiff(file_path, num_bands=13, dtype=np.uint16)
elif modality == '3K_DSM':
create_dummy_geotiff(file_path, num_bands=1, dtype=np.float32)
elif modality == '3K_RGB':
create_dummy_geotiff(file_path, num_bands=3, dtype=np.uint8)

print(f'Dummy MDAS dataset created at {os.path.join(root_dir, ds_root_name)}')

# Create a zip archive of the dataset directory
zip_filename = f'{ds_root_name}.zip'
zip_path = os.path.join(root_dir, zip_filename)

shutil.make_archive(
base_name=os.path.splitext(zip_path)[0],
format='zip',
root_dir='.',
base_dir=ds_root_name,
)


def calculate_md5(filename: str) -> str:
hash_md5 = hashlib.md5()
with open(filename, 'rb') as f:
for chunk in iter(lambda: f.read(4096), b''):
hash_md5.update(chunk)
return hash_md5.hexdigest()


checksum = calculate_md5(zip_path)
print(f'MD5 checksum: {checksum}')
113 changes: 113 additions & 0 deletions tests/datasets/test_mdas.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,113 @@
# Copyright (c) Microsoft Corporation. All rights reserved.
# Licensed under the MIT License.

import os
import shutil
from pathlib import Path

import matplotlib.pyplot as plt
import pytest
import torch
import torch.nn as nn
from _pytest.fixtures import SubRequest
from pytest import MonkeyPatch

from torchgeo.datasets import MDAS, DatasetNotFoundError


class TestMDAS:
@pytest.fixture(
params=[
{'subareas': ['sub_area_1'], 'modalities': ['HySpex']},
{
'subareas': ['sub_area_1', 'sub_area_2'],
'modalities': ['3K_DSM', 'HySpex', 'osm_water'],
},
{
'subareas': ['sub_area_2', 'sub_area_3'],
'modalities': [
'3K_DSM',
'3K_RGB',
'HySpex',
'EeteS_EnMAP_10m',
'EeteS_EnMAP_30m',
'EeteS_Sentinel_2_10m',
'Sentinel_2',
'Sentinel_1',
'osm_buildings',
'osm_landuse',
'osm_water',
],
},
]
)
def dataset(
self, monkeypatch: MonkeyPatch, tmp_path: Path, request: SubRequest
) -> MDAS:
md5 = '99e1744ca6f19aa19a3aa23a2bbf7bef'
monkeypatch.setattr(MDAS, 'md5', md5)
url = os.path.join('tests', 'data', 'mdas', 'Augsburg_data_4_publication.zip')
monkeypatch.setattr(MDAS, 'url', url)

params = request.param
subareas = params['subareas']
modalities = params['modalities']

root = tmp_path
transforms = nn.Identity()

return MDAS(
root=root,
subareas=subareas,
modalities=modalities,
transforms=transforms,
download=True,
checksum=True,
)

def test_getitem(self, dataset: MDAS) -> None:
x = dataset[0]
assert isinstance(x, dict)
for key in dataset.modalities:
if key.startswith('osm'):
key = f'{key}_mask'
else:
key = f'{key}_image'
assert key in x

for key, value in x.items():
assert isinstance(value, torch.Tensor)

def test_len(self, dataset: MDAS) -> None:
assert len(dataset) == len(dataset.subareas)

def test_already_downloaded(self, dataset: MDAS) -> None:
MDAS(root=dataset.root)

def test_not_yet_extracted(self, tmp_path: Path) -> None:
filename = 'Augsburg_data_4_publication.zip'
dir = os.path.join('tests', 'data', 'mdas')
shutil.copyfile(
os.path.join(dir, filename), os.path.join(str(tmp_path), filename)
)
MDAS(root=str(tmp_path))

def test_invalid_subarea(self) -> None:
with pytest.raises(AssertionError):
MDAS(subareas=['foo'])

def test_invalid_modality(self) -> None:
with pytest.raises(AssertionError):
MDAS(modalities=['foo'])

def test_not_downloaded(self, tmp_path: Path) -> None:
with pytest.raises(DatasetNotFoundError, match='Dataset not found'):
MDAS(tmp_path)

def test_plot(self, dataset: MDAS) -> None:
dataset.plot(dataset[0], suptitle='Test')
plt.close()

def test_plot_single_sample(self, dataset: MDAS) -> None:
dataset.plot(dataset[0], show_titles=False)
plt.close()
2 changes: 2 additions & 0 deletions torchgeo/datasets/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -85,6 +85,7 @@
from .levircd import LEVIRCD, LEVIRCDBase, LEVIRCDPlus
from .loveda import LoveDA
from .mapinwild import MapInWild
from .mdas import MDAS
from .millionaid import MillionAID
from .mmearth import MMEarth
from .naip import NAIP
Expand Down Expand Up @@ -159,6 +160,7 @@
'GBIF',
'GID15',
'LEVIRCD',
'MDAS',
'NAIP',
'NCCM',
'NLCD',
Expand Down
Loading

0 comments on commit 5cca8e7

Please sign in to comment.