Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add MMFlood dataset #2450

Open
wants to merge 11 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 5 additions & 0 deletions docs/api/datamodules.rst
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,11 @@ L8 Biome

.. autoclass:: L8BiomeDataModule

MMFlood
^^^^^^^^

.. autoclass:: MMFloodDataModule

NAIP
^^^^

Expand Down
4 changes: 4 additions & 0 deletions docs/api/datasets.rst
Original file line number Diff line number Diff line change
Expand Up @@ -142,6 +142,10 @@ Landsat
.. autoclass:: Landsat2
.. autoclass:: Landsat1

MMFlood
^^^^^^^
.. autoclass:: MMFlood

NAIP
^^^^

Expand Down
1 change: 1 addition & 0 deletions docs/api/datasets/geo_datasets.csv
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@ Dataset,Type,Source,License,Size (px),Resolution (m)
`L8 Biome`_,"Imagery, Masks",Landsat,"CC0-1.0","8,900x8,900","15, 30"
`LandCover.ai Geo`_,"Imagery, Masks",Aerial,"CC-BY-NC-SA-4.0","4,200--9,500",0.25--0.5
`Landsat`_,Imagery,Landsat,"public domain","8,900x8,900",30
`MMFlood`_,"Imagery,DEM,Masks","Sentinel, MapZen/TileZen, OpenStreetMap",MIT,"2,147x2,313",20
`NAIP`_,Imagery,Aerial,"public domain","6,100x7,600",0.3--2
`NCCM`_,Masks,Sentinel-2,"CC-BY-4.0",-,10
`NLCD`_,Masks,Landsat,"public domain",-,30
Expand Down
20 changes: 20 additions & 0 deletions tests/conf/mmflood.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,20 @@
model:
class_path: SemanticSegmentationTask
init_args:
loss: 'ce'
model: 'unet'
backbone: 'resnet18'
in_channels: 4
num_classes: 2
num_filters: 1
ignore_index: 255
data:
class_path: MMFloodDataModule
init_args:
batch_size: 1
dict_kwargs:
root: 'tests/data/mmflood'
patch_size: 8
normalization: 'median'
include_dem: True
include_hydro: True
1 change: 1 addition & 0 deletions tests/data/mmflood/activations.json
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
{"EMSR000": {"title": "Test flood", "type": "Flood", "country": "N/A", "start": "2014-11-06T17:57:00", "end": "2015-01-29T12:47:04", "lat": 45.82427031690563, "lon": 14.484407562009336, "subset": "train", "delineations": ["EMSR000_00"]}, "EMSR001": {"title": "Test flood", "type": "Flood", "country": "N/A", "start": "2014-11-06T17:57:00", "end": "2015-01-29T12:47:04", "lat": 45.82427031690563, "lon": 14.484407562009336, "subset": "train", "delineations": ["EMSR001_00"]}, "EMSR003": {"title": "Test flood", "type": "Flood", "country": "N/A", "start": "2014-11-06T17:57:00", "end": "2015-01-29T12:47:04", "lat": 45.82427031690563, "lon": 14.484407562009336, "subset": "val", "delineations": ["EMSR003_00"]}, "EMSR004": {"title": "Test flood", "type": "Flood", "country": "N/A", "start": "2014-11-06T17:57:00", "end": "2015-01-29T12:47:04", "lat": 45.82427031690563, "lon": 14.484407562009336, "subset": "test", "delineations": ["EMSR004_00"]}}
Binary file added tests/data/mmflood/activations.tar.000.gz.part
Binary file not shown.
Binary file added tests/data/mmflood/activations.tar.001.gz.part
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
142 changes: 142 additions & 0 deletions tests/data/mmflood/data.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,142 @@
# Copyright (c) Microsoft Corporation. All rights reserved.
# Licensed under the MIT License.

import json
import os
import tarfile

import numpy as np
import rasterio
from rasterio.crs import CRS
from rasterio.transform import Affine


def generate_data(
path: str, filename: str, height: int, width: int, include_hydro: bool = False
) -> None:
max_value = 1000.0
min_value = 0.0
interval = max_value - min_value
folders = ['s1_raw', 'DEM', 'mask', 'hydro']
profile = {
'driver': 'GTiff',
'dtype': 'float32',
'nodata': None,
'crs': CRS.from_epsg(4326),
'transform': Affine(
0.0001287974837883981,
0.0,
14.438064999669106,
0.0,
-8.989523639880024e-05,
45.71617928533084,
),
'blockysize': 1,
'tiled': False,
'interleave': 'pixel',
'height': height,
'width': width,
}
data = {
's1_raw': np.random.rand(2, height, width).astype(np.float32) * interval
- min_value,
'DEM': np.random.rand(1, height, width).astype(np.float32) * interval
- min_value,
'mask': np.random.randint(low=0, high=2, size=(1, height, width)).astype(
np.uint8
),
}

