Skip to content

Commit

Permalink
Merge branch 'eurocrops_handlenones' into tutorial_be
Browse files Browse the repository at this point in the history
  • Loading branch information
burakekim committed Jan 11, 2025
2 parents d5038a2 + 863aaa9 commit 65fd5a1
Show file tree
Hide file tree
Showing 41 changed files with 661 additions and 3 deletions.
5 changes: 5 additions & 0 deletions docs/api/datasets.rst
Original file line number Diff line number Diff line change
Expand Up @@ -354,6 +354,11 @@ MapInWild

.. autoclass:: MapInWild

MDAS
^^^^

.. autoclass:: MDAS

Million-AID
^^^^^^^^^^^

Expand Down
1 change: 1 addition & 0 deletions docs/api/datasets/non_geo_datasets.csv
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,7 @@ Dataset,Task,Source,License,# Samples,# Classes,Size (px),Resolution (m),Bands
`LEVIR-CD+`_,CD,Google Earth,-,985,2,"1,024x1,024",0.5,RGB
`LoveDA`_,S,Google Earth,"CC-BY-NC-SA-4.0","5,987",7,"1,024x1,024",0.3,RGB
`MapInWild`_,S,"Sentinel-1/2, ESA WorldCover, NOAA VIIRS DNB","CC-BY-4.0",1018,1,1920x1920,10--463.83,"SAR, MSI, 2020_Map, avg_rad"
`MDAS`_,S,"Sentinel-1/2,EnMAP,HySpex","CC-BY-SA-4.0",3,20,"100x120, 300x360, 1364x1636, 10000x12000, 15000x18000",0.3--30,HSI
`Million-AID`_,C,Google Earth,-,1M,51--73,,0.5--153,RGB
`MMEarth`_,"C, S","Aster, Sentinel, ERA5","CC-BY-4.0","100K--1M",,"128x128 or 64x64",10,MSI
`NASA Marine Debris`_,OD,PlanetScope,"Apache-2.0",707,1,256x256,3,RGB
Expand Down
Binary file added tests/data/mdas/Augsburg_data_4_publication.zip
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
161 changes: 161 additions & 0 deletions tests/data/mdas/data.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,161 @@
#!/usr/bin/env python3

# Copyright (c) Microsoft Corporation. All rights reserved.
# Licensed under the MIT License.

import hashlib
import os
import shutil

import numpy as np
import rasterio
from rasterio.crs import CRS
from rasterio.transform import from_origin

# Set the random seed for reproducibility
np.random.seed(0)

# Define the root directory, dataset name, subareas, and modalities based on mdas.py
root_dir = '.'
ds_root_name = 'Augsburg_data_4_publication'
subareas = ['sub_area_1', 'sub_area_2', 'sub_area_3']
modalities = [
'3K_DSM',
'3K_RGB',
'HySpex',
'EeteS_EnMAP_10m',
'EeteS_EnMAP_30m',
'EeteS_Sentinel_2_10m',
'Sentinel_1',
'Sentinel_2',
'osm_buildings',
'osm_landuse',
'osm_water',
]

landuse_class_codes = [
-2147483647, # no label
7201, # forest
7202, # park
7203, # residential
7204, # industrial
7205, # farm
7206, # cemetery
7207, # allotments
7208, # meadow
7209, # commercial
7210, # nature reserve
7211, # recreation ground
7212, # retail
7213, # military
7214, # quarry
7215, # orchard
7217, # scrub
7218, # grass
7219, # heath
]

# Remove existing dummy data if it exists
dataset_path = os.path.join(root_dir, ds_root_name)
if os.path.exists(dataset_path):
shutil.rmtree(dataset_path)


def create_dummy_geotiff(
path: str,
num_bands: int = 3,
width: int = 32,
height: int = 32,
dtype: np.dtype = np.uint16,
binary: bool = False,
landuse: bool = False,
) -> None:
"""Create a dummy GeoTIFF file."""
crs = CRS.from_epsg(32632)
transform = from_origin(0, 0, 1, 1)

