Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add MMFlood dataset #2450

Open
wants to merge 13 commits into
base: main
Choose a base branch
from
5 changes: 5 additions & 0 deletions docs/api/datamodules.rst
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,11 @@ L8 Biome

.. autoclass:: L8BiomeDataModule

MMFlood
^^^^^^^^

.. autoclass:: MMFloodDataModule

NAIP
^^^^

Expand Down
4 changes: 4 additions & 0 deletions docs/api/datasets.rst
Original file line number Diff line number Diff line change
Expand Up @@ -142,6 +142,10 @@ Landsat
.. autoclass:: Landsat2
.. autoclass:: Landsat1

MMFlood
^^^^^^^
.. autoclass:: MMFlood

NAIP
^^^^

Expand Down
1 change: 1 addition & 0 deletions docs/api/datasets/geo_datasets.csv
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@ Dataset,Type,Source,License,Size (px),Resolution (m)
`L8 Biome`_,"Imagery, Masks",Landsat,"CC0-1.0","8,900x8,900","15, 30"
`LandCover.ai Geo`_,"Imagery, Masks",Aerial,"CC-BY-NC-SA-4.0","4,200--9,500",0.25--0.5
`Landsat`_,Imagery,Landsat,"public domain","8,900x8,900",30
`MMFlood`_,"Imagery,DEM,Masks","Sentinel, MapZen/TileZen, OpenStreetMap",CC-BY-4.0,"2,147x2,313",20
adamjstewart marked this conversation as resolved.
Show resolved Hide resolved
`NAIP`_,Imagery,Aerial,"public domain","6,100x7,600",0.3--2
`NCCM`_,Masks,Sentinel-2,"CC-BY-4.0",-,10
`NLCD`_,Masks,Landsat,"public domain",-,30
Expand Down
1 change: 1 addition & 0 deletions tests/data/mmflood/activations.json
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
{"EMSR000": {"title": "Test flood", "type": "Flood", "country": "N/A", "start": "2014-11-06T17:57:00", "end": "2015-01-29T12:47:04", "lat": 45.82427031690563, "lon": 14.484407562009336, "subset": "train", "delineations": ["EMSR000_00"]}, "EMSR001": {"title": "Test flood", "type": "Flood", "country": "N/A", "start": "2014-11-06T17:57:00", "end": "2015-01-29T12:47:04", "lat": 45.82427031690563, "lon": 14.484407562009336, "subset": "train", "delineations": ["EMSR001_00"]}, "EMSR003": {"title": "Test flood", "type": "Flood", "country": "N/A", "start": "2014-11-06T17:57:00", "end": "2015-01-29T12:47:04", "lat": 45.82427031690563, "lon": 14.484407562009336, "subset": "val", "delineations": ["EMSR003_00"]}, "EMSR004": {"title": "Test flood", "type": "Flood", "country": "N/A", "start": "2014-11-06T17:57:00", "end": "2015-01-29T12:47:04", "lat": 45.82427031690563, "lon": 14.484407562009336, "subset": "test", "delineations": ["EMSR004_00"]}}
Binary file added tests/data/mmflood/activations.tar.000.gz.part
Binary file not shown.
Binary file added tests/data/mmflood/activations.tar.001.gz.part
Binary file not shown.
127 changes: 127 additions & 0 deletions tests/data/mmflood/data.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,127 @@
# Copyright (c) Microsoft Corporation. All rights reserved.
# Licensed under the MIT License.

import json
import os
import shutil
import tarfile

import numpy as np
import rasterio
from rasterio.crs import CRS
from rasterio.transform import Affine


def generate_data(path: str, filename: str, height: int, width: int) -> None:
MAX_VALUE = 1000.0
MIN_VALUE = 0.0
RANGE = MAX_VALUE - MIN_VALUE
FOLDERS = ['s1_raw', 'DEM', 'mask']
adamjstewart marked this conversation as resolved.
Show resolved Hide resolved
profile = {
'driver': 'GTiff',
'dtype': 'float32',
'nodata': None,
'crs': CRS.from_epsg(4326),
'transform': Affine(
0.0001287974837883981,
0.0,
14.438064999669106,
0.0,
-8.989523639880024e-05,
45.71617928533084,
),
'blockysize': 1,
'tiled': False,
'interleave': 'pixel',
'height': height,
'width': width,
}
data = {
's1_raw': np.random.rand(2, height, width).astype(np.float32) * RANGE
- MIN_VALUE,
'DEM': np.random.rand(1, height, width).astype(np.float32) * RANGE - MIN_VALUE,
'mask': np.random.randint(low=0, high=2, size=(1, height, width)).astype(
np.uint8
),
}

os.makedirs(os.path.join(path, 'hydro'), exist_ok=True)

