Skip to content

Commit

Permalink
Merge pull request #74 from ANTsX/ShivaOutput
Browse files Browse the repository at this point in the history
ENH:  Generalize number of outputs.
  • Loading branch information
ntustison authored Oct 7, 2024
2 parents 94d1349 + 4b77b46 commit 1644d76
Showing 1 changed file with 10 additions and 2 deletions.
12 changes: 10 additions & 2 deletions R/createCustomUnetModel.R
Original file line number Diff line number Diff line change
Expand Up @@ -312,6 +312,8 @@ createHippMapp3rUnetModel3D <- function( inputImageSize,
#'
#' @param numberOfModalities Specifies number of channels in the
#' architecture.
#' @param numberOfOutputs Specifies the number of outputs per voxel.
#' Determines final activation function (1 = sigmoid, >1 = softmax).
#' @return a u-net keras model
#' @author Tustison NJ
#' @examples
Expand All @@ -322,7 +324,8 @@ createHippMapp3rUnetModel3D <- function( inputImageSize,
#' }
#' @import keras
#' @export
createShivaUnetModel3D <- function( numberOfModalities = 1 )
createShivaUnetModel3D <- function( numberOfModalities = 1,
numberOfOutputs = 1 )
{
K <- tensorflow::tf$keras$backend

Expand Down Expand Up @@ -424,7 +427,12 @@ createShivaUnetModel3D <- function( numberOfModalities = 1 )
outputs <- outputs %>% layer_batch_normalization()
outputs <- outputs %>% layer_activation( "swish" )

outputs <- outputs %>% layer_conv_3d( 1, kernel_size = 1L, activation = "sigmoid", padding = 'same' )
activation = 'sigmoid'
if( numberOfOutputs > 1 )
{
activation = 'softmax'
}
outputs <- outputs %>% layer_conv_3d( 1, kernel_size = numberOfOutputs, activation = activation, padding = 'same' )

unetModel <- keras_model( inputs = inputs, outputs = outputs )

Expand Down

0 comments on commit 1644d76

Please sign in to comment.