Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Added substation segementation dataset #2352

Open
wants to merge 73 commits into
base: main
Choose a base branch
from
Open
Changes from 1 commit
Commits
Show all changes
73 commits
Select commit Hold shift + click to select a range
7dff61c
Added substation segementation dataset
rijuld Oct 17, 2024
10637af
resolved bugs
rijuld Oct 21, 2024
2cb0842
a
rijuld Oct 21, 2024
608f76a
Resolved error
rijuld Oct 21, 2024
288e8b1
fixed ruff errors
rijuld Oct 21, 2024
2e9bf83
fixed mypy errors for substation seg py file
rijuld Oct 21, 2024
78c494d
removed more errors
rijuld Oct 21, 2024
75ca32c
resolved ruff errors and mypy errors
rijuld Oct 24, 2024
e2326cc
fixed length and data size along with ruff and mypy errors
rijuld Oct 25, 2024
9832db4
resolved float error
rijuld Oct 25, 2024
ef79cd7
organized imports
rijuld Oct 25, 2024
83f2eb4
changed to float
rijuld Oct 25, 2024
69f5815
resolved mypy errors
rijuld Oct 27, 2024
898e6b3
resolved further tests
rijuld Oct 27, 2024
d14eca6
sorted imports
rijuld Oct 27, 2024
d6ae700
more test coverage
rijuld Oct 30, 2024
8892f0d
ruff format
rijuld Oct 30, 2024
3f135b4
increased test code coverage
rijuld Oct 30, 2024
9a05811
added formatting
rijuld Oct 30, 2024
4e65b04
removed transformations so that I can add them in data module
rijuld Oct 30, 2024
9a9d555
increased underline length
rijuld Oct 30, 2024
3e12e7e
corrected csv row length
rijuld Oct 30, 2024
bbba17b
Update datasets.rst
zijinyin Nov 24, 2024
4fffc1f
Update non_geo_datasets.csv
zijinyin Nov 24, 2024
598c4be
Merge pull request #3 from zijinyin/patch-4
rijuld Nov 25, 2024
15a8881
Merge pull request #1 from zijinyin/patch-2
rijuld Nov 25, 2024
095b7dd
added comment for dataset
rijuld Nov 25, 2024
b503817
changed name to substation
rijuld Nov 25, 2024
f28e30c
added copyright
rijuld Nov 25, 2024
fe1761d
corrected issues
rijuld Nov 25, 2024
c4c3545
added plot and tests
rijuld Nov 25, 2024
1817132
removed pytest
rijuld Nov 25, 2024
28377f8
ruff format
rijuld Nov 25, 2024
5af4e0f
Merge branch 'main' into main
rijuld Nov 26, 2024
a3b95ba
added extract function
rijuld Dec 2, 2024
1216da4
added import
rijuld Dec 2, 2024
b0c3c90
Merge branch 'main' into main
rijuld Dec 2, 2024
545ff66
added datamodule
rijuld Dec 5, 2024
4a6e349
addressed few comments
rijuld Jan 1, 2025
dcc98ef
changed image size
rijuld Jan 1, 2025
d8147ed
removed argument for image files
rijuld Jan 1, 2025
23adef5
added homepage for dataset
rijuld Jan 1, 2025
14e3e51
added ruff format
rijuld Jan 1, 2025
fe12d52
removed mypy errors
rijuld Jan 1, 2025
337b002
fixed the remaining mypy errors
rijuld Jan 1, 2025
4aedf93
Merge branch 'main' into main
rijuld Jan 1, 2025
c7fc761
fixed all the existing tests
rijuld Jan 5, 2025
8e09e8a
added datamodule testing files
rijuld Jan 5, 2025
7626f28
Merge branch 'main' into main
rijuld Jan 8, 2025
d35b435
changed the datatype of bands to list[int] form int
Jan 8, 2025
173a915
changed bands datatype from datamodule
rijuld Jan 8, 2025
cfe800d
changed num of bands variables
rijuld Jan 8, 2025
ebcc36f
Added substation in datamodules.rst and resolved datasets.rst length …
rijuld Jan 8, 2025
f1fcdf0
added substation datamodule in init
rijuld Jan 8, 2025
f00bcd2
chanded the data type of normalizing factor to Any
rijuld Jan 8, 2025
280e32a
[just for testing]
rijuld Jan 8, 2025
743113e
[for testing]
rijuld Jan 8, 2025
85bb9c9
Added parent class
rijuld Jan 8, 2025
de5b337
removed patch size
rijuld Jan 8, 2025
3285346
removed unwanted key
rijuld Jan 8, 2025
d1f062f
resolved errors and tested data module using conf file
rijuld Jan 19, 2025
aebe183
resolved some ruff issues
rijuld Jan 19, 2025
a01c3b4
Merge branch 'main' into main
rijuld Jan 19, 2025
5de36d4
fixed another ruff error
rijuld Jan 19, 2025
7c8c71a
fixed ruff issue
rijuld Jan 19, 2025
6c2b1cb
added more test coverage for extract and verify
rijuld Jan 19, 2025
d4bf9fb
organized imports
rijuld Jan 19, 2025
9a050bd
added more tests for dataset
rijuld Jan 19, 2025
8c918a8
added identity for init values
rijuld Jan 19, 2025
39668ca
ruff format
rijuld Jan 19, 2025
b3af64a
removed pytest command from test file
rijuld Jan 19, 2025
8355860
ruff format
rijuld Jan 19, 2025
5091e16
Merge branch 'main' into main
rijuld Jan 20, 2025
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Prev Previous commit
Next Next commit
removed mypy errors
rijuld committed Jan 1, 2025
commit fe12d523a4766d4826998f9a5d0f7c3aaca291ea
14 changes: 9 additions & 5 deletions tests/datasets/test_substation.py
Original file line number Diff line number Diff line change
@@ -15,9 +15,10 @@
from torchgeo.datasets import Substation


