From 4b77b4666e7eade8ffeabdffb6a1c6b1affecbc8 Mon Sep 17 00:00:00 2001 From: Nick Tustison Date: Mon, 7 Oct 2024 09:31:05 -0700 Subject: [PATCH] ENH: Generalize number of outputs. --- R/createCustomUnetModel.R | 12 ++++++++++-- 1 file changed, 10 insertions(+), 2 deletions(-) diff --git a/R/createCustomUnetModel.R b/R/createCustomUnetModel.R index 150bdd1..4788abe 100644 --- a/R/createCustomUnetModel.R +++ b/R/createCustomUnetModel.R @@ -312,6 +312,8 @@ createHippMapp3rUnetModel3D <- function( inputImageSize, #' #' @param numberOfModalities Specifies number of channels in the #' architecture. +#' @param numberOfOutputs Specifies the number of outputs per voxel. +#' Determines final activation function (1 = sigmoid, >1 = softmax). #' @return a u-net keras model #' @author Tustison NJ #' @examples @@ -322,7 +324,8 @@ createHippMapp3rUnetModel3D <- function( inputImageSize, #' } #' @import keras #' @export -createShivaUnetModel3D <- function( numberOfModalities = 1 ) +createShivaUnetModel3D <- function( numberOfModalities = 1, + numberOfOutputs = 1 ) { K <- tensorflow::tf$keras$backend @@ -424,7 +427,12 @@ createShivaUnetModel3D <- function( numberOfModalities = 1 ) outputs <- outputs %>% layer_batch_normalization() outputs <- outputs %>% layer_activation( "swish" ) - outputs <- outputs %>% layer_conv_3d( 1, kernel_size = 1L, activation = "sigmoid", padding = 'same' ) + activation = 'sigmoid' + if( numberOfOutputs > 1 ) + { + activation = 'softmax' + } + outputs <- outputs %>% layer_conv_3d( 1, kernel_size = numberOfOutputs, activation = activation, padding = 'same' ) unetModel <- keras_model( inputs = inputs, outputs = outputs )