diff --git a/tests/models/test_rcf.py b/tests/models/test_rcf.py new file mode 100644 index 00000000000..d6bd796b0d3 --- /dev/null +++ b/tests/models/test_rcf.py @@ -0,0 +1,37 @@ +# Copyright (c) Microsoft Corporation. All rights reserved. +# Licensed under the MIT License. + +import pytest +import torch + +from torchgeo.models import RCF + + +class TestRCF: + def test_in_channels(self) -> None: + model = RCF(in_channels=5, features=4, kernel_size=3) + x = torch.randn(2, 5, 64, 64) + model(x) + + model = RCF(in_channels=3, features=4, kernel_size=3) + match = "to have 3 channels, but got 5 channels instead" + with pytest.raises(RuntimeError, match=match): + model(x) + + def test_num_features(self) -> None: + model = RCF(in_channels=5, features=4, kernel_size=3) + x = torch.randn(2, 5, 64, 64) + y = model(x) + assert y.shape[1] == 4 + + x = torch.randn(1, 5, 64, 64) + y = model(x) + assert y.shape[0] == 4 + + def test_untrainable(self) -> None: + model = RCF(in_channels=5, features=4, kernel_size=3) + assert len(list(model.parameters())) == 0 + + def test_biases(self) -> None: + model = RCF(features=24, bias=10) + assert torch.all(model.biases == 10) # type: ignore[attr-defined] diff --git a/torchgeo/models/__init__.py b/torchgeo/models/__init__.py index 1742465941f..3402b76ffcf 100644 --- a/torchgeo/models/__init__.py +++ b/torchgeo/models/__init__.py @@ -7,6 +7,7 @@ from .farseg import FarSeg from .fccd import FCEF, FCSiamConc, FCSiamDiff from .fcn import FCN +from .rcf import RCF __all__ = ( "ChangeMixin", @@ -17,6 +18,7 @@ "FCEF", "FCSiamConc", "FCSiamDiff", + "RCF", ) # https://stackoverflow.com/questions/40018681 diff --git a/torchgeo/models/rcf.py b/torchgeo/models/rcf.py new file mode 100644 index 00000000000..37a594b8f5f --- /dev/null +++ b/torchgeo/models/rcf.py @@ -0,0 +1,99 @@ +# Copyright (c) Microsoft Corporation. All rights reserved. +# Licensed under the MIT License. + +"""Implementation of a random convolutional feature projection model.""" + +from typing import cast + +import torch +import torch.nn.functional as F +from torch import Tensor +from torch.nn.modules import Conv2d, Module + +Module.__module__ = "torch.nn" +Conv2d.__module__ = "torch.nn" + + +class RCF(Module): + """This model extracts random convolutional features (RCFs) from its input. + + RCFs are used in Multi-task Observation using Satellite Imagery & Kitchen Sinks + (MOSAIKS) method proposed in https://www.nature.com/articles/s41467-021-24638-z. + + .. note:: + + This Module is *not* trainable. It is only used as a feature extractor. + """ + + def __init__( + self, + in_channels: int = 4, + features: int = 16, + kernel_size: int = 3, + bias: float = -1.0, + ) -> None: + """Initializes the RCF model. + + This is a static model that serves to extract fixed length feature vectors from + input patches. + + Args: + in_channels: number of input channels + features: number of features to compute, must be divisible by 2 + kernel_size: size of the kernel used to compute the RCFs + bias: bias of the convolutional layer + """ + super().__init__() # type: ignore[no-untyped-call] + + assert features % 2 == 0 + + # We register the weight and bias tensors as "buffers". This does two things: + # makes them behave correctly when we call .to(...) on the module, and makes + # them explicitely _not_ Parameters of the model (which might get updated) if + # a user tries to train with this model. + self.register_buffer( + "weights", + torch.randn( + features // 2, + in_channels, + kernel_size, + kernel_size, + requires_grad=False, + ), + ) + self.register_buffer( + "biases", + torch.zeros( # type: ignore[attr-defined] + features // 2, requires_grad=False + ) + + bias, + ) + + def forward(self, x: Tensor) -> Tensor: + """Forward pass of the RCF model. + + Args: + x: a tensor with shape (B, C, H, W) + + Returns: + a tensor of size (B, ``self.num_features``) + """ + x1a = F.relu( + F.conv2d(x, self.weights, bias=self.biases, stride=1, padding=0), + inplace=True, + ) + x1b = F.relu( + -F.conv2d(x, self.weights, bias=self.biases, stride=1, padding=0), + inplace=False, + ) + + x1a = F.adaptive_avg_pool2d(x1a, (1, 1)).squeeze() + x1b = F.adaptive_avg_pool2d(x1b, (1, 1)).squeeze() + + if len(x1a.shape) == 1: # case where we passed a single input + output = torch.cat((x1a, x1b), dim=0) # type: ignore[attr-defined] + return cast(Tensor, output) + else: # case where we passed a batch of > 1 inputs + assert len(x1a.shape) == 2 + output = torch.cat((x1a, x1b), dim=1) # type: ignore[attr-defined] + return cast(Tensor, output)