Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add ChangeDetectionTask #2422

Draft
wants to merge 11 commits into
base: main
Choose a base branch
from
15 changes: 15 additions & 0 deletions tests/conf/oscd.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,15 @@
model:
class_path: ChangeDetectionTask
init_args:
loss: 'bce'
model: 'unet'
backbone: 'resnet18'
in_channels: 13
adamjstewart marked this conversation as resolved.
Show resolved Hide resolved
data:
class_path: OSCDDataModule
init_args:
batch_size: 2
patch_size: 16
val_split_pct: 0.5
dict_kwargs:
root: 'tests/data/oscd'
7 changes: 4 additions & 3 deletions tests/trainers/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,8 +6,8 @@
from pathlib import Path

import pytest
import timm
import torch
import torchvision
from _pytest.fixtures import SubRequest
from torch import Tensor
from torch.nn.modules import Module
Expand All @@ -22,8 +22,9 @@ def fast_dev_run(request: SubRequest) -> bool:


@pytest.fixture(scope='package')
def model() -> Module:
model: Module = torchvision.models.resnet18(weights=None)
def model(request: SubRequest) -> Module:
in_channels = getattr(request, 'param', 3)
model: Module = timm.create_model('resnet18', in_chans=in_channels)
return model


Expand Down
169 changes: 169 additions & 0 deletions tests/trainers/test_change.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,169 @@
# Copyright (c) Microsoft Corporation. All rights reserved.
# Licensed under the MIT License.

import os
from pathlib import Path
from typing import Any, cast

import pytest
import segmentation_models_pytorch as smp
import timm
import torch
import torch.nn as nn
from pytest import MonkeyPatch
from torch.nn.modules import Module
from torchvision.models._api import WeightsEnum

from torchgeo.datamodules import MisconfigurationException
from torchgeo.main import main
from torchgeo.models import ResNet18_Weights
from torchgeo.trainers import ChangeDetectionTask


class ChangeDetectionTestModel(Module):
def __init__(self, in_channels: int = 3, classes: int = 3, **kwargs: Any) -> None:
super().__init__()
self.conv1 = nn.Conv2d(
in_channels=in_channels, out_channels=classes, kernel_size=1, padding=0
)

def forward(self, x: torch.Tensor) -> torch.Tensor:
return cast(torch.Tensor, self.conv1(x))


def create_model(**kwargs: Any) -> Module:
return ChangeDetectionTestModel(**kwargs)


class TestChangeDetectionTask:
@pytest.mark.parametrize('name', ['oscd'])
def test_trainer(
self, monkeypatch: MonkeyPatch, name: str, fast_dev_run: bool
) -> None:
config = os.path.join('tests', 'conf', name + '.yaml')

monkeypatch.setattr(smp, 'Unet', create_model)

args = [
'--config',
config,
'--trainer.accelerator',
'cpu',
'--trainer.fast_dev_run',
str(fast_dev_run),
'--trainer.max_epochs',
'1',
'--trainer.log_every_n_steps',
'1',
]

main(['fit', *args])
try:
main(['test', *args])
except MisconfigurationException:
pass
try:
main(['predict', *args])
except MisconfigurationException:
pass

@pytest.fixture
def weights(self) -> WeightsEnum:
return ResNet18_Weights.SENTINEL2_ALL_MOCO

@pytest.fixture
def mocked_weights(
self,
tmp_path: Path,
monkeypatch: MonkeyPatch,
weights: WeightsEnum,
load_state_dict_from_url: None,
) -> WeightsEnum:
path = tmp_path / f'{weights}.pth'
# multiply in_chans by 2 since images are concatenated
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

How hard would it be to do late fusion, so pass each image through the encoder separately, then concatenate them, then pass them through the decoder? This would make it easier to use pre-trained models.

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It's definitely possible, although I think we would need a custom Unet implementation in torchgeo/models to do this. It would simplify using the pretrained weights but is late fusion a common enough approach that many people would find this useful?

model = timm.create_model(
weights.meta['model'], in_chans=weights.meta['in_chans'] * 2
)
torch.save(model.state_dict(), path)
try:
monkeypatch.setattr(weights.value, 'url', str(path))
except AttributeError:
monkeypatch.setattr(weights, 'url', str(path))
return weights

