From 4eff98df2741e553796af28e47cbd1e47dc6fd07 Mon Sep 17 00:00:00 2001 From: ntustison Date: Thu, 19 Sep 2024 17:45:55 -0700 Subject: [PATCH] WIP: DA --- R/deepAtropos.R | 439 ++++++++++++++++++++++++++++----------- R/getANTsXNetData.R | 12 +- R/getPretrainedNetwork.R | 16 +- 3 files changed, 347 insertions(+), 120 deletions(-) diff --git a/R/deepAtropos.R b/R/deepAtropos.R index 2bc4b31..aa1e0b7 100644 --- a/R/deepAtropos.R +++ b/R/deepAtropos.R @@ -45,153 +45,360 @@ deepAtropos <- function( t1, doPreprocessing = TRUE, useSpatialPriors = 1, verbose = FALSE, debug = FALSE ) { - if( t1@dimension != 3 ) - { - stop( "Input image dimension must be 3." ) - } + if( ! is.list( t1 ) ) + { + if( t1@dimension != 3 ) + { + stop( "Input image dimension must be 3." ) + } - ################################ - # - # Preprocess image - # - ################################ - - t1Preprocessed <- t1 - if( doPreprocessing ) - { - t1Preprocessing <- preprocessBrainImage( t1, - truncateIntensity = c( 0.01, 0.99 ), - brainExtractionModality = "t1", - template = "croppedMni152", - templateTransformType = "antsRegistrationSyNQuickRepro[a]", - doBiasCorrection = TRUE, - doDenoising = TRUE, - verbose = verbose ) - t1Preprocessed <- t1Preprocessing$preprocessedImage * t1Preprocessing$brainMask - } + ################################ + # + # Preprocess image + # + ################################ - ################################ - # - # Build model and load weights - # - ################################ + t1Preprocessed <- t1 + if( doPreprocessing ) + { + t1Preprocessing <- preprocessBrainImage( t1, + truncateIntensity = c( 0.01, 0.99 ), + brainExtractionModality = "t1", + template = "croppedMni152", + templateTransformType = "antsRegistrationSyNQuickRepro[a]", + doBiasCorrection = TRUE, + doDenoising = TRUE, + verbose = verbose ) + t1Preprocessed <- t1Preprocessing$preprocessedImage * t1Preprocessing$brainMask + } - patchSize <- c( 112L, 112L, 112L ) - strideLength <- dim( t1Preprocessed ) - patchSize + ################################ + # + # Build model and load weights + # + ################################ - classes <- c( "background", "csf", "gray matter", "white matter", - "deep gray matter", "brain stem", "cerebellum" ) + patchSize <- c( 112L, 112L, 112L ) + strideLength <- dim( t1Preprocessed ) - patchSize - mniPriors <- NULL - channelSize <- 1 - if( useSpatialPriors != 0 ) - { - mniPriors <- splitNDImageToList( antsImageRead( getANTsXNetData( "croppedMni152Priors" ) ) ) - for( i in seq.int( length( mniPriors ) ) ) + classes <- c( "background", "csf", "gray matter", "white matter", + "deep gray matter", "brain stem", "cerebellum" ) + + mniPriors <- NULL + channelSize <- 1 + if( useSpatialPriors != 0 ) { - mniPriors[[i]] <- antsCopyImageInfo( t1Preprocessed, mniPriors[[i]] ) + mniPriors <- splitNDImageToList( antsImageRead( getANTsXNetData( "croppedMni152Priors" ) ) ) + for( i in seq.int( length( mniPriors ) ) ) + { + mniPriors[[i]] <- antsCopyImageInfo( t1Preprocessed, mniPriors[[i]] ) + } + channelSize <- 2 } - channelSize <- 2 - } - unetModel <- createUnetModel3D( c( patchSize, channelSize ), - numberOfOutputs = length( classes ), mode = 'classification', - numberOfLayers = 4, numberOfFiltersAtBaseLayer = 16, dropoutRate = 0.0, - convolutionKernelSize = c( 3, 3, 3 ), deconvolutionKernelSize = c( 2, 2, 2 ), - weightDecay = 1e-5, additionalOptions = c( "attentionGating" ) ) + unetModel <- createUnetModel3D( c( patchSize, channelSize ), + numberOfOutputs = length( classes ), mode = 'classification', + numberOfLayers = 4, numberOfFiltersAtBaseLayer = 16, dropoutRate = 0.0, + convolutionKernelSize = c( 3, 3, 3 ), deconvolutionKernelSize = c( 2, 2, 2 ), + weightDecay = 1e-5, additionalOptions = c( "attentionGating" ) ) + + if( verbose ) + { + cat( "DeepAtropos: retrieving model weights.\n" ) + } + weightsFileName <- '' + if( useSpatialPriors == 0 ) + { + weightsFileName <- getPretrainedNetwork( "sixTissueOctantBrainSegmentation" ) + } else if( useSpatialPriors == 1 ) { + weightsFileName <- getPretrainedNetwork( "sixTissueOctantBrainSegmentationWithPriors1" ) + } else { + stop( "useSpatialPriors must be a 0 or 1" ) + } + load_model_weights_hdf5( unetModel, filepath = weightsFileName ) + + ################################ + # + # Do prediction and normalize to native space + # + ################################ + + if( verbose ) + { + message( "Prediction.\n" ) + } + + t1Preprocessed <- ( t1Preprocessed - mean( t1Preprocessed ) ) / sd( t1Preprocessed ) + imagePatches <- extractImagePatches( t1Preprocessed, patchSize, maxNumberOfPatches = "all", + strideLength = strideLength, returnAsArray = TRUE ) + batchX <- array( data = 0, dim = c( dim( imagePatches ), channelSize ) ) + batchX[,,,,1] <- imagePatches + if( channelSize > 1 ) + { + priorPatches <- extractImagePatches( mniPriors[[7]], patchSize, maxNumberOfPatches = "all", + strideLength = strideLength, returnAsArray = TRUE ) + batchX[,,,,2] <- priorPatches + } + predictedData <- unetModel %>% predict( batchX, verbose = verbose ) + + probabilityImages <- list() + for( i in seq.int( dim( predictedData )[5] ) ) + { + if( verbose ) + { + cat( "Reconstructing image ", classes[i], "\n" ) + } + reconstructedImage <- reconstructImageFromPatches( predictedData[,,,,i], + domainImage = t1Preprocessed, strideLength = strideLength ) + if( doPreprocessing ) + { + probabilityImages[[i]] <- antsApplyTransforms( fixed = t1, moving = reconstructedImage, + transformlist = t1Preprocessing$templateTransforms$invtransforms, + whichtoinvert = c( TRUE ), interpolator = "linear", verbose = verbose ) + } else { + probabilityImages[[i]] <- reconstructedImage + } + } + + imageMatrix <- imageListToMatrix( probabilityImages, t1 * 0 + 1 ) + segmentationMatrix <- matrix( apply( imageMatrix, 2, which.max ), nrow = 1 ) + segmentationImage <- matrixToImages( segmentationMatrix, t1 * 0 + 1 )[[1]] - 1 + + results <- list( segmentationImage = segmentationImage, + probabilityImages = probabilityImages ) + + # debugging + + if( debug ) + { + inputImage <- unetModel$input + featureLayer <- unetModel$layers[[length( unetModel$layers ) - 1]] + featureFunction <- keras::backend()$`function`( list( inputImage ), list( featureLayer$output ) ) + featureBatch <- featureFunction( list( batchX[1,,,,,drop = FALSE] ) ) + + featureImagesList <- decodeUnet( featureBatch[[1]], croppedImage ) + + featureImages <- list() + for( i in seq.int( length( featureImagesList[[1]] ) ) ) + { + decroppedImage <- decropImage( featureImagesList[[1]][[i]], t1Preprocessed * 0 ) + if( doPreprocessing ) + { + featureImages[[i]] <- antsApplyTransforms( fixed = t1, moving = decroppedImage, + transformlist = t1Preprocessing$templateTransforms$invtransforms, + whichtoinvert = c( TRUE ), interpolator = "linear", verbose = verbose ) + } else { + featureImages[[i]] <- decroppedImage + } + } + results[['featureImagesLastLayer']] <- featureImages + } + return( results ) - if( verbose ) - { - cat( "DeepAtropos: retrieving model weights.\n" ) - } - weightsFileName <- '' - if( useSpatialPriors == 0 ) - { - weightsFileName <- getPretrainedNetwork( "sixTissueOctantBrainSegmentation" ) - } else if( useSpatialPriors == 1 ) { - weightsFileName <- getPretrainedNetwork( "sixTissueOctantBrainSegmentationWithPriors1" ) } else { - stop( "useSpatialPriors must be a 0 or 1" ) - } - load_model_weights_hdf5( unetModel, filepath = weightsFileName ) - ################################ - # - # Do prediction and normalize to native space - # - ################################ + if( length( t1 ) != 3 ) + { + stop( paste0( "Length of input list must be 3. Input images are (in order): [T1, T2, FA].", + "If a particular modality or modalities is not available, use NULL as a placeholder." ) ) + } + + if( is.null( t1[[1]] ) ) + { + stop( "T1 modality must be specified." ) + } - if( verbose ) - { - message( "Prediction.\n" ) - } + whichNetwork <- "" + inputImages <- list() + inputImages[[1]] <- t1[[1]] + if( ! is.null( t1[[2]] ) && ! is.null( t1[[3]] ) ) + { + whichNetwork = "t1_t2_fa" + inputImages[[2]] <- t1[[2]] + inputImages[[3]] <- t1[[3]] + } else if( ! is.null( t1[[2]] ) && is.null( t1[[3]] ) ) { + whichNetwork = "t1_t2" + inputImages[[2]] <- t1[[2]] + } else if( ! is.null( t1[[2]] ) && is.null( t1[[3]] ) ) { + whichNetwork = "t1_fa" + inputImages[[2]] <- t1[[3]] + } else { + whichNetwork = "t1" + } - t1Preprocessed <- ( t1Preprocessed - mean( t1Preprocessed ) ) / sd( t1Preprocessed ) - imagePatches <- extractImagePatches( t1Preprocessed, patchSize, maxNumberOfPatches = "all", - strideLength = strideLength, returnAsArray = TRUE ) - batchX <- array( data = 0, dim = c( dim( imagePatches ), channelSize ) ) - batchX[,,,,1] <- imagePatches - if( channelSize > 1 ) - { - priorPatches <- extractImagePatches( mniPriors[[7]], patchSize, maxNumberOfPatches = "all", - strideLength = strideLength, returnAsArray = TRUE ) - batchX[,,,,2] <- priorPatches - } - predictedData <- unetModel %>% predict( batchX, verbose = verbose ) + if( verbose ) + { + cat( "Prediction using ", whichNetwork ) + } + + ################################ + # + # Preprocess image + # + ################################ + + hcpT1Template <- antsImageRead( getANTsXNetData( "hcpyaT1Template" ) ) + hcpT2Template <- antsImageRead( getANTsXNetData( "hcpyaT2Template" ) ) + hcpFaTemplate <- antsImageRead( getANTsXNetData( "hcpyaFATemplate" ) ) + hcpTemplateBrainMask <- antsImageRead( getANTsXNetData( "hcpyaTemplateBrainMask" ) ) + hcpTemplateBrainSegmentation <- antsImageRead( getANTsXNetData( "hcpyaTemplateBrainSegmentation" ) ) + + hcpTemplates <- list() + hcpTemplates[[1]] <- hcpT1Template * hcpTemplateBrainMask + hcpTemplates[[2]] <- hcpT2Template * hcpTemplateBrainMask + hcpTemplates[[3]] <- hcpFaTemplate * hcpTemplateBrainMask + + reg <- NULL + t1Mask <- NULL + preprocessedImages <- list() + for( i in seq.int( length( inputImages ) ) ) + { + n4 <- n4BiasFieldCorrection( inputImages[[i]], mask = inputImages[[i]] * 0 + 1, + convergence = list( iters = c( 50, 50, 50, 50 ), tol = 0.0 ), + rescaleIntensities = TRUE, + verbose = verbose ) + if( i == 1 ) + { + t1Mask <- brainExtraction( inputImages[[1]], modality = "t1", verbose = verbose ) + n4 <- n4 * t1Mask + reg <- antsRegistration( hcpTemplates[[i]], n4, + typeofTransform = "antsRegistrationSyNQuick[a]", + verbose = verbose ) + preprocessedImages[[i]] <- antsImageClone( reg$warpedmovout ) + } else { + n4 <- n4 * t1Mask + n4 <- antsApplyTransforms( hcpTemplates[[i]], n4, + transformlist = reg$fwdtransforms, + verbose = verbose ) + preprocessedImages[[i]] <- n4 + } + preprocessedImages[[i]] <- iMath( preprocessedImages[[i]], "Normalize" ) + } + + ################################ + # + # Build model and load weights + # + ################################ + + patchSize <- c( 192L, 224L, 192L ) + strideLength <- c( dim( hcpTemplates[[1]] )[1] - patchSize[1], + dim( hcpTemplates[[1]] )[2] - patchSize[2], + dim( hcpTemplates[[1]] )[3] - patchSize[3] ) + + hcpTemplatePriors <- list() + for( i in seq.int( 6 ) ) + { + prior <- thresholdImage( hcpTemplateBrainSegmentation, i, i, 1, 0 ) + priorSmooth <- smoothImage( prior, 1.0 ) + hcpTemplatePriors[[i]] <- priorSmooth + } + + classes <- c( "background", "csf", "gray matter", "white matter", + "deep gray matter", "brain stem", "cerebellum" ) + numberOfClassificationLabels <- length( classes ) + channelSize <- length( inputImages ) + length( hcpTemplatePriors ) + + unetModel <- createUnetModel3D( c( patchSize, channelSize ), + numberOfOutputs = numberOfClassificationLabels, mode = 'classification', + numberOfFilters = c( 16, 32, 64, 128 ), dropoutRate = 0.0, + convolutionKernelSize = c( 3, 3, 3 ), deconvolutionKernelSize = c( 2, 2, 2 ), + weightDecay = 0.0 ) - probabilityImages <- list() - for( i in seq.int( dim( predictedData )[5] ) ) - { if( verbose ) { - cat( "Reconstructing image ", classes[i], "\n" ) + cat( "DeepAtropos: retrieving model weights.\n" ) } - reconstructedImage <- reconstructImageFromPatches( predictedData[,,,,i], - domainImage = t1Preprocessed, strideLength = strideLength ) - if( doPreprocessing ) + weightsFileName <- '' + if( whichNetwork == "t1" ) { - probabilityImages[[i]] <- antsApplyTransforms( fixed = t1, moving = reconstructedImage, - transformlist = t1Preprocessing$templateTransforms$invtransforms, - whichtoinvert = c( TRUE ), interpolator = "linear", verbose = verbose ) - } else { - probabilityImages[[i]] <- reconstructedImage + weightsFileName <- getPretrainedNetwork( "DeepAtroposHcpT1Weights" ) + } else if( whichNetwork == "t1_t2" ) { + weightsFileName <- getPretrainedNetwork( "DeepAtroposHcpT1T2Weights" ) + } else if( whichNetwork == "t1_fa" ) { + weightsFileName <- getPretrainedNetwork( "DeepAtroposHcpT1FAWeights" ) + } else if( whichNetwork == "t1_t2_fa" ) { + weightsFileName <- getPretrainedNetwork( "DeepAtroposHcpT1T2FAWeights" ) } - } + load_model_weights_hdf5( unetModel, filepath = weightsFileName ) - imageMatrix <- imageListToMatrix( probabilityImages, t1 * 0 + 1 ) - segmentationMatrix <- matrix( apply( imageMatrix, 2, which.max ), nrow = 1 ) - segmentationImage <- matrixToImages( segmentationMatrix, t1 * 0 + 1 )[[1]] - 1 + ################################ + # + # Do prediction and normalize to native space + # + ################################ - results <- list( segmentationImage = segmentationImage, - probabilityImages = probabilityImages ) + if( verbose ) + { + message( "Prediction.\n" ) + } - # debugging + batchX <- array( data = 0, dim = c( 8, patchSize, channelSize ) ) - if( debug ) - { - inputImage <- unetModel$input - featureLayer <- unetModel$layers[[length( unetModel$layers ) - 1]] - featureFunction <- keras::backend()$`function`( list( inputImage ), list( featureLayer$output ) ) - featureBatch <- featureFunction( list( batchX[1,,,,,drop = FALSE] ) ) + imagePatchesList <- list() + for( i in seq.int( length( preprocessedImages ) ) ) + { + imagePatches <- extractImagePatches( preprocessedImages[[i]], + patchSize = patchSize, + maxNumberOfPatches = "all", + strideLength = strideLength, + returnAsArray = TRUE ) + imagePatchesList[[i]] <- imagePatches + } + for( i in seq.int( length( preprocessedImages ) ) ) + { + for( j in seq.int( 8 ) ) + { + batchX[j,,,,i] <- imagePatchesList[[i]][j,,,] + } + } - featureImagesList <- decodeUnet( featureBatch[[1]], croppedImage ) + priorsPatchesList <- list() + for( i in seq.int( length( hcpTemplatePriors ) ) ) + { + priorPatches <- extractImagePatches( hcpTemplatePriors[[i]], + patchSize = patchSize, + maxNumberOfPatches = "all", + strideLength = strideLength, + returnAsArray = TRUE ) + priorsPatchesList[[i]] <- priorPatches + } + for( i in seq.int( length( hcpTemplatePriors ) ) ) + { + for( j in seq.int( 8 ) ) + { + batchX[j,,,,length( preprocessedImages ) + i] <- priorsPatchesList[[i]][j,,,] + } + } + predictedData <- unetModel %>% predict( batchX, verbose = verbose ) - featureImages <- list() - for( i in seq.int( length( featureImagesList[[1]] ) ) ) + probabilityImages <- list() + for( i in seq.int( dim( predictedData )[5] ) ) { - decroppedImage <- decropImage( featureImagesList[[1]][[i]], t1Preprocessed * 0 ) + if( verbose ) + { + cat( "Reconstructing image ", classes[i], "\n" ) + } + reconstructedImage <- reconstructImageFromPatches( predictedData[,,,,i], + domainImage = hcpTemplates[[1]], strideLength = strideLength ) if( doPreprocessing ) { - featureImages[[i]] <- antsApplyTransforms( fixed = t1, moving = decroppedImage, - transformlist = t1Preprocessing$templateTransforms$invtransforms, + probabilityImages[[i]] <- antsApplyTransforms( fixed = inputImages[[1]], + moving = reconstructedImage, + transformlist = reg$invtransforms, whichtoinvert = c( TRUE ), interpolator = "linear", verbose = verbose ) } else { - featureImages[[i]] <- decroppedImage + probabilityImages[[i]] <- reconstructedImage } } - results[['featureImagesLastLayer']] <- featureImages - } - return( results ) -} + imageMatrix <- imageListToMatrix( probabilityImages, inputImages[[1]] * 0 + 1 ) + segmentationMatrix <- matrix( apply( imageMatrix, 2, which.max ), nrow = 1 ) + segmentationImage <- matrixToImages( segmentationMatrix, inputImages[[1]] * 0 + 1 )[[1]] - 1 + results <- list( segmentationImage = segmentationImage, + probabilityImages = probabilityImages ) + + return( results ) + } +} diff --git a/R/getANTsXNetData.R b/R/getANTsXNetData.R index 07022ab..1b034a8 100644 --- a/R/getANTsXNetData.R +++ b/R/getANTsXNetData.R @@ -98,7 +98,12 @@ getANTsXNetData <- function( "DevCCF_P56_MRI_T2_50um_BrainParcellationNickMask", "DevCCF_P56_MRI_T2_50um_BrainParcellationTctMask", "DevCCF_P04_STPT_50um", - "DevCCF_P04_STPT_50um_BrainParcellationJayMask" + "DevCCF_P04_STPT_50um_BrainParcellationJayMask", + "hcpyaT1Template", + "hcpyaT2Template", + "hcpyaFATemplate", + "hcpyaTemplateBrainMask", + "hcpyaTemplateBrainSegmentation" ), targetFileName) { @@ -146,6 +151,11 @@ getANTsXNetData <- function( mraTemplate = "https://figshare.com/ndownloader/files/46406695", mraTemplateBrainMask = "https://figshare.com/ndownloader/files/46406698", mraTemplateVesselPrior = "https://figshare.com/ndownloader/files/46406713", + hcpyaT1Template = "https://figshare.com/ndownloader/files/46746142", + hcpyaT2Template = "https://figshare.com/ndownloader/files/46746334", + hcpyaFATemplate = "https://figshare.com/ndownloader/files/46746349", + hcpyaTemplateBrainMask = "https://figshare.com/ndownloader/files/46746388", + hcpyaTemplateBrainSegmentation = "https://figshare.com/ndownloader/files/46746367", bsplineT2MouseTemplate = "https://figshare.com/ndownloader/files/44706247", bsplineT2MouseTemplateBrainMask = "https://figshare.com/ndownloader/files/44869285", DevCCF_P56_MRI_T2_50um = "https://figshare.com/ndownloader/files/44706244", diff --git a/R/getPretrainedNetwork.R b/R/getPretrainedNetwork.R index 0bcbc58..9b4d73d 100644 --- a/R/getPretrainedNetwork.R +++ b/R/getPretrainedNetwork.R @@ -153,8 +153,13 @@ getPretrainedNetwork <- function( "wholeHeadInpaintingFLAIR", "wholeHeadInpaintingPatchBasedT1", "wholeHeadInpaintingPatchBasedFLAIR", - "wholeTumorSegmentationT2Flair", - "wholeLungMaskFromVentilation" ), + "wholeTumorSegmentationT2Flair", + "wholeLungMaskFromVentilation", + "DeepAtroposHcpT1Weights", + "DeepAtroposHcpT1T2Weights", + "DeepAtroposHcpT1FAWeights", + "DeepAtroposHcpT1T2FAWeights" + ), targetFileName ) { @@ -304,7 +309,12 @@ getPretrainedNetwork <- function( wholeHeadInpaintingPatchBasedT1 = "https://figshare.com/ndownloader/files/39337442", wholeHeadInpaintingPatchBasedFLAIR = "https://figshare.com/ndownloader/files/39337439", wholeTumorSegmentationT2Flair = "https://ndownloader.figshare.com/files/14087045", - wholeLungMaskFromVentilation = "https://ndownloader.figshare.com/files/28914441" + wholeLungMaskFromVentilation = "https://ndownloader.figshare.com/files/28914441", + DeepAtroposHcpT1Weights = "https://figshare.com/ndownloader/files/49132504", + DeepAtroposHcpT1T2Weights = "https://figshare.com/ndownloader/files/49132498", + DeepAtroposHcpT1FAWeights = "https://figshare.com/ndownloader/files/49132507", + DeepAtroposHcpT1T2FAWeights = "https://figshare.com/ndownloader/files/49132501" + ) if( missing( targetFileName ) )