Skip to content

Commit

Permalink
Merge pull request #10 from colon3ltocard/unetrpp_add_linear_upsampling
Browse files Browse the repository at this point in the history
Adds Linear Upsampling to UnetR++
  • Loading branch information
LBerth authored Sep 6, 2024
2 parents 7982776 + 1bcd87d commit bc0cdfd
Show file tree
Hide file tree
Showing 3 changed files with 55 additions and 22 deletions.
5 changes: 5 additions & 0 deletions config/models/unetrpp8512.json
Original file line number Diff line number Diff line change
@@ -0,0 +1,5 @@
{
"num_heads": 8,
"hidden_size": 512
}

6 changes: 6 additions & 0 deletions config/models/unetrpp8512_linear_up.json
Original file line number Diff line number Diff line change
@@ -0,0 +1,6 @@
{
"num_heads": 8,
"hidden_size": 512,
"linear_upsampling": true
}

66 changes: 44 additions & 22 deletions py4cast/models/vision/unetrpp.py
Original file line number Diff line number Diff line change
Expand Up @@ -409,6 +409,7 @@ def __init__(
out_size: int = 0,
depth: int = 3,
conv_decoder: bool = False,
linear_upsampling: bool = False,
) -> None:
"""
Args:
Expand All @@ -427,29 +428,45 @@ def __init__(
super().__init__()
padding = get_padding(upsample_kernel_size, upsample_kernel_size)
if spatial_dims == 2:
self.transp_conv = nn.ConvTranspose2d(
in_channels,
out_channels,
kernel_size=upsample_kernel_size,
stride=upsample_kernel_size,
padding=padding,
output_padding=get_output_padding(
upsample_kernel_size, upsample_kernel_size, padding
),
dilation=1,
)
if linear_upsampling:
self.transp_conv = nn.Sequential(
nn.UpsamplingBilinear2d(scale_factor=upsample_kernel_size),
nn.Conv2d(
in_channels, out_channels, kernel_size=kernel_size, padding=1
),
)
else:
self.transp_conv = nn.ConvTranspose2d(
in_channels,
out_channels,
kernel_size=upsample_kernel_size,
stride=upsample_kernel_size,
padding=padding,
output_padding=get_output_padding(
upsample_kernel_size, upsample_kernel_size, padding
),
dilation=1,
)
else:
self.transp_conv = nn.ConvTranspose3d(
in_channels,
out_channels,
kernel_size=upsample_kernel_size,
stride=upsample_kernel_size,
padding=padding,
output_padding=get_output_padding(
upsample_kernel_size, upsample_kernel_size, padding
),
dilation=1,
)
if linear_upsampling:
self.transp_conv = nn.Sequential(
nn.Upsample(scale_factor=upsample_kernel_size, mode="trilinear"),
nn.Conv3d(
in_channels, out_channels, kernel_size=kernel_size, padding=1
),
)
else:
self.transp_conv = nn.ConvTranspose3d(
in_channels,
out_channels,
kernel_size=upsample_kernel_size,
stride=upsample_kernel_size,
padding=padding,
output_padding=get_output_padding(
upsample_kernel_size, upsample_kernel_size, padding
),
dilation=1,
)

# 4 feature resolution stages, each consisting of multiple residual blocks
self.decoder_block = nn.ModuleList()
Expand Down Expand Up @@ -512,6 +529,7 @@ class UNETRPPSettings:
conv_op: str = "Conv2d"
do_ds = False
spatial_dims = 2
linear_upsampling: bool = False


class UNETRPP(ModelABC, nn.Module):
Expand Down Expand Up @@ -605,6 +623,7 @@ def __init__(
upsample_kernel_size=2,
norm_name=settings.norm_name,
out_size=no_pixels // 16,
linear_upsampling=settings.linear_upsampling,
)
self.decoder4 = UnetrUpBlock(
spatial_dims=settings.spatial_dims,
Expand All @@ -614,6 +633,7 @@ def __init__(
upsample_kernel_size=2,
norm_name=settings.norm_name,
out_size=no_pixels // 4,
linear_upsampling=settings.linear_upsampling,
)
self.decoder3 = UnetrUpBlock(
spatial_dims=settings.spatial_dims,
Expand All @@ -623,6 +643,7 @@ def __init__(
upsample_kernel_size=2,
norm_name=settings.norm_name,
out_size=no_pixels,
linear_upsampling=settings.linear_upsampling,
)
self.decoder2 = UnetrUpBlock(
spatial_dims=settings.spatial_dims,
Expand All @@ -633,6 +654,7 @@ def __init__(
norm_name=settings.norm_name,
out_size=no_pixels * 16,
conv_decoder=True,
linear_upsampling=settings.linear_upsampling,
)
self.out1 = UnetOutBlock(
spatial_dims=settings.spatial_dims,
Expand Down

0 comments on commit bc0cdfd

Please sign in to comment.