if include_hydro:
data['hydro'] = (
np.random.rand(1, height, width).astype(np.float32) * interval - min_value
)

for folder in folders:
folder_path = os.path.join(path, folder)
os.makedirs(folder_path, exist_ok=True)
filepath = os.path.join(folder_path, filename)
profile2 = profile.copy()
profile2['count'] = 2 if folder == 's1_raw' else 1
if folder in data:
with rasterio.open(filepath, mode='w', **profile2) as src:
src.write(data[folder])

return


def generate_tar_gz(src: str, dst: str) -> None:
with tarfile.open(dst, 'w:gz') as tar:
tar.add(src, arcname=src)
return


def split_tar(path: str, dst: str, nparts: int) -> None:
fstats = os.stat(path)
size = fstats.st_size
chunk = size // nparts

with open(path, 'rb') as fp:
for idx in range(nparts):
part_path = os.path.join(dst, f'activations.tar.{idx:03}.gz.part')

bytes_to_write = chunk if idx < nparts - 1 else size - fp.tell()
with open(part_path, 'wb') as dst_fp:
dst_fp.write(fp.read(bytes_to_write))

return


def generate_folders_and_metadata(datapath: str, metadatapath: str) -> None:
folders_splits = [
('EMSR000', 'train'),
('EMSR001', 'train'),
('EMSR003', 'val'),
('EMSR004', 'test'),
]
num_files = {'EMSR000': 3, 'EMSR001': 2, 'EMSR003': 2, 'EMSR004': 1}
num_hydro = {'EMSR001': 2, 'EMSR003': 1, 'EMSR004': 1}
metadata = {}
for folder, split in folders_splits:
data = {}
data['title'] = 'Test flood'
data['type'] = 'Flood'
data['country'] = 'N/A'
data['start'] = '2014-11-06T17:57:00'
data['end'] = '2015-01-29T12:47:04'
data['lat'] = 45.82427031690563
data['lon'] = 14.484407562009336
data['subset'] = split
data['delineations'] = [f'{folder}_00']

count_hydro = 0

dst_folder = os.path.join(datapath, f'{folder}-0')
for idx in range(num_files[folder]):
include_hydro = count_hydro < num_hydro.get(folder, 0)
generate_data(
dst_folder,
filename=f'{folder}-{idx}.tif',
height=16,
width=16,
include_hydro=include_hydro,
)
if include_hydro:
count_hydro += 1

metadata[folder] = data

generate_tar_gz(src='activations', dst='activations.tar.gz')
split_tar(path='activations.tar.gz', dst='.', nparts=2)
os.remove('activations.tar.gz')
with open(os.path.join(metadatapath, 'activations.json'), 'w') as fp:
json.dump(metadata, fp)

return


if __name__ == '__main__':
datapath = os.path.join(os.getcwd(), 'activations')
metadatapath = os.getcwd()

generate_folders_and_metadata(datapath, metadatapath)
151 changes: 151 additions & 0 deletions tests/datasets/test_mmflood.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,151 @@
# Copyright (c) Microsoft Corporation. All rights reserved.
# Licensed under the MIT License.

import os
from itertools import product
from pathlib import Path

import matplotlib.pyplot as plt
import pytest
import torch
import torch.nn as nn
from _pytest.fixtures import SubRequest
from pytest import MonkeyPatch
from rasterio.crs import CRS

from torchgeo.datasets import (
BoundingBox,
DatasetNotFoundError,
IntersectionDataset,
MMFlood,
UnionDataset,
)
from torchgeo.datasets.mmflood import MMFloodComponent, MMFloodIntersection


class TestMMFlood:
@pytest.fixture(
params=product([True, False], [True, False], ['train', 'val', 'test'])
)
def dataset(
self, monkeypatch: MonkeyPatch, tmp_path: Path, request: SubRequest
) -> MMFlood:
dataset_root = os.path.join('tests', 'data', 'mmflood/')
url = os.path.join(dataset_root)

monkeypatch.setattr(MMFlood, 'url', url)
monkeypatch.setattr(MMFlood, '_nparts', 2)

include_dem, include_hydro, split = request.param
root = tmp_path
return MMFlood(
root,
split=split,
include_dem=include_dem,
include_hydro=include_hydro,
transforms=nn.Identity(),
download=True,
checksum=True,
)

