-
Notifications
You must be signed in to change notification settings - Fork 388
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Added substation segementation dataset #2352
base: main
Are you sure you want to change the base?
Changes from 1 commit
7dff61c
10637af
2cb0842
608f76a
288e8b1
2e9bf83
78c494d
75ca32c
e2326cc
9832db4
ef79cd7
83f2eb4
69f5815
898e6b3
d14eca6
d6ae700
8892f0d
3f135b4
9a05811
4e65b04
9a9d555
3e12e7e
bbba17b
4fffc1f
598c4be
15a8881
095b7dd
b503817
f28e30c
fe1761d
c4c3545
1817132
28377f8
5af4e0f
a3b95ba
1216da4
b0c3c90
545ff66
4a6e349
dcc98ef
d8147ed
23adef5
14e3e51
fe12d52
337b002
4aedf93
c7fc761
8e09e8a
7626f28
d35b435
173a915
cfe800d
ebcc36f
f1fcdf0
f00bcd2
280e32a
743113e
85bb9c9
de5b337
3285346
d1f062f
aebe183
a01c3b4
5de36d4
7c8c71a
6c2b1cb
d4bf9fb
9a050bd
8c918a8
39668ca
b3af64a
8355860
5091e16
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,9 +1,4 @@ | ||
# Copyright (c) Microsoft Corporation. All rights reserved. | ||
# Licensed under the MIT License. | ||
|
||
"""Substation Data Module.""" | ||
|
||
from typing import Any | ||
from typing import Any, List, Optional | ||
Check failure on line 1 in torchgeo/datamodules/substation.py GitHub Actions / ruffRuff (UP035)
Check failure on line 1 in torchgeo/datamodules/substation.py GitHub Actions / ruffRuff (D100)
Check failure on line 1 in torchgeo/datamodules/substation.py GitHub Actions / ruffRuff (F401)
|
||
|
||
import numpy as np | ||
import torch | ||
|
@@ -30,12 +25,12 @@ | |
means: np.ndarray | None = None, | ||
stds: np.ndarray | None = None, | ||
bands: int = 13, | ||
num_of_timepoints: int | None = None, | ||
num_of_timepoints: int = 4, | ||
model_type: str = 'default', | ||
geo_transforms: Any = None, | ||
color_transforms: Any = None, | ||
image_resize: Any = None, | ||
mask_resize: Any = None, | ||
geo_transforms: Any | None = None, | ||
color_transforms: Any | None = None, | ||
image_resize: Any | None = None, | ||
mask_resize: Any | None = None, | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. We try to avoid dynamic typing, can you be more specific than |
||
**kwargs: Any, | ||
) -> None: | ||
"""Initialize a new SubstationDataModule instance. | ||
|
@@ -58,13 +53,6 @@ | |
mask_resize: Resizing function for the mask. | ||
**kwargs: Additional arguments passed to Substation. | ||
""" | ||
# super().__init__( | ||
# Substation, | ||
# root=root, | ||
# batch_size=batch_size, | ||
# num_workers=num_workers, | ||
# **kwargs, | ||
# ) | ||
self.root = root | ||
self.split_ratio = split_ratio | ||
self.normalizing_type = normalizing_type | ||
|
@@ -79,10 +67,9 @@ | |
self.mask_resize = mask_resize | ||
self.num_of_timepoints = num_of_timepoints | ||
|
||
# Placeholder for datasets | ||
self.train_dataset = None | ||
self.val_dataset = None | ||
self.test_dataset = None | ||
self.train_dataset: Subset[Any] | None = None | ||
self.val_dataset: Subset[Any] | None = None | ||
self.test_dataset: Subset[Any] | None = None | ||
|
||
def setup(self, stage: str) -> None: | ||
"""Set up datasets. | ||
|
@@ -101,35 +88,34 @@ | |
checksum=False, | ||
) | ||
|
||
# Train-test split | ||
total_size = len(dataset) | ||
train_size = int(total_size * self.split_ratio) | ||
indices = list(range(total_size)) | ||
train_indices: Subset[Any] | ||
test_indices: Subset[Any] | ||
train_indices, test_indices = torch.utils.data.random_split( | ||
indices, [train_size, total_size - train_size] | ||
dataset, [train_size, total_size - train_size] | ||
) | ||
|
||
if stage in ['fit', 'validate']: | ||
# Further split train set into train imageand validation sets | ||
val_split_ratio = 0.2 | ||
val_size = int(len(train_indices) * val_split_ratio) | ||
train_size = len(train_indices) - val_size | ||
val_indices: Subset[Any] | ||
train_indices, val_indices = torch.utils.data.random_split( | ||
train_indices, [train_size, val_size] | ||
) | ||
|
||
self.train_dataset = Subset(dataset, train_indices) | ||
self.val_dataset = Subset(dataset, val_indices) | ||
self.train_dataset = Subset(dataset, train_indices.indices) | ||
self.val_dataset = Subset(dataset, val_indices.indices) | ||
|
||
# Apply preprocessing to train and validation datasets | ||
self.train_dataset = self._apply_transforms(self.train_dataset) | ||
self.val_dataset = self._apply_transforms(self.val_dataset) | ||
|
||
if stage == 'test': | ||
self.test_dataset = Subset(dataset, test_indices) | ||
self.test_dataset = Subset(dataset, test_indices.indices) | ||
self.test_dataset = self._apply_transforms(self.test_dataset) | ||
|
||
def _apply_transforms(self, dataset: Subset) -> Subset: | ||
def _apply_transforms(self, dataset: Subset[Any]) -> Subset[Any]: | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. You should use |
||
"""Apply preprocessing and transformations to the dataset. | ||
|
||
Args: | ||
|
@@ -141,22 +127,11 @@ | |
for sample in tqdm(dataset, desc='Processing images', unit='sample'): | ||
image, mask = sample['image'], sample['mask'] | ||
|
||
# Standardizing image | ||
# if self.normalizing_type == "percentile": | ||
# image = (image - self.normalizing_factor[:, 0].reshape((-1, 1, 1))) / self.normalizing_factor[:, 2].reshape((-1, 1, 1)) | ||
# elif self.normalizing_type == "zscore": | ||
# image = (image - self.means) / self.stds | ||
# else: | ||
# image = image / self.normalizing_factor | ||
# image = torch.clamp(image, 0, 1) | ||
|
||
# Applying geometric transformations | ||
if self.geo_transforms: | ||
combined = torch.cat((image, mask), 0) | ||
combined = self.geo_transforms(combined) | ||
image, mask = torch.split(combined, [image.shape[0], mask.shape[0]], 0) | ||
|
||
# Applying color transformations | ||
if self.color_transforms: | ||
num_timepoints = image.shape[0] // self.bands | ||
for i in range(num_timepoints): | ||
|
@@ -171,7 +146,6 @@ | |
'Input dimensions must support color transformations.' | ||
) | ||
|
||
# Resizing image and mask | ||
if self.image_resize: | ||
image = self.image_resize(image) | ||
if self.mask_resize: | ||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Shouldn't you parametrize the fixture instead of the unit test? Then all other unit tests will also be parametrized.