if binary:
data = np.random.randint(0, 2, size=(num_bands, height, width)).astype(dtype)
elif landuse:
num_pixels = num_bands * height * width
no_label_ratio = 0.1
num_no_label = int(no_label_ratio * num_pixels)
num_labels = num_pixels - num_no_label
landuse_values = np.random.choice(landuse_class_codes[1:], size=num_labels)
no_label_values = np.full(num_no_label, landuse_class_codes[0], dtype=dtype)
combined = np.concatenate([landuse_values, no_label_values])
np.random.shuffle(combined)
data = combined.reshape((num_bands, height, width)).astype(dtype)
else:
# Generate random data for other modalities
data = np.random.randint(0, 255, size=(num_bands, height, width)).astype(dtype)

os.makedirs(os.path.dirname(path), exist_ok=True)

with rasterio.open(
path,
'w',
driver='GTiff',
height=height,
width=width,
count=num_bands,
dtype=dtype,
crs=crs,
transform=transform,
) as dst:
dst.write(data)


# Create directory structure and dummy data
for subarea in subareas:
# Format the subarea name for filenames, as in mdas.py _format_subarea method
parts = subarea.split('_')
subarea_formatted = parts[0] + '_' + parts[1] + parts[2] # e.g., 'sub_area1'

subarea_dir = os.path.join(root_dir, ds_root_name, subarea)

for modality in modalities:
filename = f'{modality}_{subarea_formatted}.tif'
file_path = os.path.join(subarea_dir, filename)

if modality in ['osm_buildings', 'osm_water']:
create_dummy_geotiff(file_path, num_bands=1, dtype=np.uint8, binary=True)
elif modality == 'osm_landuse':
create_dummy_geotiff(file_path, num_bands=1, dtype=np.float64, landuse=True)
elif modality == 'HySpex':
create_dummy_geotiff(file_path, num_bands=368, dtype=np.int16)
elif modality in ['EeteS_EnMAP_10m', 'EeteS_EnMAP_30m']:
create_dummy_geotiff(file_path, num_bands=242, dtype=np.uint16)
elif modality == 'Sentinel_1':
create_dummy_geotiff(file_path, num_bands=2, dtype=np.float32)
elif modality in ['Sentinel_2', 'EeteS_Sentinel_2_10m']:
create_dummy_geotiff(file_path, num_bands=13, dtype=np.uint16)
elif modality == '3K_DSM':
create_dummy_geotiff(file_path, num_bands=1, dtype=np.float32)
elif modality == '3K_RGB':
create_dummy_geotiff(file_path, num_bands=3, dtype=np.uint8)

print(f'Dummy MDAS dataset created at {os.path.join(root_dir, ds_root_name)}')

# Create a zip archive of the dataset directory
zip_filename = f'{ds_root_name}.zip'
zip_path = os.path.join(root_dir, zip_filename)

shutil.make_archive(
base_name=os.path.splitext(zip_path)[0],
format='zip',
root_dir='.',
base_dir=ds_root_name,
)


def calculate_md5(filename: str) -> str:
hash_md5 = hashlib.md5()
with open(filename, 'rb') as f:
for chunk in iter(lambda: f.read(4096), b''):
hash_md5.update(chunk)
return hash_md5.hexdigest()


checksum = calculate_md5(zip_path)
print(f'MD5 checksum: {checksum}')
113 changes: 113 additions & 0 deletions tests/datasets/test_mdas.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,113 @@
# Copyright (c) Microsoft Corporation. All rights reserved.
# Licensed under the MIT License.

import os
import shutil
from pathlib import Path

import matplotlib.pyplot as plt
import pytest
import torch
import torch.nn as nn
from _pytest.fixtures import SubRequest
from pytest import MonkeyPatch

from torchgeo.datasets import MDAS, DatasetNotFoundError


