diff --git a/R/createCustomUnetModel.R b/R/createCustomUnetModel.R index 150bdd1..4788abe 100644 --- a/R/createCustomUnetModel.R +++ b/R/createCustomUnetModel.R @@ -312,6 +312,8 @@ createHippMapp3rUnetModel3D <- function( inputImageSize, #' #' @param numberOfModalities Specifies number of channels in the #' architecture. +#' @param numberOfOutputs Specifies the number of outputs per voxel. +#' Determines final activation function (1 = sigmoid, >1 = softmax). #' @return a u-net keras model #' @author Tustison NJ #' @examples @@ -322,7 +324,8 @@ createHippMapp3rUnetModel3D <- function( inputImageSize, #' } #' @import keras #' @export -createShivaUnetModel3D <- function( numberOfModalities = 1 ) +createShivaUnetModel3D <- function( numberOfModalities = 1, + numberOfOutputs = 1 ) { K <- tensorflow::tf$keras$backend @@ -424,7 +427,12 @@ createShivaUnetModel3D <- function( numberOfModalities = 1 ) outputs <- outputs %>% layer_batch_normalization() outputs <- outputs %>% layer_activation( "swish" ) - outputs <- outputs %>% layer_conv_3d( 1, kernel_size = 1L, activation = "sigmoid", padding = 'same' ) + activation = 'sigmoid' + if( numberOfOutputs > 1 ) + { + activation = 'softmax' + } + outputs <- outputs %>% layer_conv_3d( 1, kernel_size = numberOfOutputs, activation = activation, padding = 'same' ) unetModel <- keras_model( inputs = inputs, outputs = outputs )