class Substation:
class TestSubstation:
@pytest.fixture
def dataset(
self,
tmp_path: Path, monkeypatch: pytest.MonkeyPatch
) -> Generator[Substation, None, None]:
"""Fixture for the Substation."""
@@ -29,6 +30,7 @@ def dataset(
use_timepoints=True,
mask_2d=True,
timepoint_aggregation='median',
num_of_timepoints=4
)

@pytest.mark.parametrize(
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Shouldn't you parametrize the fixture instead of the unit test? Then all other unit tests will also be parametrized.

@@ -74,7 +76,7 @@ def dataset(
},
],
)
def test_getitem_semantic(config: dict[str, Any]) -> None:
def test_getitem_semantic(self, config: dict[str, Any]) -> None:
root = os.path.join(os.getcwd(), 'tests', 'data')
dataset = Substation(root=root, **config)

@@ -85,17 +87,17 @@ def test_getitem_semantic(config: dict[str, Any]) -> None:
), 'Expected image to be a torch.Tensor'
assert isinstance(x['mask'], torch.Tensor), 'Expected mask to be a torch.Tensor'

def test_len(dataset: Substation) -> None:
def test_len(self, dataset: Substation) -> None:
"""Test the length of the dataset."""
assert len(dataset) == 2

def test_output_shape(dataset: Substation) -> None:
def test_output_shape(self, dataset: Substation) -> None:
"""Test the output shape of the dataset."""
x = dataset[0]
assert x['image'].shape == torch.Size([13, 32, 32])
assert x['mask'].shape == torch.Size([2, 32, 32])

def test_plot(dataset: Substation) -> None:
def test_plot(self, dataset: Substation) -> None:
sample = dataset[0]
dataset.plot(sample, suptitle='Test')
plt.close()
@@ -106,6 +108,7 @@ def test_plot(dataset: Substation) -> None:
plt.close()

def test_already_downloaded(
self,
dataset: Substation, tmp_path: Path, monkeypatch: pytest.MonkeyPatch
) -> None:
"""Test that the dataset doesn't re-download if already present."""
@@ -125,6 +128,7 @@ def test_already_downloaded(
dataset._download() # This will now call the mocked method

def test_download(
self,
dataset: Substation, tmp_path: Path, monkeypatch: pytest.MonkeyPatch
) -> None:
"""Test the _download method of the dataset."""
60 changes: 17 additions & 43 deletions torchgeo/datamodules/substation.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,4 @@
# Copyright (c) Microsoft Corporation. All rights reserved.
# Licensed under the MIT License.

"""Substation Data Module."""

from typing import Any
from typing import Any, List, Optional

Check failure on line 1 in torchgeo/datamodules/substation.py

GitHub Actions / ruff

Ruff (UP035)

torchgeo/datamodules/substation.py:1:1: UP035 `typing.List` is deprecated, use `list` instead

Check failure on line 1 in torchgeo/datamodules/substation.py

GitHub Actions / ruff

Ruff (D100)

torchgeo/datamodules/substation.py:1:1: D100 Missing docstring in public module

Check failure on line 1 in torchgeo/datamodules/substation.py

GitHub Actions / ruff

Ruff (F401)

torchgeo/datamodules/substation.py:1:25: F401 `typing.List` imported but unused

Check failure on line 1 in torchgeo/datamodules/substation.py

GitHub Actions / ruff

Ruff (F401)

torchgeo/datamodules/substation.py:1:31: F401 `typing.Optional` imported but unused

import numpy as np
import torch
@@ -30,12 +25,12 @@
means: np.ndarray | None = None,
stds: np.ndarray | None = None,
bands: int = 13,
num_of_timepoints: int | None = None,
num_of_timepoints: int = 4,
model_type: str = 'default',
geo_transforms: Any = None,
color_transforms: Any = None,
image_resize: Any = None,
mask_resize: Any = None,
geo_transforms: Any | None = None,
color_transforms: Any | None = None,
image_resize: Any | None = None,
mask_resize: Any | None = None,
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We try to avoid dynamic typing, can you be more specific than Any which disables all type checking?

**kwargs: Any,
) -> None:
"""Initialize a new SubstationDataModule instance.
@@ -58,13 +53,6 @@
mask_resize: Resizing function for the mask.
**kwargs: Additional arguments passed to Substation.
"""
# super().__init__(
# Substation,
# root=root,
# batch_size=batch_size,
# num_workers=num_workers,
# **kwargs,
# )
self.root = root
self.split_ratio = split_ratio
self.normalizing_type = normalizing_type
@@ -79,10 +67,9 @@
self.mask_resize = mask_resize
self.num_of_timepoints = num_of_timepoints

