Skip to content

Commit

Permalink
Added RCF model to implement the MOSAIKs method (#176)
Browse files Browse the repository at this point in the history
* Added RCF model

* Rename to RCF

* Update tests/models/test_rcf.py

Co-authored-by: Adam J. Stewart <[email protected]>

Co-authored-by: Adam J. Stewart <[email protected]>
  • Loading branch information
calebrob6 and adamjstewart authored Oct 8, 2021
1 parent a5b3099 commit 699dfec
Show file tree
Hide file tree
Showing 3 changed files with 138 additions and 0 deletions.
37 changes: 37 additions & 0 deletions tests/models/test_rcf.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,37 @@
# Copyright (c) Microsoft Corporation. All rights reserved.
# Licensed under the MIT License.

import pytest
import torch

from torchgeo.models import RCF


class TestRCF:
def test_in_channels(self) -> None:
model = RCF(in_channels=5, features=4, kernel_size=3)
x = torch.randn(2, 5, 64, 64)
model(x)

model = RCF(in_channels=3, features=4, kernel_size=3)
match = "to have 3 channels, but got 5 channels instead"
with pytest.raises(RuntimeError, match=match):
model(x)

def test_num_features(self) -> None:
model = RCF(in_channels=5, features=4, kernel_size=3)
x = torch.randn(2, 5, 64, 64)
y = model(x)
assert y.shape[1] == 4

x = torch.randn(1, 5, 64, 64)
y = model(x)
assert y.shape[0] == 4

def test_untrainable(self) -> None:
model = RCF(in_channels=5, features=4, kernel_size=3)
assert len(list(model.parameters())) == 0

def test_biases(self) -> None:
model = RCF(features=24, bias=10)
assert torch.all(model.biases == 10) # type: ignore[attr-defined]
2 changes: 2 additions & 0 deletions torchgeo/models/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
from .farseg import FarSeg
from .fccd import FCEF, FCSiamConc, FCSiamDiff
from .fcn import FCN
from .rcf import RCF

__all__ = (
"ChangeMixin",
Expand All @@ -17,6 +18,7 @@
"FCEF",
"FCSiamConc",
"FCSiamDiff",
"RCF",
)

# https://stackoverflow.com/questions/40018681
Expand Down
99 changes: 99 additions & 0 deletions torchgeo/models/rcf.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,99 @@
# Copyright (c) Microsoft Corporation. All rights reserved.
# Licensed under the MIT License.

"""Implementation of a random convolutional feature projection model."""

from typing import cast

import torch
import torch.nn.functional as F
from torch import Tensor
from torch.nn.modules import Conv2d, Module

Module.__module__ = "torch.nn"
Conv2d.__module__ = "torch.nn"


class RCF(Module):
"""This model extracts random convolutional features (RCFs) from its input.
RCFs are used in Multi-task Observation using Satellite Imagery & Kitchen Sinks
(MOSAIKS) method proposed in https://www.nature.com/articles/s41467-021-24638-z.
.. note::
This Module is *not* trainable. It is only used as a feature extractor.
"""

def __init__(
self,
in_channels: int = 4,
features: int = 16,
kernel_size: int = 3,
bias: float = -1.0,
) -> None:
"""Initializes the RCF model.
This is a static model that serves to extract fixed length feature vectors from
input patches.
Args:
in_channels: number of input channels
features: number of features to compute, must be divisible by 2
kernel_size: size of the kernel used to compute the RCFs
bias: bias of the convolutional layer
"""
super().__init__() # type: ignore[no-untyped-call]

assert features % 2 == 0

# We register the weight and bias tensors as "buffers". This does two things:
# makes them behave correctly when we call .to(...) on the module, and makes
# them explicitely _not_ Parameters of the model (which might get updated) if
# a user tries to train with this model.
self.register_buffer(
"weights",
torch.randn(
features // 2,
in_channels,
kernel_size,
kernel_size,
requires_grad=False,
),
)
self.register_buffer(
"biases",
torch.zeros( # type: ignore[attr-defined]
features // 2, requires_grad=False
)
+ bias,
)

def forward(self, x: Tensor) -> Tensor:
"""Forward pass of the RCF model.
Args:
x: a tensor with shape (B, C, H, W)
Returns:
a tensor of size (B, ``self.num_features``)
"""
x1a = F.relu(
F.conv2d(x, self.weights, bias=self.biases, stride=1, padding=0),
inplace=True,
)
x1b = F.relu(
-F.conv2d(x, self.weights, bias=self.biases, stride=1, padding=0),
inplace=False,
)

x1a = F.adaptive_avg_pool2d(x1a, (1, 1)).squeeze()
x1b = F.adaptive_avg_pool2d(x1b, (1, 1)).squeeze()

if len(x1a.shape) == 1: # case where we passed a single input
output = torch.cat((x1a, x1b), dim=0) # type: ignore[attr-defined]
return cast(Tensor, output)
else: # case where we passed a batch of > 1 inputs
assert len(x1a.shape) == 2
output = torch.cat((x1a, x1b), dim=1) # type: ignore[attr-defined]
return cast(Tensor, output)

0 comments on commit 699dfec

Please sign in to comment.