@pytest.mark.parametrize('model', [6], indirect=True)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Remind me what [6] means here?

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Number of input channels (2 3-channel images stacked)

def test_weight_file(self, checkpoint: str) -> None:
ChangeDetectionTask(backbone='resnet18', weights=checkpoint)

def test_weight_enum(self, mocked_weights: WeightsEnum) -> None:
ChangeDetectionTask(
backbone=mocked_weights.meta['model'],
weights=mocked_weights,
in_channels=mocked_weights.meta['in_chans'],
)

def test_weight_str(self, mocked_weights: WeightsEnum) -> None:
ChangeDetectionTask(
backbone=mocked_weights.meta['model'],
weights=str(mocked_weights),
in_channels=mocked_weights.meta['in_chans'],
)

@pytest.mark.slow
def test_weight_enum_download(self, weights: WeightsEnum) -> None:
ChangeDetectionTask(
backbone=weights.meta['model'],
weights=weights,
in_channels=weights.meta['in_chans'],
)

@pytest.mark.slow
def test_weight_str_download(self, weights: WeightsEnum) -> None:
ChangeDetectionTask(
backbone=weights.meta['model'],
weights=str(weights),
in_channels=weights.meta['in_chans'],
)

def test_invalid_model(self) -> None:
match = "Model type 'invalid_model' is not valid."
with pytest.raises(ValueError, match=match):
ChangeDetectionTask(model='invalid_model')

def test_invalid_loss(self) -> None:
match = "Loss type 'invalid_loss' is not valid."
with pytest.raises(ValueError, match=match):
ChangeDetectionTask(loss='invalid_loss')

@pytest.mark.parametrize('model_name', ['unet'])
@pytest.mark.parametrize(
'backbone', ['resnet18', 'mobilenet_v2', 'efficientnet-b0']
)
def test_freeze_backbone(self, model_name: str, backbone: str) -> None:
model = ChangeDetectionTask(
model=model_name, backbone=backbone, freeze_backbone=True
)
assert all(
[param.requires_grad is False for param in model.model.encoder.parameters()]
)
assert all([param.requires_grad for param in model.model.decoder.parameters()])
assert all(
[
param.requires_grad
for param in model.model.segmentation_head.parameters()
]
)