# Placeholder for datasets
self.train_dataset = None
self.val_dataset = None
self.test_dataset = None
self.train_dataset: Subset[Any] | None = None
self.val_dataset: Subset[Any] | None = None
self.test_dataset: Subset[Any] | None = None

def setup(self, stage: str) -> None:
"""Set up datasets.
@@ -101,35 +88,34 @@
checksum=False,
)

# Train-test split
total_size = len(dataset)
train_size = int(total_size * self.split_ratio)
indices = list(range(total_size))
train_indices: Subset[Any]
test_indices: Subset[Any]
train_indices, test_indices = torch.utils.data.random_split(
indices, [train_size, total_size - train_size]
dataset, [train_size, total_size - train_size]
)

if stage in ['fit', 'validate']:
# Further split train set into train imageand validation sets
val_split_ratio = 0.2
val_size = int(len(train_indices) * val_split_ratio)
train_size = len(train_indices) - val_size
val_indices: Subset[Any]
train_indices, val_indices = torch.utils.data.random_split(
train_indices, [train_size, val_size]
)

self.train_dataset = Subset(dataset, train_indices)
self.val_dataset = Subset(dataset, val_indices)
self.train_dataset = Subset(dataset, train_indices.indices)
self.val_dataset = Subset(dataset, val_indices.indices)

# Apply preprocessing to train and validation datasets
self.train_dataset = self._apply_transforms(self.train_dataset)
self.val_dataset = self._apply_transforms(self.val_dataset)

if stage == 'test':
self.test_dataset = Subset(dataset, test_indices)
self.test_dataset = Subset(dataset, test_indices.indices)
self.test_dataset = self._apply_transforms(self.test_dataset)

def _apply_transforms(self, dataset: Subset) -> Subset:
def _apply_transforms(self, dataset: Subset[Any]) -> Subset[Any]:
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

You should use self.aug, self.train_aug, etc. to store the augmentations you want to perform. These are then automatically run on the GPU for you, no need to run them manually. See https://torchgeo.readthedocs.io/en/latest/tutorials/contribute_datamodule.html, torchgeo/datamodules/geo.py, and every other builtin datamodule for examples of this.

"""Apply preprocessing and transformations to the dataset.

Args:
@@ -141,22 +127,11 @@
for sample in tqdm(dataset, desc='Processing images', unit='sample'):
image, mask = sample['image'], sample['mask']

# Standardizing image
# if self.normalizing_type == "percentile":
# image = (image - self.normalizing_factor[:, 0].reshape((-1, 1, 1))) / self.normalizing_factor[:, 2].reshape((-1, 1, 1))
# elif self.normalizing_type == "zscore":
# image = (image - self.means) / self.stds
# else:
# image = image / self.normalizing_factor
# image = torch.clamp(image, 0, 1)

# Applying geometric transformations
if self.geo_transforms:
combined = torch.cat((image, mask), 0)
combined = self.geo_transforms(combined)
image, mask = torch.split(combined, [image.shape[0], mask.shape[0]], 0)

# Applying color transformations
if self.color_transforms:
num_timepoints = image.shape[0] // self.bands
for i in range(num_timepoints):
@@ -171,7 +146,6 @@
'Input dimensions must support color transformations.'
)

# Resizing image and mask
if self.image_resize:
image = self.image_resize(image)
if self.mask_resize:
2 changes: 1 addition & 1 deletion torchgeo/datasets/substation.py
Original file line number Diff line number Diff line change
@@ -43,7 +43,7 @@
url_for_images = 'https://storage.googleapis.com/tz-ml-public/substation-over-10km2-csv-main-444e360fd2b6444b9018d509d0e4f36e/image_stack.tar.gz'
url_for_masks = 'https://storage.googleapis.com/tz-ml-public/substation-over-10km2-csv-main-444e360fd2b6444b9018d509d0e4f36e/mask.tar.gz'

def __init__(

Check failure on line 46 in torchgeo/datasets/substation.py

GitHub Actions / ruff

Ruff (D417)

torchgeo/datasets/substation.py:46:9: D417 Missing argument description in the docstring for `__init__`: `num_of_timepoints`
self,
root: str,
adamjstewart marked this conversation as resolved.
Show resolved Hide resolved
bands: int,
@@ -79,7 +79,7 @@
self._verify()
self.load_image_filenames()

def load_image_filenames(self) -> list[str]:
def load_image_filenames(self) -> None:
"""Load image filenames from the image directory."""
self.image_filenames = os.listdir(self.image_dir)