def test_getitem(self, dataset: MMFlood) -> None:
x = dataset[dataset.bounds]
assert isinstance(x, dict)
assert isinstance(x['crs'], CRS)
assert isinstance(x['image'], torch.Tensor)
assert isinstance(x['mask'], torch.Tensor)
nchannels = 2

# If DEM is included and hydro is included, check if 4 channels are present,
# If only one between DEM or hydro is included, check if 3 channels are present
# 2 otherwise
if dataset.include_dem:
nchannels += 1
if dataset.include_hydro:
nchannels += 1
assert x['image'].size(0) == nchannels

@pytest.fixture
def mock_intersection_dataset(
self, monkeypatch: MonkeyPatch, tmp_path: Path
) -> MMFlood:
class MockIntersection(MMFloodIntersection):
def __init__(
self,
dataset1: MMFloodIntersection | MMFloodComponent,
dataset2: MMFloodIntersection | MMFloodComponent,
) -> None:
super().__init__(dataset2, dataset1)

monkeypatch.setattr(
'torchgeo.datasets.mmflood.MMFloodIntersection', MockIntersection
)
dataset_root = os.path.join('tests', 'data', 'mmflood/')
url = os.path.join(dataset_root)
monkeypatch.setattr(MMFlood, 'url', url)
monkeypatch.setattr(MMFlood, '_nparts', 2)
return MMFlood(
tmp_path,
split='train',
include_dem=True,
include_hydro=True,
transforms=nn.Identity(),
download=True,
checksum=True,
)

def test_swap_dataset(self, mock_intersection_dataset: MMFlood) -> None:
d = MMFlood(
mock_intersection_dataset.root,
split='train',
include_dem=True,
include_hydro=True,
)
assert len(d) == 2

def test_len(self, dataset: MMFlood) -> None:
if dataset.split == 'train':
if not dataset.include_hydro:
assert len(dataset) == 5
else:
assert len(dataset) == 2
elif dataset.split == 'val':
if not dataset.include_hydro:
assert len(dataset) == 2
else:
assert len(dataset) == 1
else:
assert len(dataset) == 1

def test_and(self, dataset: MMFlood) -> None:
ds = dataset & dataset
assert isinstance(ds, IntersectionDataset)

def test_or(self, dataset: MMFlood) -> None:
ds = dataset | dataset
assert isinstance(ds, UnionDataset)

def test_already_downloaded(self, dataset: MMFlood) -> None:
MMFlood(root=dataset.root)

def test_not_downloaded(self, tmp_path: Path) -> None:
with pytest.raises(DatasetNotFoundError, match='Dataset not found'):
MMFlood(tmp_path)

def test_plot(self, dataset: MMFlood) -> None:
x = dataset[dataset.bounds]
dataset.plot(x, suptitle='Test')
plt.close()

def test_plot_prediction(self, dataset: MMFlood) -> None:
x = dataset[dataset.bounds]
x['prediction'] = x['mask'].clone()
dataset.plot(x, suptitle='Prediction')
plt.close()

def test_invalid_query(self, dataset: MMFlood) -> None:
query = BoundingBox(0, 0, 0, 0, 0, 0)
with pytest.raises(
IndexError, match='query: .* not found in index with bounds:'
):
dataset[query]
1 change: 1 addition & 0 deletions tests/trainers/test_segmentation.py
Original file line number Diff line number Diff line change
Expand Up @@ -65,6 +65,7 @@ class TestSemanticSegmentationTask:
'landcoverai',
'landcoverai100',
'loveda',
'mmflood',
'naipchesapeake',
'potsdam2d',
'sen12ms_all',
Expand Down
2 changes: 2 additions & 0 deletions torchgeo/datamodules/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,7 @@
from .landcoverai import LandCoverAI100DataModule, LandCoverAIDataModule
from .levircd import LEVIRCDDataModule, LEVIRCDPlusDataModule
from .loveda import LoveDADataModule
from .mmflood import MMFloodDataModule
from .naip import NAIPChesapeakeDataModule
from .nasa_marine_debris import NASAMarineDebrisDataModule
from .oscd import OSCDDataModule
Expand Down Expand Up @@ -87,6 +88,7 @@
'LandCoverAI100DataModule',
'LandCoverAIDataModule',
'LoveDADataModule',
'MMFloodDataModule',
'MisconfigurationException',
'NAIPChesapeakeDataModule',
'NASAMarineDebrisDataModule',
Expand Down
Loading
Loading