From 1bcd87d6d4c2d41f49e50fe8ba4083ca4e678328 Mon Sep 17 00:00:00 2001 From: Frank Guibert Date: Thu, 5 Sep 2024 17:06:39 +0200 Subject: [PATCH] Adds Linear Upsampling to UnetR++ --- config/models/unetrpp8512.json | 5 ++ config/models/unetrpp8512_linear_up.json | 6 +++ py4cast/models/vision/unetrpp.py | 66 ++++++++++++++++-------- 3 files changed, 55 insertions(+), 22 deletions(-) create mode 100644 config/models/unetrpp8512.json create mode 100644 config/models/unetrpp8512_linear_up.json diff --git a/config/models/unetrpp8512.json b/config/models/unetrpp8512.json new file mode 100644 index 00000000..6b325f7c --- /dev/null +++ b/config/models/unetrpp8512.json @@ -0,0 +1,5 @@ +{ +"num_heads": 8, +"hidden_size": 512 +} + diff --git a/config/models/unetrpp8512_linear_up.json b/config/models/unetrpp8512_linear_up.json new file mode 100644 index 00000000..866f9c79 --- /dev/null +++ b/config/models/unetrpp8512_linear_up.json @@ -0,0 +1,6 @@ +{ +"num_heads": 8, +"hidden_size": 512, +"linear_upsampling": true +} + diff --git a/py4cast/models/vision/unetrpp.py b/py4cast/models/vision/unetrpp.py index dd9af363..7328e067 100644 --- a/py4cast/models/vision/unetrpp.py +++ b/py4cast/models/vision/unetrpp.py @@ -409,6 +409,7 @@ def __init__( out_size: int = 0, depth: int = 3, conv_decoder: bool = False, + linear_upsampling: bool = False, ) -> None: """ Args: @@ -427,29 +428,45 @@ def __init__( super().__init__() padding = get_padding(upsample_kernel_size, upsample_kernel_size) if spatial_dims == 2: - self.transp_conv = nn.ConvTranspose2d( - in_channels, - out_channels, - kernel_size=upsample_kernel_size, - stride=upsample_kernel_size, - padding=padding, - output_padding=get_output_padding( - upsample_kernel_size, upsample_kernel_size, padding - ), - dilation=1, - ) + if linear_upsampling: + self.transp_conv = nn.Sequential( + nn.UpsamplingBilinear2d(scale_factor=upsample_kernel_size), + nn.Conv2d( + in_channels, out_channels, kernel_size=kernel_size, padding=1 + ), + ) + else: + self.transp_conv = nn.ConvTranspose2d( + in_channels, + out_channels, + kernel_size=upsample_kernel_size, + stride=upsample_kernel_size, + padding=padding, + output_padding=get_output_padding( + upsample_kernel_size, upsample_kernel_size, padding + ), + dilation=1, + ) else: - self.transp_conv = nn.ConvTranspose3d( - in_channels, - out_channels, - kernel_size=upsample_kernel_size, - stride=upsample_kernel_size, - padding=padding, - output_padding=get_output_padding( - upsample_kernel_size, upsample_kernel_size, padding - ), - dilation=1, - ) + if linear_upsampling: + self.transp_conv = nn.Sequential( + nn.Upsample(scale_factor=upsample_kernel_size, mode="trilinear"), + nn.Conv3d( + in_channels, out_channels, kernel_size=kernel_size, padding=1 + ), + ) + else: + self.transp_conv = nn.ConvTranspose3d( + in_channels, + out_channels, + kernel_size=upsample_kernel_size, + stride=upsample_kernel_size, + padding=padding, + output_padding=get_output_padding( + upsample_kernel_size, upsample_kernel_size, padding + ), + dilation=1, + ) # 4 feature resolution stages, each consisting of multiple residual blocks self.decoder_block = nn.ModuleList() @@ -512,6 +529,7 @@ class UNETRPPSettings: conv_op: str = "Conv2d" do_ds = False spatial_dims = 2 + linear_upsampling: bool = False class UNETRPP(ModelABC, nn.Module): @@ -605,6 +623,7 @@ def __init__( upsample_kernel_size=2, norm_name=settings.norm_name, out_size=no_pixels // 16, + linear_upsampling=settings.linear_upsampling, ) self.decoder4 = UnetrUpBlock( spatial_dims=settings.spatial_dims, @@ -614,6 +633,7 @@ def __init__( upsample_kernel_size=2, norm_name=settings.norm_name, out_size=no_pixels // 4, + linear_upsampling=settings.linear_upsampling, ) self.decoder3 = UnetrUpBlock( spatial_dims=settings.spatial_dims, @@ -623,6 +643,7 @@ def __init__( upsample_kernel_size=2, norm_name=settings.norm_name, out_size=no_pixels, + linear_upsampling=settings.linear_upsampling, ) self.decoder2 = UnetrUpBlock( spatial_dims=settings.spatial_dims, @@ -633,6 +654,7 @@ def __init__( norm_name=settings.norm_name, out_size=no_pixels * 16, conv_decoder=True, + linear_upsampling=settings.linear_upsampling, ) self.out1 = UnetOutBlock( spatial_dims=settings.spatial_dims,