class TestMDAS:
@pytest.fixture(
params=[
{'subareas': ['sub_area_1'], 'modalities': ['HySpex']},
{
'subareas': ['sub_area_1', 'sub_area_2'],
'modalities': ['3K_DSM', 'HySpex', 'osm_water'],
},
{
'subareas': ['sub_area_2', 'sub_area_3'],
'modalities': [
'3K_DSM',
'3K_RGB',
'HySpex',
'EeteS_EnMAP_10m',
'EeteS_EnMAP_30m',
'EeteS_Sentinel_2_10m',
'Sentinel_2',
'Sentinel_1',
'osm_buildings',
'osm_landuse',
'osm_water',
],
},
]
)
def dataset(
self, monkeypatch: MonkeyPatch, tmp_path: Path, request: SubRequest
) -> MDAS:
md5 = '99e1744ca6f19aa19a3aa23a2bbf7bef'
monkeypatch.setattr(MDAS, 'md5', md5)
url = os.path.join('tests', 'data', 'mdas', 'Augsburg_data_4_publication.zip')
monkeypatch.setattr(MDAS, 'url', url)

params = request.param
subareas = params['subareas']
modalities = params['modalities']

root = tmp_path
transforms = nn.Identity()

return MDAS(
root=root,
subareas=subareas,
modalities=modalities,
transforms=transforms,
download=True,
checksum=True,
)

def test_getitem(self, dataset: MDAS) -> None:
x = dataset[0]
assert isinstance(x, dict)
for key in dataset.modalities:
if key.startswith('osm'):
key = f'{key}_mask'
else:
key = f'{key}_image'
assert key in x

for key, value in x.items():
assert isinstance(value, torch.Tensor)

def test_len(self, dataset: MDAS) -> None:
assert len(dataset) == len(dataset.subareas)

def test_already_downloaded(self, dataset: MDAS) -> None:
MDAS(root=dataset.root)

def test_not_yet_extracted(self, tmp_path: Path) -> None:
filename = 'Augsburg_data_4_publication.zip'
dir = os.path.join('tests', 'data', 'mdas')
shutil.copyfile(
os.path.join(dir, filename), os.path.join(str(tmp_path), filename)
)
MDAS(root=str(tmp_path))

def test_invalid_subarea(self) -> None:
with pytest.raises(AssertionError):
MDAS(subareas=['foo'])

def test_invalid_modality(self) -> None:
with pytest.raises(AssertionError):
MDAS(modalities=['foo'])

def test_not_downloaded(self, tmp_path: Path) -> None:
with pytest.raises(DatasetNotFoundError, match='Dataset not found'):
MDAS(tmp_path)

def test_plot(self, dataset: MDAS) -> None:
dataset.plot(dataset[0], suptitle='Test')
plt.close()

def test_plot_single_sample(self, dataset: MDAS) -> None:
dataset.plot(dataset[0], show_titles=False)
plt.close()
2 changes: 2 additions & 0 deletions torchgeo/datasets/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -85,6 +85,7 @@
from .levircd import LEVIRCD, LEVIRCDBase, LEVIRCDPlus
from .loveda import LoveDA
from .mapinwild import MapInWild
from .mdas import MDAS
from .millionaid import MillionAID
from .mmearth import MMEarth
from .naip import NAIP
Expand Down Expand Up @@ -159,6 +160,7 @@
'GBIF',
'GID15',
'LEVIRCD',
'MDAS',
'NAIP',
'NCCM',
'NLCD',
Expand Down
3 changes: 0 additions & 3 deletions torchgeo/datasets/eurocrops.py
Original file line number Diff line number Diff line change
Expand Up @@ -205,9 +205,6 @@ def get_label(self, feature: 'fiona.model.Feature') -> int:
# (Parent code is computed by replacing rightmost non-0 character with 0.)
hcat_code = feature['properties'][self.label_name]
if hcat_code is None:
print(
f"Feature does not contain the label '{self.label_name}'. Skip rendering."
)
return 0

while True:
Expand Down
Loading

0 comments on commit 65fd5a1

Please sign in to comment.