diff --git a/smriprep/workflows/anatomical.py b/smriprep/workflows/anatomical.py index 656f1b4579..95fa8ac52f 100644 --- a/smriprep/workflows/anatomical.py +++ b/smriprep/workflows/anatomical.py @@ -320,22 +320,6 @@ def _check_img(img): omp_nthreads=omp_nthreads, normalization_quality='precise' if not debug else 'testing') - # 3. Brain tissue segmentation - FAST produces: 0 (bg), 1 (wm), 2 (csf), 3 (gm) - t1w_dseg = pe.Node(fsl.FAST(segments=True, no_bias=True, probability_maps=True), - name='t1w_dseg', mem_gb=3) - - # Change LookUp Table - BIDS wants: 0 (bg), 1 (gm), 2 (wm), 3 (csf) - lut_t1w_dseg = pe.Node(niu.Function(function=_apply_bids_lut), - name='lut_t1w_dseg') - lut_t1w_dseg.inputs.lut = (0, 3, 1, 2) # Maps: 0 -> 0, 3 -> 1, 1 -> 2, 2 -> 3. - - workflow.connect([ - (buffernode, t1w_dseg, [('t1w_brain', 'in_files')]), - (t1w_dseg, lut_t1w_dseg, [('tissue_class_map', 'in_dseg')]), - (t1w_dseg, outputnode, [('probability_maps', 't1w_tpms')]), - (lut_t1w_dseg, outputnode, [('out', 't1w_dseg')]), - ]) - # 4. Spatial normalization anat_norm_wf = init_anat_norm_wf( debug=debug, @@ -363,10 +347,6 @@ def _check_img(img): (brain_extraction_wf, anat_norm_wf, [ (('outputnode.bias_corrected', _pop), 'inputnode.moving_image')]), (buffernode, anat_norm_wf, [('t1w_mask', 'inputnode.moving_mask')]), - (lut_t1w_dseg, anat_norm_wf, [ - ('out', 'inputnode.moving_segmentation')]), - (t1w_dseg, anat_norm_wf, [ - ('probability_maps', 'inputnode.moving_tpms')]), (anat_norm_wf, outputnode, [ ('poutputnode.standardized', 'std_preproc'), ('poutputnode.std_mask', 'std_mask'), @@ -378,6 +358,16 @@ def _check_img(img): ]), ]) + # Change LookUp Table - BIDS wants: 0 (bg), 1 (gm), 2 (wm), 3 (csf) + lut_t1w_dseg = pe.Node(niu.Function(function=_apply_bids_lut), + name='lut_t1w_dseg') + + workflow.connect([ + (lut_t1w_dseg, anat_norm_wf, [ + ('out', 'inputnode.moving_segmentation')]), + (lut_t1w_dseg, outputnode, [('out', 't1w_dseg')]), + ]) + # Connect reportlets workflow.connect([ (inputnode, anat_reports_wf, [ @@ -423,13 +413,28 @@ def _check_img(img): ]) if not freesurfer: # Flag --fs-no-reconall is set - return + # Brain tissue segmentation - FAST produces: 0 (bg), 1 (wm), 2 (csf), 3 (gm) + t1w_dseg = pe.Node(fsl.FAST(segments=True, no_bias=True, probability_maps=True), + name='t1w_dseg', mem_gb=3) + lut_t1w_dseg.inputs.lut = (0, 3, 1, 2) # Maps: 0 -> 0, 3 -> 1, 1 -> 2, 2 -> 3. + workflow.connect([ (brain_extraction_wf, buffernode, [ (('outputnode.out_file', _pop), 't1w_brain'), ('outputnode.out_mask', 't1w_mask')]), + (buffernode, t1w_dseg, [('t1w_brain', 'in_files')]), + (t1w_dseg, lut_t1w_dseg, [('tissue_class_map', 'in_dseg')]), + (t1w_dseg, anat_norm_wf, [ + ('probability_maps', 'inputnode.moving_tpms')]), + (t1w_dseg, outputnode, [('probability_maps', 't1w_tpms')]), ]) return workflow + # Map FS' aseg labels onto three-tissue segmentation + lut_t1w_dseg.inputs.lut = _aseg_two_three() + split_seg = pe.Node(niu.Function(function=_split_segments), + name='split_seg') + # 5. Surface reconstruction (--fs-no-reconall not set) surface_recon_wf = init_surface_recon_wf(name='surface_recon_wf', omp_nthreads=omp_nthreads, hires=hires) @@ -449,6 +454,12 @@ def _check_img(img): (('outputnode.bias_corrected', _pop), 'in_file')]), (surface_recon_wf, applyrefined, [ ('outputnode.out_brainmask', 'mask_file')]), + (surface_recon_wf, lut_t1w_dseg, [ + ('outputnode.out_aseg', 'in_dseg')]), + (lut_t1w_dseg, split_seg, [('out', 'in_file')]), + (split_seg, anat_norm_wf, [ + ('out', 'inputnode.moving_tpms')]), + (split_seg, outputnode, [('out', 't1w_tpms')]), (surface_recon_wf, outputnode, [ ('outputnode.subjects_dir', 'subjects_dir'), ('outputnode.subject_id', 'subject_id'), @@ -620,3 +631,55 @@ def _pop(inlist): if isinstance(inlist, (list, tuple)): return inlist[0] return inlist + + +def _aseg_two_three(): + import numpy as np + # Base struct + aseg_lut = np.zeros((256,), dtype="int") + # GM + aseg_lut[3] = 1 + aseg_lut[8:14] = 1 + aseg_lut[17:21] = 1 + aseg_lut[26:40] = 1 + aseg_lut[42] = 1 + aseg_lut[47:73] = 1 + + # CSF + aseg_lut[4:6] = 3 + aseg_lut[14:16] = 3 + aseg_lut[24] = 3 + aseg_lut[43:45] = 3 + aseg_lut[72] = 3 + + # WM + aseg_lut[2] = 2 + aseg_lut[7] = 2 + aseg_lut[16] = 2 + aseg_lut[28] = 2 + aseg_lut[41] = 2 + aseg_lut[46] = 2 + aseg_lut[60] = 2 + aseg_lut[77:80] = 2 + aseg_lut[250:256] = 2 + return tuple(aseg_lut) + + +def _split_segments(in_file): + from pathlib import Path + import numpy as np + import nibabel as nb + + segimg = nb.load(in_file) + data = np.int16(segimg.dataobj) + hdr = segimg.header.copy() + hdr.set_data_dtype('uint8') + + out_files = [] + for i, label in enumerate(("CSF", "WM", "GM")): + out_fname = str((Path() / f"aseg_label-{label}_mask.nii.gz").absolute()) + segment = (data == i + 1).astype('uint8') + segimg.__class__(segment, segimg.affine, hdr).to_filename(out_fname) + out_files.append(out_fname) + + return out_files