for folder in FOLDERS:
folder_path = os.path.join(path, folder)
os.makedirs(folder_path, exist_ok=True)
filepath = os.path.join(folder_path, filename)
profile2 = profile.copy()
profile2['count'] = 2 if folder == 's1_raw' else 1
with rasterio.open(filepath, mode='w', **profile2) as src:
src.write(data[folder])

return
lccol marked this conversation as resolved.
Show resolved Hide resolved


def generate_tar_gz(src: str, dst: str) -> None:
with tarfile.open(dst, 'w:gz') as tar:
tar.add(src, arcname=src)
return


def split_tar(path: str, dst: str, nparts: int) -> None:
fstats = os.stat(path)
size = fstats.st_size
chunk = size // nparts

with open(path, 'rb') as fp:
for idx in range(nparts):
part_path = os.path.join(dst, f'activations.tar.{idx:03}.gz.part')

bytes_to_write = chunk if idx < nparts - 1 else size - fp.tell()
with open(part_path, 'wb') as dst_fp:
dst_fp.write(fp.read(bytes_to_write))

return


def generate_folders_and_metadata(datapath: str, metadatapath: str) -> None:
folders_splits = [
('EMSR000', 'train'),
('EMSR001', 'train'),
('EMSR003', 'val'),
('EMSR004', 'test'),
]
num_files = {'EMSR000': 3, 'EMSR001': 2, 'EMSR003': 2, 'EMSR004': 1}
metadata = {}
for folder, split in folders_splits:
data = {}
data['title'] = 'Test flood'
data['type'] = 'Flood'
data['country'] = 'N/A'
data['start'] = '2014-11-06T17:57:00'
data['end'] = '2015-01-29T12:47:04'
data['lat'] = 45.82427031690563
data['lon'] = 14.484407562009336
data['subset'] = split
data['delineations'] = [f'{folder}_00']

dst_folder = os.path.join(datapath, f'{folder}-0')
for idx in range(num_files[folder]):
generate_data(
dst_folder, filename=f'{folder}-{idx}.tif', height=16, width=16
)

metadata[folder] = data

generate_tar_gz(src='activations', dst='activations.tar.gz')
split_tar(path='activations.tar.gz', dst='.', nparts=2)
os.remove('activations.tar.gz')
shutil.rmtree('activations')
with open(os.path.join(metadatapath, 'activations.json'), 'w') as fp:
json.dump(metadata, fp)

return


if __name__ == '__main__':
datapath = os.path.join(os.getcwd(), 'activations')
metadatapath = os.getcwd()

generate_folders_and_metadata(datapath, metadatapath)
72 changes: 72 additions & 0 deletions tests/datamodules/test_mmflood.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,72 @@
import os
adamjstewart marked this conversation as resolved.
Show resolved Hide resolved
from itertools import product
from pathlib import Path

import pytest
from _pytest.fixtures import SubRequest
from pytest import MonkeyPatch
from torch import nn

from torchgeo.datamodules import MMFloodDataModule
from torchgeo.datasets import MMFlood


class TestMMFloodDataModule:
@pytest.fixture(params=product([True, False], ['mean', 'median']))
def datamodule(
self, monkeypatch: MonkeyPatch, tmp_path: Path, request: SubRequest
) -> MMFloodDataModule:
dataset_root = os.path.join('tests', 'data', 'mmflood/')
# url = os.path.join(dataset_root)

# monkeypatch.setattr(MMFlood, 'url', url)
monkeypatch.setattr(MMFlood, '_nparts', 2)

include_dem, normalization = request.param
# root = tmp_path
return MMFloodDataModule(
batch_size=2,
patch_size=8,
normalization=normalization,
root=dataset_root,
include_dem=include_dem,
transforms=nn.Identity(),
download=True,
checksum=True,
)

def test_fit_stage(self, datamodule: MMFloodDataModule) -> None:
datamodule.setup(stage='fit')
datamodule.setup(stage='fit')
if datamodule.trainer:
datamodule.trainer.training = True
batch = next(iter(datamodule.train_dataloader()))
batch = datamodule.on_after_batch_transfer(batch, 0)
nchannels = 3 if datamodule.kwargs['include_dem'] else 2
assert batch['image'].shape == (2, nchannels, 8, 8)
assert batch['mask'].shape == (2, 8, 8)
return

def test_validate_stage(self, datamodule: MMFloodDataModule) -> None:
datamodule.setup(stage='validate')
datamodule.setup(stage='validate')
if datamodule.trainer:
datamodule.trainer.validating = True
batch = next(iter(datamodule.val_dataloader()))
batch = datamodule.on_after_batch_transfer(batch, 0)
nchannels = 3 if datamodule.kwargs['include_dem'] else 2
assert batch['image'].shape == (2, nchannels, 8, 8)
assert batch['mask'].shape == (2, 8, 8)
return