@pytest.mark.parametrize('model_name', ['unet'])
def test_freeze_decoder(self, model_name: str) -> None:
model = ChangeDetectionTask(model=model_name, freeze_decoder=True)
assert all(
[param.requires_grad is False for param in model.model.decoder.parameters()]
)
assert all([param.requires_grad for param in model.model.encoder.parameters()])
assert all(
[
param.requires_grad
for param in model.model.segmentation_head.parameters()
]
)
8 changes: 5 additions & 3 deletions torchgeo/datamodules/oscd.py
Original file line number Diff line number Diff line change
Expand Up @@ -86,9 +86,11 @@ def __init__(
self.std = torch.tensor([STD[b] for b in self.bands])

self.aug = AugmentationSequential(
K.Normalize(mean=self.mean, std=self.std),
_RandomNCrop(self.patch_size, batch_size),
data_keys=['image1', 'image2', 'mask'],
K.VideoSequential(
K.Normalize(mean=self.mean, std=self.std),
_RandomNCrop(self.patch_size, batch_size),
),
data_keys=['image', 'mask'],
adamjstewart marked this conversation as resolved.
Show resolved Hide resolved
)

def setup(self, stage: str) -> None:
Expand Down
8 changes: 5 additions & 3 deletions torchgeo/datasets/oscd.py
Original file line number Diff line number Diff line change
Expand Up @@ -150,7 +150,8 @@ def __getitem__(self, index: int) -> dict[str, Tensor]:
image1 = self._load_image(files['images1'])
image2 = self._load_image(files['images2'])
mask = self._load_target(str(files['mask']))
sample = {'image1': image1, 'image2': image2, 'mask': mask}
image = torch.stack(tensors=[image1, image2], dim=0)
sample = {'image': image, 'mask': mask}

if self.transforms is not None:
sample = self.transforms(sample)
Expand All @@ -169,7 +170,8 @@ def _load_files(self) -> list[dict[str, str | Sequence[str]]]:
regions = []
labels_root = os.path.join(
self.root,
f'Onera Satellite Change Detection dataset - {self.split.capitalize()} '
f'Onera Satellite Change Detection dataset - {
self.split.capitalize()} '
+ 'Labels',
)
images_root = os.path.join(
Expand Down Expand Up @@ -240,7 +242,7 @@ def _load_target(self, path: Path) -> Tensor:
array: np.typing.NDArray[np.int_] = np.array(img.convert('L'))
tensor = torch.from_numpy(array)
tensor = torch.clamp(tensor, min=0, max=1)
tensor = tensor.to(torch.long)
tensor = tensor.to(torch.float)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why would the target be a float?

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The loss function BCEWithLogitsLoss expects the target to be a float.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why do we have to use BCEWithLogitsLoss? Can we use BCELoss instead?

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Both BCELoss and BCEWithLogitsLoss require float targets. Here's a brief explanation I found as to why: https://discuss.pytorch.org/t/inconsistency-between-loss-functions-input-types/138942. Is there any issue with the target being converted to a float here?

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ah, I see. For our binary classification datasets, we convert the label to a float in MultiLabelClassificationTask instead of in the dataset. I would kind of like our datasets to be consistent (int for classification and float for regression). Let's change it in ChangeDetectionTask instead.

return tensor

def _verify(self) -> None:
Expand Down
4 changes: 3 additions & 1 deletion torchgeo/losses/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,8 @@

"""TorchGeo losses."""

from .focaljaccard import BinaryFocalJaccardLoss
from .qr import QRLoss, RQLoss
from .xentjaccard import BinaryXEntJaccardLoss

__all__ = ('QRLoss', 'RQLoss')
__all__ = ('QRLoss', 'RQLoss', 'BinaryFocalJaccardLoss', 'BinaryXEntJaccardLoss')
adamjstewart marked this conversation as resolved.
Show resolved Hide resolved
27 changes: 27 additions & 0 deletions torchgeo/losses/focaljaccard.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,27 @@
# Copyright (c) Microsoft Corporation. All rights reserved.
# Licensed under the MIT License.

"""Focal Jaccard loss functions."""

from typing import cast

import segmentation_models_pytorch as smp
import torch
import torch.nn as nn


class BinaryFocalJaccardLoss(nn.Module):
"""Binary Focal Jaccard Loss."""

def __init__(self) -> None:
"""Initialize a BinaryFocalJaccardLoss instance."""
super().__init__()
self.focal_loss = smp.losses.FocalLoss(mode='binary', normalized=True)
self.jaccard_loss = smp.losses.JaccardLoss(mode='binary')

def forward(self, preds: torch.Tensor, targets: torch.Tensor) -> torch.Tensor:
"""Compute the loss."""
return cast(
torch.Tensor,
self.focal_loss(preds, targets) + self.jaccard_loss(preds, targets),
)
27 changes: 27 additions & 0 deletions torchgeo/losses/xentjaccard.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,27 @@
# Copyright (c) Microsoft Corporation. All rights reserved.
# Licensed under the MIT License.

"""Cross-Entropy Jaccard loss functions."""

from typing import cast

import segmentation_models_pytorch as smp
import torch
import torch.nn as nn


class BinaryXEntJaccardLoss(nn.Module):
"""Binary Cross-Entropy Jaccard Loss."""

def __init__(self) -> None:
"""Initialize a BinaryXEntJaccardLoss instance."""
super().__init__()
self.bce_loss = nn.BCEWithLogitsLoss()
self.jaccard_loss = smp.losses.JaccardLoss(mode='binary')

def forward(self, preds: torch.Tensor, targets: torch.Tensor) -> torch.Tensor:
"""Compute the loss."""
return cast(
torch.Tensor,
self.bce_loss(preds, targets) + self.jaccard_loss(preds, targets),
)
2 changes: 2 additions & 0 deletions torchgeo/trainers/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@

from .base import BaseTask
from .byol import BYOLTask
from .change import ChangeDetectionTask
from .classification import ClassificationTask, MultiLabelClassificationTask
from .detection import ObjectDetectionTask
from .iobench import IOBenchTask
Expand All @@ -15,6 +16,7 @@

__all__ = (
# Supervised
'ChangeDetectionTask',
'ClassificationTask',
'MultiLabelClassificationTask',
'ObjectDetectionTask',
Expand Down
Loading
Loading