def test_test_stage(self, datamodule: MMFloodDataModule) -> None:
datamodule.setup(stage='test')
datamodule.setup(stage='test')
if datamodule.trainer:
datamodule.trainer.testing = True
batch = next(iter(datamodule.test_dataloader()))
batch = datamodule.on_after_batch_transfer(batch, 0)
nchannels = 3 if datamodule.kwargs['include_dem'] else 2
assert batch['image'].shape == (2, nchannels, 8, 8)
assert batch['mask'].shape == (2, 8, 8)
return
123 changes: 123 additions & 0 deletions tests/datasets/test_mmflood.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,123 @@
# Copyright (c) Microsoft Corporation. All rights reserved.
# Licensed under the MIT License.

import os
from itertools import product
from pathlib import Path

import matplotlib.pyplot as plt
import pytest
import torch
import torch.nn as nn
from _pytest.fixtures import SubRequest
from pytest import MonkeyPatch
from rasterio.crs import CRS

from torchgeo.datasets import (
BoundingBox,
DatasetNotFoundError,
IntersectionDataset,
MMFlood,
UnionDataset,
)


class TestMMFlood:
@pytest.fixture(params=product([True, False], ['train', 'val', 'test']))
def dataset(
self, monkeypatch: MonkeyPatch, tmp_path: Path, request: SubRequest
) -> MMFlood:
dataset_root = os.path.join('tests', 'data', 'mmflood/')
url = os.path.join(dataset_root)

monkeypatch.setattr(MMFlood, 'url', url)
monkeypatch.setattr(MMFlood, '_nparts', 2)

include_dem, split = request.param
root = tmp_path
return MMFlood(
root,
split=split,
include_dem=include_dem,
transforms=nn.Identity(),
download=True,
checksum=True,
)

def test_getitem(self, dataset: MMFlood) -> None:
x = dataset[dataset.bounds]
assert isinstance(x, dict)
assert isinstance(x['crs'], CRS)
assert isinstance(x['image'], torch.Tensor)
assert isinstance(x['mask'], torch.Tensor)

# If DEM is included, check if 3 channels are present, 2 otherwise
if dataset.include_dem:
assert x['image'].size(0) == 3
else:
assert x['image'].size(0) == 2
return

def test_len(self, dataset: MMFlood) -> None:
if dataset.split == 'train':
assert len(dataset) == 5
elif dataset.split == 'val':
assert len(dataset) == 2
else:
assert len(dataset) == 1

def test_and(self, dataset: MMFlood) -> None:
ds = dataset & dataset
assert isinstance(ds, IntersectionDataset)

def test_or(self, dataset: MMFlood) -> None:
ds = dataset | dataset
assert isinstance(ds, UnionDataset)

def test_already_downloaded(self, dataset: MMFlood) -> None:
MMFlood(root=dataset.root)

def test_not_downloaded(self, tmp_path: Path) -> None:
with pytest.raises(DatasetNotFoundError, match='Dataset not found'):
MMFlood(tmp_path)

def test_plot(self, dataset: MMFlood) -> None:
x = dataset[dataset.bounds]
dataset.plot(x, suptitle='Test')
plt.close()

def test_plot_prediction(self, dataset: MMFlood) -> None:
x = dataset[dataset.bounds]
x['prediction'] = x['mask'].clone()
dataset.plot(x, suptitle='Prediction')
plt.close()

def test_invalid_query(self, dataset: MMFlood) -> None:
query = BoundingBox(0, 0, 0, 0, 0, 0)
with pytest.raises(
IndexError, match='query: .* not found in index with bounds:'
):
dataset[query]

def test_check_folders(self, tmp_path: Path, monkeypatch: MonkeyPatch) -> None:
class MockMMFlood(MMFlood):
def _load_folders(
self, check_folders: bool = False
) -> list[dict[str, str]]:
return super()._load_folders(check_folders=False)

dataset_root = os.path.join('tests', 'data', 'mmflood/')
url = os.path.join(dataset_root)

monkeypatch.setattr(MMFlood, 'url', url)
monkeypatch.setattr(MMFlood, '_nparts', 2)

_ = MockMMFlood(
tmp_path,
split='train',
include_dem=True,
transforms=nn.Identity(),
download=True,
checksum=True,
)
return
2 changes: 2 additions & 0 deletions torchgeo/datamodules/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,7 @@
from .landcoverai import LandCoverAI100DataModule, LandCoverAIDataModule
from .levircd import LEVIRCDDataModule, LEVIRCDPlusDataModule
from .loveda import LoveDADataModule
from .mmflood import MMFloodDataModule
from .naip import NAIPChesapeakeDataModule
from .nasa_marine_debris import NASAMarineDebrisDataModule
from .oscd import OSCDDataModule
Expand Down Expand Up @@ -87,6 +88,7 @@
'LandCoverAI100DataModule',
'LandCoverAIDataModule',
'LoveDADataModule',
'MMFloodDataModule',
'MisconfigurationException',
'NAIPChesapeakeDataModule',
'NASAMarineDebrisDataModule',
Expand Down
Loading
Loading