diff --git a/.github/workflows/build.yml b/.github/workflows/build.yml new file mode 100644 index 000000000..b2c3c96d9 --- /dev/null +++ b/.github/workflows/build.yml @@ -0,0 +1,28 @@ +name: Build Docker with selected version + +on: + workflow_dispatch: + inputs: + scilpy_commit: + description: Scilpy commit id + required: true + +jobs: + Build_Docker: + + runs-on: ubuntu-latest + + steps: + - uses: actions/checkout@v2 + name: Check out repository + - name: Change scilpy version + run: sed -i '/ENV SCILPY_VERSION=/c\ENV SCILPY_VERSION=${{ github.event.inputs.scilpy_commit }}' containers/Dockerfile + - uses: mr-smithers-excellent/docker-build-push@v3.1 + name: Docker Build & Push + with: + image: scilus/scilpy + tag: dev + dockerfile: containers/Dockerfile + registry: docker.io + username: ${{ secrets.DOCKER_USERNAME }} + password: ${{ secrets.DOCKER_PASSWORD }} diff --git a/Jenkinsfile b/Jenkinsfile index dbc7910db..9dc89863a 100644 --- a/Jenkinsfile +++ b/Jenkinsfile @@ -4,21 +4,11 @@ pipeline { stages { stage('Build') { stages { - stage('Python3.6') { - steps { - withPythonEnv('CPython-3.6') { - sh ''' - pip3 install numpy==1.18.* wheel - pip3 install -e . - ''' - } - } - } stage('Python3.7') { steps { withPythonEnv('CPython-3.7') { sh ''' - pip3 install numpy==1.18.* wheel + pip3 install numpy==1.20.* wheel pip3 install -e . ''' } @@ -31,7 +21,7 @@ pipeline { steps { withPythonEnv('CPython-3.7') { sh ''' - pip3 install numpy==1.18.* wheel + pip3 install numpy==1.20.* wheel pip3 install -e . export MPLBACKEND="agg" export OPENBLAS_NUM_THREADS=1 @@ -55,7 +45,12 @@ pipeline { cleanWs() script { if (env.CHANGE_ID) { - pullRequest.createReviewRequests(['arnaudbore']) + if (pullRequest.createdBy != "arnaudbore"){ + pullRequest.createReviewRequests(['arnaudbore']) + } + else{ + pullRequest.createReviewRequests(['GuillaumeTh']) + } } } } diff --git a/docs/requirements.txt b/docs/requirements.txt index 9b8e74a8a..c551aac3e 100644 --- a/docs/requirements.txt +++ b/docs/requirements.txt @@ -6,7 +6,7 @@ matplotlib==2.2.* nibabel==3.0.* nilearn==0.6.* numpy==1.18.* -Pillow==7.1.* +Pillow==8.2.* pybids==0.10.* pyparsing==2.2.* python-dateutil==2.7.* diff --git a/requirements.txt b/requirements.txt index f6923e64f..a494e2559 100644 --- a/requirements.txt +++ b/requirements.txt @@ -3,15 +3,15 @@ coloredlogs==10.0.* cycler==0.10.* Cython==0.29.* dipy==1.3.* -fury==0.6.* +fury==0.7.* future==0.17.* h5py==2.10.* kiwisolver==1.0.* matplotlib==2.2.* nibabel==3.0.* nilearn==0.6.* -numpy==1.18.* -Pillow==7.1.* +numpy==1.20.* +Pillow==8.2.* bids-validator==1.6.0 pybids==0.10.* pyparsing==2.2.* diff --git a/scilpy/image/utils.py b/scilpy/image/utils.py index fd7c6f0c6..73141d09d 100644 --- a/scilpy/image/utils.py +++ b/scilpy/image/utils.py @@ -34,7 +34,7 @@ def count_non_zero_voxels(image): return nb_voxels -def volume_iterator(img, blocksize=1): +def volume_iterator(img, blocksize=1, start=0, end=0): """Generator that iterates on volumes of data. Parameters @@ -43,24 +43,34 @@ def volume_iterator(img, blocksize=1): Image of a 4D volume with shape X,Y,Z,N blocksize : int, optional Number of volumes to return in a single batch + start : int, optional + Starting iteration index in the 4D volume + end : int, optional + Stopping iteration index in the 4D volume + (the volume at this index is excluded) Yields ------- tuple of (list of int, ndarray) The ids of the selected volumes, and the selected data as a 4D array """ + assert end <= img.shape[-1], "End limit provided is greater than the " \ + "total number of volumes in image" + nb_volumes = img.shape[-1] + end = end if end else img.shape[-1] if blocksize == nb_volumes: - yield list(range(nb_volumes)), img.get_fdata(dtype=np.float32) + yield list(range(start, end)), \ + img.get_fdata(dtype=np.float32)[..., start:end] else: - start, end = 0, 0 - for i in range(0, nb_volumes - blocksize, blocksize): - start, end = i, i + blocksize - logging.info("Loading volumes {} to {}.".format(start, end - 1)) - yield list(range(start, end)), img.dataobj[..., start:end] + stop = start + for i in range(start, end - blocksize, blocksize): + start, stop = i, i + blocksize + logging.info("Loading volumes {} to {}.".format(start, stop - 1)) + yield list(range(start, stop)), img.dataobj[..., start:stop] - if end < nb_volumes: + if stop < end: logging.info( - "Loading volumes {} to {}.".format(end, nb_volumes - 1)) - yield list(range(end, nb_volumes)), img.dataobj[..., end:] + "Loading volumes {} to {}.".format(stop, end - 1)) + yield list(range(stop, end)), img.dataobj[..., stop:end] diff --git a/scilpy/reconst/utils.py b/scilpy/reconst/utils.py index f239a6361..d23a89fc9 100644 --- a/scilpy/reconst/utils.py +++ b/scilpy/reconst/utils.py @@ -15,6 +15,22 @@ def find_order_from_nb_coeff(data): return int((-3 + np.sqrt(1 + 8 * shape[-1])) / 2) +def get_sh_order_and_fullness(ncoeffs): + """ + Get the order of the SH basis from the number of SH coefficients + as well as a boolean indicating if the basis is full. + """ + # the two curves (sym and full) intersect at ncoeffs = 1, in what + # case both bases correspond to order 1. + sym_order = (-3.0 + np.sqrt(1.0 + 8.0 * ncoeffs)) / 2.0 + if sym_order.is_integer(): + return sym_order, False + full_order = np.sqrt(ncoeffs) - 1.0 + if full_order.is_integer(): + return full_order, True + raise ValueError('Invalid number of coefficients for SH basis.') + + def _honor_authorsnames_sh_basis(sh_basis_type): sh_basis = sh_basis_type if sh_basis_type == 'fibernav': diff --git a/scilpy/segment/voting_scheme.py b/scilpy/segment/voting_scheme.py index d12c3efea..f435bb18c 100644 --- a/scilpy/segment/voting_scheme.py +++ b/scilpy/segment/voting_scheme.py @@ -288,13 +288,13 @@ def __call__(self, input_tractogram_path, nbr_processes=1, seeds=None): slr_transform_type, seed]) tmp_dir, tmp_memmap_filenames = streamlines_to_memmap(wb_streamlines) + del wb_streamlines comb_param_cluster = product(self.tractogram_clustering_thr, seeds) # Clustring is now parallelize pool = multiprocessing.Pool(nbr_processes) all_rbx_dict = pool.map(single_clusterize_and_rbx_init, - zip(repeat(wb_streamlines), - repeat(tmp_memmap_filenames), + zip(repeat(tmp_memmap_filenames), comb_param_cluster, repeat(self.nb_points))) pool.close() @@ -363,8 +363,6 @@ def single_clusterize_and_rbx_init(args): Parameters ---------- - wb_streamlines : list or ArraySequence - All streamlines of the tractogram to segment. tmp_memmap_filename: tuple (3) Temporary filename for the data, offsets and lengths. @@ -381,11 +379,11 @@ def single_clusterize_and_rbx_init(args): rbx : dict Initialisation of the recobundles class using specific parameters. """ - wb_streamlines = args[0] - tmp_memmap_filename = args[1] - clustering_thr = args[2][0] - seed = args[2][1] - nb_points = args[3] + tmp_memmap_filename = args[0] + wb_streamlines = reconstruct_streamlines_from_memmap(tmp_memmap_filename) + clustering_thr = args[1][0] + seed = args[1][1] + nb_points = args[2] rbx = {} base_thresholds = [45, 35, 25] diff --git a/scilpy/tractanalysis/features.py b/scilpy/tractanalysis/features.py index ea7fe1f46..9d955b0b9 100644 --- a/scilpy/tractanalysis/features.py +++ b/scilpy/tractanalysis/features.py @@ -7,9 +7,52 @@ from dipy.segment.metric import ResampleFeature from dipy.segment.metric import AveragePointwiseEuclideanMetric from dipy.tracking import metrics as tm +from scilpy.tracking.tools import resample_streamlines_num_points import numpy as np +def detect_ushape(sft, minU, maxU): + """ + Extract streamlines depending of their "u-shapeness". + Parameters + ---------- + sft: Statefull tractogram + Tractogram used to extract streamlines depending on their ushapeness. + minU: Float + Minimum ufactor of a streamline. + maxU: Float + Maximum ufactor of a streamline. + + Returns + ------- + list: the ids of clean streamlines + Only the ids are returned so proper filtering can be done afterwards. + """ + ids = [] + new_sft = resample_streamlines_num_points(sft, 4) + for i, s in enumerate(new_sft.streamlines): + if len(s) == 4: + first_point = s[0] + last_point = s[-1] + second_point = s[1] + third_point = s[2] + + v1 = first_point - second_point + v2 = second_point - third_point + v3 = third_point - last_point + + v1 = v1 / np.linalg.norm(v1) + v2 = v2 / np.linalg.norm(v2) + v3 = v3 / np.linalg.norm(v3) + + val = np.dot(np.cross(v1, v2), np.cross(v2, v3)) + + if minU <= val <= maxU: + ids.append(i) + + return ids + + def remove_loops_and_sharp_turns(streamlines, max_angle, use_qb=False, diff --git a/scilpy/tractanalysis/todi.py b/scilpy/tractanalysis/todi.py index 7e12cfa53..a1e385313 100644 --- a/scilpy/tractanalysis/todi.py +++ b/scilpy/tractanalysis/todi.py @@ -200,7 +200,8 @@ def smooth_todi_spatial(self, sigma=0.5): new_todi = deepcopy(tmp_todi) else: new_todi = np.hstack((new_todi, tmp_todi)) - self.todi = np.delete(self.todi, range(0, chunk_size), axis=1) + self.todi = np.delete(self.todi, range( + 0, min(self.todi.shape[1], chunk_size)), axis=1) chunk_count -= 1 self.mask = new_mask diff --git a/scilpy/utils/bvec_bval_tools.py b/scilpy/utils/bvec_bval_tools.py index 788f87175..e115fc04f 100644 --- a/scilpy/utils/bvec_bval_tools.py +++ b/scilpy/utils/bvec_bval_tools.py @@ -1,6 +1,7 @@ # -*- coding: utf-8 -*- import logging +from enum import Enum import numpy as np @@ -11,6 +12,12 @@ DEFAULT_B0_THRESHOLD = 20 +class B0ExtractionStrategy(Enum): + FIRST = "first" + MEAN = "mean" + ALL = "all" + + def is_normalized_bvecs(bvecs): """ Check if b-vectors are normalized. @@ -282,11 +289,11 @@ def extract_dwi_shell(dwi, bvals, bvecs, bvals_to_extract, tol=20, bvals_to_extract : list of int The list of b-values to extract. tol : int - Loads the data using this block size. Useful when the data is too - large to be loaded in memory. - block_size : int The tolerated gap between the b-values to extract and the actual b-values. + block_size : int + Load the data using this block size. Useful when the data is too + large to be loaded in memory. Returns ------- @@ -335,6 +342,89 @@ def extract_dwi_shell(dwi, bvals, bvecs, bvals_to_extract, tol=20, return indices, shell_data, output_bvals, output_bvecs +def extract_b0(dwi, b0_mask, extract_in_cluster=False, + strategy=B0ExtractionStrategy.MEAN, block_size=None): + """ + Extract a set of b0 volumes from a dwi dataset + + Parameters + ---------- + dwi : nib.Nifti1Image + Original multi-shell volume. + b0_mask: array of bool + Mask over the time dimension (4th) identifying b0 volumes. + extract_in_cluster: bool + Specify to extract b0's in each continuous sets of b0 volumes + appearing in the input data. + strategy: Enum + The extraction strategy, of either select the first b0 found, select + them all or average them. When used in conjunction with the batch + parameter set to True, the strategy is applied individually on each + continuous set found. + block_size : int + Load the data using this block size. Useful when the data is too + large to be loaded in memory. + + Returns + ------- + b0_data : ndarray + Extracted b0 volumes. + """ + + indices = np.where(b0_mask)[0] + + if block_size is None: + block_size = dwi.shape[-1] + + if not extract_in_cluster and strategy == B0ExtractionStrategy.FIRST: + idx = np.min(indices) + output_b0 = dwi.dataobj[..., idx:idx + 1].squeeze() + else: + # Generate list of clustered b0 in the data + mask = np.ma.masked_array(b0_mask) + mask[~b0_mask] = np.ma.masked + b0_clusters = np.ma.notmasked_contiguous(mask, axis=0) + + if extract_in_cluster or strategy == B0ExtractionStrategy.ALL: + if strategy == B0ExtractionStrategy.ALL: + time_d = len(indices) + else: + time_d = len(b0_clusters) + + output_b0 = np.zeros(dwi.shape[:-1] + (time_d,)) + + for idx, cluster in enumerate(b0_clusters): + if strategy == B0ExtractionStrategy.FIRST: + data = dwi.dataobj[..., cluster.start:cluster.start + 1] + output_b0[..., idx] = data.squeeze() + else: + vol_it = volume_iterator(dwi, block_size, + cluster.start, cluster.stop) + + for vi, data in vol_it: + if strategy == B0ExtractionStrategy.ALL: + in_volume = np.array([i in vi for i in indices]) + output_b0[..., in_volume] = data + elif strategy == B0ExtractionStrategy.MEAN: + output_b0[..., idx] += np.sum(data, -1) + + if strategy == B0ExtractionStrategy.MEAN: + output_b0[..., idx] /= cluster.stop - cluster.start + + else: + output_b0 = np.zeros(dwi.shape[:-1]) + for cluster in b0_clusters: + vol_it = volume_iterator(dwi, block_size, + cluster.start, cluster.stop) + + for _, data in vol_it: + output_b0 += np.sum(data, -1) + + output_b0 /= len(indices) + + return output_b0 + + def flip_mrtrix_gradient_sampling(gradient_sampling_filename, gradient_sampling_flipped_filename, axes): """ diff --git a/scilpy/viz/scene_utils.py b/scilpy/viz/scene_utils.py index 2a4f493ae..5c53d4935 100644 --- a/scilpy/viz/scene_utils.py +++ b/scilpy/viz/scene_utils.py @@ -3,7 +3,7 @@ from enum import Enum import numpy as np -from dipy.reconst.shm import sh_to_sf +from dipy.reconst.shm import sh_to_sf_matrix from fury import window, actor from scilpy.io.utils import snapshot @@ -19,64 +19,80 @@ class CamParams(Enum): ZOOM_FACTOR = 'zoom_factor' -def initialize_camera(orientation, volume_shape): +def initialize_camera(orientation, slice_index, volume_shape): """ Initialize a camera for a given orientation. """ camera = {} # Tighten the view around the data camera[CamParams.ZOOM_FACTOR] = 2.0 / max(volume_shape) + # heuristic for setting the camera position at a distance + # proportional to the scale of the scene eye_distance = max(volume_shape) if orientation == 'sagittal': + if slice_index is None: + slice_index = volume_shape[0] // 2 camera[CamParams.VIEW_POS] = np.array([-eye_distance, (volume_shape[1] - 1) / 2.0, (volume_shape[2] - 1) / 2.0]) - camera[CamParams.VIEW_CENTER] = np.array([0.0, + camera[CamParams.VIEW_CENTER] = np.array([slice_index, (volume_shape[1] - 1) / 2.0, (volume_shape[2] - 1) / 2.0]) camera[CamParams.VIEW_UP] = np.array([0.0, 0.0, 1.0]) elif orientation == 'coronal': + if slice_index is None: + slice_index = volume_shape[1] // 2 camera[CamParams.VIEW_POS] = np.array([(volume_shape[0] - 1) / 2.0, eye_distance, (volume_shape[2] - 1) / 2.0]) camera[CamParams.VIEW_CENTER] = np.array([(volume_shape[0] - 1) / 2.0, - 0.0, + slice_index, (volume_shape[2] - 1) / 2.0]) camera[CamParams.VIEW_UP] = np.array([0.0, 0.0, 1.0]) elif orientation == 'axial': + if slice_index is None: + slice_index = volume_shape[2] // 2 camera[CamParams.VIEW_POS] = np.array([(volume_shape[0] - 1) / 2.0, (volume_shape[1] - 1) / 2.0, -eye_distance]) camera[CamParams.VIEW_CENTER] = np.array([(volume_shape[0] - 1) / 2.0, (volume_shape[1] - 1) / 2.0, - 0.0]) + slice_index]) camera[CamParams.VIEW_UP] = np.array([0.0, 1.0, 0.0]) else: raise ValueError('Invalid axis name: {0}'.format(orientation)) return camera -def set_display_extent(slicer_actor, orientation, volume_shape): +def set_display_extent(slicer_actor, orientation, volume_shape, slice_index): """ Set the display extent for a fury actor in ``orientation``. """ if orientation == 'sagittal': - slicer_actor.display_extent(0, 0, 0, volume_shape[1], + if slice_index is None: + slice_index = volume_shape[0] // 2 + slicer_actor.display_extent(slice_index, slice_index, + 0, volume_shape[1], 0, volume_shape[2]) elif orientation == 'coronal': - slicer_actor.display_extent(0, volume_shape[0], 0, 0, + if slice_index is None: + slice_index = volume_shape[1] // 2 + slicer_actor.display_extent(0, volume_shape[0], + slice_index, slice_index, 0, volume_shape[2]) elif orientation == 'axial': + if slice_index is None: + slice_index = volume_shape[2] // 2 slicer_actor.display_extent(0, volume_shape[0], 0, volume_shape[1], - 0, 0) + slice_index, slice_index) else: raise ValueError('Invalid axis name : {0}'.format(orientation)) def create_odf_slicer(sh_fodf, mask, sphere, nb_subdivide, sh_order, sh_basis, full_basis, orientation, - scale, radial_scale, norm, colormap): + scale, radial_scale, norm, colormap, slice_index): """ Create a ODF slicer actor displaying a fODF slice. The input volume is a 3-dimensional grid containing the SH coefficients of the fODF for each @@ -87,20 +103,16 @@ def create_odf_slicer(sh_fodf, mask, sphere, nb_subdivide, if nb_subdivide is not None: sphere = sphere.subdivide(nb_subdivide) - # Convert SH coefficients to SF coefficients - fodf = sh_to_sf(sh_fodf, sphere, sh_order, sh_basis, - full_basis=full_basis) + # SH coefficients to SF coefficients matrix + B_mat = sh_to_sf_matrix(sphere, sh_order, sh_basis, + full_basis, return_inv=False) - # Get mask if supplied, otherwise create a mask discarding empty voxels - if mask is None: - mask = np.linalg.norm(fodf, axis=-1) > 0. - - odf_actor = actor.odf_slicer(fodf, mask=mask, norm=norm, + odf_actor = actor.odf_slicer(sh_fodf, mask=mask, norm=norm, radial_scale=radial_scale, sphere=sphere, colormap=colormap, - scale=scale) - set_display_extent(odf_actor, orientation, fodf.shape) + scale=scale, B_matrix=B_mat) + set_display_extent(odf_actor, orientation, sh_fodf.shape[:3], slice_index) return odf_actor @@ -124,25 +136,32 @@ def _get_affine_for_texture(orientation, offset): return affine -def create_texture_slicer(texture, value_range=None, orientation='axial', - opacity=1.0, offset=0.5, interpolation='nearest'): +def create_texture_slicer(texture, mask, slice_index, value_range=None, + orientation='axial', opacity=1.0, offset=0.5, + interpolation='nearest'): """ Create a texture displayed behind the fODF. The texture is applied on a plane with a given offset for the fODF grid. """ affine = _get_affine_for_texture(orientation, offset) - slicer_actor = actor.slicer(texture, affine=affine, + if mask is not None: + masked_texture = np.zeros_like(texture) + masked_texture[mask] = texture[mask] + else: + masked_texture = texture + + slicer_actor = actor.slicer(masked_texture, affine=affine, value_range=value_range, opacity=opacity, interpolation=interpolation) - set_display_extent(slicer_actor, orientation, texture.shape) - + set_display_extent(slicer_actor, orientation, texture.shape, slice_index) return slicer_actor -def create_peaks_slicer(data, orientation, peak_values=None, mask=None, - color=None, peaks_width=1.0): +def create_peaks_slicer(data, orientation, slice_index, peak_values=None, + mask=None, color=None, peaks_width=1.0, + symmetric=False): """ Create a peaks slicer actor rendering a slice of the fODF peaks """ @@ -153,20 +172,21 @@ def create_peaks_slicer(data, orientation, peak_values=None, mask=None, # Instantiate peaks slicer peaks_slicer = actor.peak_slicer(data, peaks_values=peak_values, mask=mask, colors=color, - linewidth=peaks_width) - set_display_extent(peaks_slicer, orientation, data.shape) + linewidth=peaks_width, + symmetric=symmetric) + set_display_extent(peaks_slicer, orientation, data.shape, slice_index) return peaks_slicer -def create_scene(actors, orientation, volume_shape): +def create_scene(actors, orientation, slice_index, volume_shape): """ Create a 3D scene containing actors fitting inside a grid. The camera is placed based on the orientation supplied by the user. The projection mode is parallel. """ # Configure camera - camera = initialize_camera(orientation, volume_shape) + camera = initialize_camera(orientation, slice_index, volume_shape) scene = window.Scene() scene.projection('parallel') @@ -182,15 +202,18 @@ def create_scene(actors, orientation, volume_shape): return scene -def render_scene(scene, window_size, interactor, output, silent): +def render_scene(scene, window_size, interactor, + output, silent, title='Viewer'): """ Render a scene. If a output is supplied, a snapshot of the rendered scene is taken. """ if not silent: - showm = window.ShowManager(scene, size=window_size, + showm = window.ShowManager(scene, title=title, + size=window_size, reset_camera=False, interactor_style=interactor) + showm.initialize() showm.start() diff --git a/scripts/scil_extract_b0.py b/scripts/scil_extract_b0.py index e2e093971..44d7e3f9f 100755 --- a/scripts/scil_extract_b0.py +++ b/scripts/scil_extract_b0.py @@ -19,7 +19,8 @@ from scilpy.io.utils import (add_verbose_arg, assert_inputs_exist, add_force_b0_arg) -from scilpy.utils.bvec_bval_tools import check_b0_threshold +from scilpy.utils.bvec_bval_tools import (check_b0_threshold, extract_b0, + B0ExtractionStrategy) from scilpy.utils.filenames import split_name_with_nii logger = logging.getLogger(__file__) @@ -39,13 +40,29 @@ def _build_arg_parser(): p.add_argument('--b0_thr', type=float, default=0.0, help='All b-values with values less than or equal ' 'to b0_thr are considered as b0s i.e. without ' - 'diffusion weighting.') + 'diffusion weighting. [%(default)s]') group = p.add_mutually_exclusive_group() group.add_argument('--all', action='store_true', help='Extract all b0. Index number will be appended to ' 'the output file.') group.add_argument('--mean', action='store_true', help='Extract mean b0.') + group.add_argument('--cluster-mean', action='store_true', + help='Extract mean of each continuous cluster of b0s.') + group.add_argument('--cluster-first', action='store_true', + help='Extract first b0 of each ' + 'continuous cluster of b0s.') + + p.add_argument('--block-size', '-s', + metavar='INT', type=int, + help='Load the data using this block size. ' + 'Useful\nwhen the data is too large to be ' + 'loaded in memory.') + + p.add_argument('--single-image', action='store_true', + help='If output b0 volume has multiple time points, only ' + 'outputs a single image instead of a numbered series ' + 'of images.') add_force_b0_arg(p) add_verbose_arg(p) @@ -53,32 +70,15 @@ def _build_arg_parser(): return p -def _keep_time_step(dwi, time, output): - image = nib.load(dwi) - data = image.get_fdata(dtype=np.float32) - +def _split_time_steps(b0, affine, header, output): fname, fext = split_name_with_nii(os.path.basename(output)) - multi_b0 = len(time) > 1 - for t in time: - t_data = data[..., t] - + multiple_b0 = b0.shape[-1] > 1 + for t in range(b0.shape[-1]): out_name = os.path.join( os.path.dirname(os.path.abspath(output)), - '{}_{}{}'.format(fname, t, fext)) if multi_b0 else output - nib.save(nib.Nifti1Image(t_data, image.affine, image.header), - out_name) - - -def _mean_in_time(dwi, time, output): - image = nib.load(dwi) - - data = image.get_fdata(dtype=np.float32) - data = data[..., time] - data = np.mean(data, axis=3, dtype=data.dtype) - - nib.save(nib.Nifti1Image(data, image.affine, image.header), - output) + '{}_{}{}'.format(fname, t, fext)) if multiple_b0 else output + nib.save(nib.Nifti1Image(b0[..., t], affine, header), out_name) def main(): @@ -103,14 +103,25 @@ def main(): logger.info('Number of b0 images in the data: {}'.format(len(b0_idx))) - if args.mean: - logger.info('Using mean of indices {} for b0'.format(b0_idx)) - _mean_in_time(args.in_dwi, b0_idx, args.out_b0) + strategy, extract_in_cluster = B0ExtractionStrategy.FIRST, False + if args.mean or args.cluster_mean: + strategy = B0ExtractionStrategy.MEAN + extract_in_cluster = args.cluster_mean + elif args.all: + strategy = B0ExtractionStrategy.ALL + elif args.cluster_first: + extract_in_cluster = True + + image = nib.load(args.in_dwi) + + b0_volumes = extract_b0( + image, gtab.b0s_mask, extract_in_cluster, strategy, args.block_size) + + if len(b0_volumes.shape) > 3 and not args.single_image: + _split_time_steps(b0_volumes, image.affine, image.header, args.out_b0) else: - if not args.all: - b0_idx = [b0_idx[0]] - logger.info("Keeping {} for b0".format(b0_idx)) - _keep_time_step(args.in_dwi, b0_idx, args.out_b0) + nib.save(nib.Nifti1Image(b0_volumes, image.affine, image.header), + args.out_b0) if __name__ == '__main__': diff --git a/scripts/scil_extract_ushape.py b/scripts/scil_extract_ushape.py new file mode 100755 index 000000000..553d6b0a0 --- /dev/null +++ b/scripts/scil_extract_ushape.py @@ -0,0 +1,110 @@ +#!/usr/bin/env python3 +# -*- coding: utf-8 -*- + +""" +This script extracts streamlines depending on their U-shapeness. +This script is a replica of Trackvis method. + +When ufactor is close to: +* 0 it defines straight streamlines +* 1 it defines U-fibers +* -1 it defines S-fibers +""" + +import argparse +import json +import logging + +from dipy.io.streamline import save_tractogram +import numpy as np + +from scilpy.io.streamlines import load_tractogram_with_reference +from scilpy.io.utils import (add_json_args, + add_overwrite_arg, + add_reference_arg, + assert_inputs_exist, + assert_outputs_exist, + check_tracts_same_format) +from scilpy.tractanalysis.features import detect_ushape + + +def _build_arg_parser(): + p = argparse.ArgumentParser(formatter_class=argparse.RawTextHelpFormatter, + description=__doc__) + p.add_argument('in_tractogram', + help='Tractogram input file name.') + p.add_argument('out_tractogram', + help='Output tractogram file name.') + p.add_argument('--minU', + default=0.5, type=float, + help='Min ufactor value. [%(default)s]') + p.add_argument('--maxU', + default=1.0, type=float, + help='Max ufactor value. [%(default)s]') + + p.add_argument('--remaining_tractogram', + help='If set, saves remaining streamlines.') + p.add_argument('--no_empty', action='store_true', + help='Do not write file if there is no streamline.') + p.add_argument('--display_counts', action='store_true', + help='Print streamline count before and after filtering.') + + add_overwrite_arg(p) + add_reference_arg(p) + add_json_args(p) + + return p + + +def main(): + parser = _build_arg_parser() + args = parser.parse_args() + + assert_inputs_exist(parser, args.in_tractogram) + assert_outputs_exist(parser, args, args.out_tractogram, + optional=args.remaining_tractogram) + check_tracts_same_format(parser, [args.in_tractogram, args.out_tractogram, + args.remaining_tractogram]) + + if not(-1 <= args.minU <= 1 and -1 <= args.maxU <= 1): + parser.error('Min-Max ufactor "{},{}" '.format(args.minU, args.maxU) + + 'must be between -1 and 1.') + + sft = load_tractogram_with_reference( + parser, args, args.in_tractogram) + + ids_c = detect_ushape(sft, args.minU, args.maxU) + ids_l = np.setdiff1d(np.arange(len(sft.streamlines)), ids_c) + + if len(ids_c) == 0: + if args.no_empty: + logging.debug("The file {} won't be written " + "(0 streamline).".format(args.out_tractogram)) + return + + logging.debug('The file {} contains 0 streamline.'.format( + args.out_tractogram)) + + save_tractogram(sft[ids_c], args.out_tractogram) + + if args.display_counts: + sc_bf = len(sft.streamlines) + sc_af = len(ids_c) + print(json.dumps({'streamline_count_before_filtering': int(sc_bf), + 'streamline_count_after_filtering': int(sc_af)}, + indent=args.indent)) + + if args.remaining_tractogram: + if len(ids_l) == 0: + if args.no_empty: + logging.debug("The file {} won't be written (0 streamline" + ").".format(args.remaining_tractogram)) + return + + logging.warning('No remaining streamlines.') + + save_tractogram(sft[ids_l], args.remaining_tractogram) + + +if __name__ == "__main__": + main() diff --git a/scripts/scil_fix_dsi_studio_trk.py b/scripts/scil_fix_dsi_studio_trk.py index a7ccc4537..4bed9d7cd 100755 --- a/scripts/scil_fix_dsi_studio_trk.py +++ b/scripts/scil_fix_dsi_studio_trk.py @@ -23,6 +23,9 @@ This script was tested on various datasets and worked on all of them. However, always verify the results and if a specific case does not work. Open an issue on the Scilpy GitHub repository. + +WARNING: This script is still experimental, DSI-Studio evolves quickly and +results may vary depending on the data itself as well as DSI-studio version. """ import argparse diff --git a/scripts/scil_generate_priors_from_bundle.py b/scripts/scil_generate_priors_from_bundle.py index b6d864c6a..ef53b88e6 100755 --- a/scripts/scil_generate_priors_from_bundle.py +++ b/scripts/scil_generate_priors_from_bundle.py @@ -12,13 +12,14 @@ import os from dipy.data import get_sphere -from dipy.io.streamline import load_tractogram from dipy.reconst.shm import sf_to_sh, sh_to_sf import nibabel as nib import numpy as np from scilpy.io.image import get_data_as_mask +from scilpy.io.streamlines import load_tractogram_with_reference from scilpy.io.utils import (add_overwrite_arg, + add_reference_arg, add_sh_basis_args, assert_inputs_exist, assert_outputs_exist) @@ -60,6 +61,7 @@ def _build_arg_parser(): 'default is current directory.') add_overwrite_arg(p) + add_reference_arg(p) return p @@ -93,8 +95,7 @@ def main(): sh_order = find_order_from_nb_coeff(sh_shape) img_mask = nib.load(args.in_mask) - sft = load_tractogram(args.in_bundle, args.in_fodf, - trk_header_check=True) + sft = load_tractogram_with_reference(parser, args, args.in_bundle) sft.to_vox() if len(sft.streamlines) < 1: raise ValueError('The input bundle contains no streamline.') @@ -108,7 +109,8 @@ def main(): # Fancy masking of 1d indices to limit spatial dilation to WM sub_mask_3d = np.logical_and(get_data_as_mask(img_mask), - todi_obj.reshape_to_3d(todi_obj.get_mask())) + todi_obj.reshape_to_3d( + todi_obj.get_mask())) sub_mask_1d = sub_mask_3d.flatten()[todi_obj.get_mask()] todi_sf = todi_obj.get_todi()[sub_mask_1d] ** 2 diff --git a/scripts/scil_run_commit.py b/scripts/scil_run_commit.py index bd1e9ca2b..a5fd1b43a 100755 --- a/scripts/scil_run_commit.py +++ b/scripts/scil_run_commit.py @@ -10,8 +10,10 @@ multi-shell data and a peak file (principal fiber directions in each voxel, typically from a field of fODFs). -It is possible to use the ball-and-stick model for single-shell data. In this -case, the peak file is not mandatory. +It is possible to use the ball-and-stick model for single-shell and multi-shell +data. In this case, the peak file is not mandatory. Multi-shell should follow a +"NODDI protocol" (low and high b-values), multiple shells with similar b-values +should not be used with COMMIT. The output from COMMIT is: - fit_NRMSE.nii.gz @@ -20,9 +22,9 @@ fiting error (Root Mean Square Error) - results.pickle Dictionary containing the experiment parameters and final weights -- compartment_EC.nii.gz (Extra-Cellular) -- compartment_IC.nii.gz (Intra-Cellular) -- compartment_ISO.nii.gz (isotropic volume fraction (freewater comportment)) +- compartment_EC.nii.gz (est. Extra-Cellular signal fraction) +- compartment_IC.nii.gz (est. Intra-Cellular signal fraction) +- compartment_ISO.nii.gz (est. isotropic signal fraction (freewater comportment)) Each of COMMIT compartments - commit_weights.txt Text file containing the commit weights for each streamline of the @@ -32,22 +34,28 @@ - tot_commit_weights Text file containing the total commit weights of each streamline. Equal to commit_weights * streamlines_length (W_i * L_i) -- commit_weights.txt - Text file containing the commit weights for each streamline of the - input tractogram. - essential.trk / non_essential.trk Tractograms containing the streamlines below or equal (essential) and - above (non_essential) the --threshold_weights argument. + above (non_essential) a threshold_weights of 0. This script can divide the input tractogram in two using a threshold to apply -on the streamlines' weight. Typically, the threshold should be 0, keeping only +on the streamlines' weight. The threshold used is 0.0, keeping only streamlines that have non-zero weight and that contribute to explain the DWI signal. Streamlines with 0 weight are essentially not necessary according to COMMIT. COMMIT2 is available only for HDF5 data from scil_decompose_connectivity.py and with the --ball_stick option. Use the --commit2 option to activite it, slightly -longer computation time. +longer computation time. This wrapper offers a simplify way to call COMMIT, but +does not allow to use (or fine-tune) every parameters. If you want to use COMMIT +with full access to all parameters, visit: https://github.com/daducci/COMMIT + +When tunning parameters, such as --iso_diff, --para_diff, --perp_diff or +--lambda_commit_2 you should evaluate the quality of results by: + - Looking at the 'density' (GTM) of the connnectome (essential tractogram) + - Confirm the quality of WM bundles reconstruction (essential tractogram) + - Inspect the (N)RMSE map and look for peaks or anomalies + - Compare the density map before and after (essential tractogram) """ import argparse @@ -117,13 +125,13 @@ def _build_arg_parser(): help='Number of directions, on the half of the sphere,\n' 'representing the possible orientations of the ' 'response functions [%(default)s].') - p.add_argument('--nbr_iter', type=int, default=500, + p.add_argument('--nbr_iter', type=int, default=1000, help='Maximum number of iterations [%(default)s].') p.add_argument('--in_peaks', help='Peaks file representing principal direction(s) ' 'locally,\n typically coming from fODFs. This file is ' 'mandatory for the default\n stick-zeppelin-ball ' - 'model, when used with multi-shell data.') + 'model.') p.add_argument('--in_tracking_mask', help='Binary mask where tratography was allowed.\n' 'If not set, uses a binary mask computed from ' @@ -138,8 +146,9 @@ def _build_arg_parser(): g1 = p.add_argument_group(title='Model options') g1.add_argument('--ball_stick', action='store_true', - help='Use the ball&Stick model.\nDisable ' - 'the zeppelin compartment for single-shell data.') + help='Use the ball&Stick model, disable the zeppelin ' + 'compartment.\nOnly model suitable for single-shell ' + 'data.') g1.add_argument('--para_diff', type=float, help='Parallel diffusivity in mm^2/s.\n' 'Default for ball_stick: 1.7E-3\n' @@ -147,7 +156,7 @@ def _build_arg_parser(): g1.add_argument('--perp_diff', nargs='+', type=float, help='Perpendicular diffusivity in mm^2/s.\n' 'Default for ball_stick: None\n' - 'Default for stick_zeppelin_ball: [0.85E-3, 0.51E-3]') + 'Default for stick_zeppelin_ball: [0.51E-3]') g1.add_argument('--iso_diff', nargs='+', type=float, help='Istropic diffusivity in mm^2/s.\n' 'Default for ball_stick: [2.0E-3]\n' @@ -157,11 +166,6 @@ def _build_arg_parser(): g2.add_argument('--keep_whole_tractogram', action='store_true', help='Save a tractogram copy with streamlines weights in ' 'the data_per_streamline\n[%(default)s].') - g2.add_argument('--threshold_weights', metavar='THRESHOLD', - default=0., - help='Split the tractogram in two; essential and\n' - 'nonessential, based on the provided threshold ' - '[%(default)s].\n Use None to skip this step.') g3 = p.add_argument_group(title='Kernels options') kern = g3.add_mutually_exclusive_group() @@ -226,10 +230,8 @@ def _save_results_wrapper(args, tmp_dir, ext, hdf5_file, offsets_list, tmp_commit_weights = \ commit_weights[offsets_list[i]:offsets_list[i+1]] - if args.threshold_weights is None: - args.threshold_weights = -1 essential_ind = np.where( - tmp_commit_weights > args.threshold_weights)[0] + tmp_commit_weights > 0)[0] tmp_commit_weights = tmp_commit_weights[essential_ind] tmp_streamlines = reconstruct_streamlines(old_group['data'], @@ -267,41 +269,34 @@ def _save_results_wrapper(args, tmp_dir, ext, hdf5_file, offsets_list, for f in files: shutil.copy(os.path.join(commit_results_dir, f), out_dir) - # Save split tractogram (essential/nonessential) and/or saving the - # tractogram with data_per_streamline updated - if args.keep_whole_tractogram or args.threshold_weights is not None: - dps_key = 'commit2_weights' if is_commit_2 else \ - 'commit1_weights' - dps_key_tot = 'tot_commit2_weights' if is_commit_2 else \ - 'tot_commit1_weights' - # Reload is needed because of COMMIT handling its file by itself - sft.data_per_streamline[dps_key] = commit_weights - sft.data_per_streamline[dps_key_tot] = commit_weights*length_list - - if args.threshold_weights is None: - args.threshold_weights = -1 - essential_ind = np.where( - commit_weights > args.threshold_weights)[0] - nonessential_ind = np.where( - commit_weights <= args.threshold_weights)[0] - logging.debug('{} essential streamlines were kept at ' - 'threshold {}'.format(len(essential_ind), - args.threshold_weights)) - logging.debug('{} nonessential streamlines were kept at ' - 'threshold {}'.format(len(nonessential_ind), - args.threshold_weights)) - - save_tractogram(sft[essential_ind], - os.path.join(out_dir, - 'essential_tractogram.trk')) - save_tractogram(sft[nonessential_ind], - os.path.join(out_dir, - 'nonessential_tractogram.trk')) - if args.keep_whole_tractogram: - output_filename = os.path.join(out_dir, 'tractogram.trk') - logging.debug('Saving tractogram with weights as {}'.format( - output_filename)) - save_tractogram(sft, output_filename) + dps_key = 'commit2_weights' if is_commit_2 else \ + 'commit1_weights' + dps_key_tot = 'tot_commit2_weights' if is_commit_2 else \ + 'tot_commit1_weights' + # Reload is needed because of COMMIT handling its file by itself + sft.data_per_streamline[dps_key] = commit_weights + sft.data_per_streamline[dps_key_tot] = commit_weights*length_list + + essential_ind = np.where( + commit_weights > 0)[0] + nonessential_ind = np.where( + commit_weights <= 0)[0] + logging.debug('{} essential streamlines were kept at'.format( + len(essential_ind))) + logging.debug('{} nonessential streamlines were kept'.format( + len(nonessential_ind))) + + save_tractogram(sft[essential_ind], + os.path.join(out_dir, + 'essential_tractogram.trk')) + save_tractogram(sft[nonessential_ind], + os.path.join(out_dir, + 'nonessential_tractogram.trk')) + if args.keep_whole_tractogram: + output_filename = os.path.join(out_dir, 'tractogram.trk') + logging.debug('Saving tractogram with weights as {}'.format( + output_filename)) + save_tractogram(sft, output_filename) def main(): @@ -346,15 +341,6 @@ def main(): parser.error('{} does not have a compatible header with {}'.format( args.in_tractogram, args.in_dwi)) - if args.threshold_weights == 'None' or args.threshold_weights == 'none': - args.threshold_weights = None - if not args.keep_whole_tractogram and ext != '.h5': - logging.warning('Not thresholding weight with trk file without ' - 'the --keep_whole_tractogram will not save a ' - 'tractogram.') - else: - args.threshold_weights = float(args.threshold_weights) - # COMMIT has some c-level stdout and non-logging print that cannot # be easily stopped. Manual redirection of all printed output if args.verbose: diff --git a/scripts/scil_score_tractogram.py b/scripts/scil_score_tractogram.py index 721d57517..5f481ed90 100755 --- a/scripts/scil_score_tractogram.py +++ b/scripts/scil_score_tractogram.py @@ -150,7 +150,13 @@ def main(): logging.info("Verifying compatibility with ground-truth") for gt in args.gt_bundles: - compatible = is_header_compatible(sft, gt) + _, gt_ext = os.path.splitext(gt) + if gt_ext in ['.trk', '.tck']: + gt_bundle = load_tractogram_with_reference( + parser, args, gt, bbox_check=False) + else: + gt_bundle = gt + compatible = is_header_compatible(sft, gt_bundle) if not compatible: parser.error("Input tractogram incompatible with" " {}".format(gt)) diff --git a/scripts/scil_visualize_fodf.py b/scripts/scil_visualize_fodf.py index ae526f4ce..09a11f78b 100755 --- a/scripts/scil_visualize_fodf.py +++ b/scripts/scil_visualize_fodf.py @@ -19,6 +19,7 @@ from dipy.data import get_sphere from dipy.reconst.shm import order_from_ncoef +from scilpy.reconst.utils import get_sh_order_and_fullness from scilpy.io.utils import (add_sh_basis_args, add_overwrite_arg, assert_inputs_exist, assert_outputs_exist) from scilpy.io.image import get_data_as_mask @@ -62,10 +63,6 @@ def _build_arg_parser(): # Optional FODF personalization arguments add_sh_basis_args(p) - p.add_argument('--full_basis', action='store_true', - help='Use full SH basis to reconstruct fODF from ' - 'coefficients.') - sphere_choices = {'symmetric362', 'symmetric642', 'symmetric724', 'repulsion724', 'repulsion100', 'repulsion200'} p.add_argument('--sphere', default='symmetric724', choices=sphere_choices, @@ -154,11 +151,6 @@ def _parse_args(parser): inputs.append(args.background) if args.peaks: - if args.full_basis: - # FURY doesn't support asymmetric peaks visualization - warnings.warn('Asymmetric peaks visualization is not supported ' - 'by FURY. Peaks shown as symmetric peaks.', - UserWarning) inputs.append(args.peaks) if args.peaks_values: inputs.append(args.peaks_values) @@ -173,52 +165,25 @@ def _parse_args(parser): return args -def _crop_along_axis(data, index, axis_name): - """ - Extract a 2-dimensional slice from a 3-dimensional data volume - """ - if axis_name == 'sagittal': - if index is None: - data_slice = data[data.shape[0]//2, :, :] - else: - data_slice = data[index, :, :] - return data_slice[None, ...] - elif axis_name == 'coronal': - if index is None: - data_slice = data[:, data.shape[1]//2, :] - else: - data_slice = data[:, index, :] - return data_slice[:, None, ...] - elif axis_name == 'axial': - if index is None: - data_slice = data[:, :, data.shape[2]//2] - else: - data_slice = data[:, :, index] - return data_slice[:, :, None] - - def _get_data_from_inputs(args): """ Load data given by args. Perform checks to ensure dimensions agree between the data for mask, background, peaks and fODF. """ fodf = nib.nifti1.load(args.in_fodf).get_fdata(dtype=np.float32) - data = {'fodf': _crop_along_axis(fodf, args.slice_index, - args.axis_name)} + data = {'fodf': fodf} if args.background: bg = nib.nifti1.load(args.background).get_fdata(dtype=np.float32) if bg.shape[:3] != fodf.shape[:-1]: raise ValueError('Background dimensions {0} do not agree with fODF' ' dimensions {1}.'.format(bg.shape, fodf.shape)) - data['bg'] = _crop_along_axis(bg, args.slice_index, - args.axis_name) + data['bg'] = bg if args.mask: mask = get_data_as_mask(nib.nifti1.load(args.mask), dtype=bool) if mask.shape != fodf.shape[:-1]: raise ValueError('Mask dimensions {0} do not agree with fODF ' 'dimensions {1}.'.format(mask.shape, fodf.shape)) - data['mask'] = _crop_along_axis(mask, args.slice_index, - args.axis_name) + data['mask'] = mask if args.peaks: peaks = nib.nifti1.load(args.peaks).get_fdata(dtype=np.float32) if peaks.shape[:3] != fodf.shape[:-1]: @@ -234,8 +199,7 @@ def _get_data_from_inputs(args): raise ValueError('Peaks volume last dimension ({0}) cannot ' 'be reshaped as (npeaks, 3).' .format(peaks.shape[-1])) - data['peaks'] = _crop_along_axis(peaks, args.slice_index, - args.axis_name) + data['peaks'] = peaks if args.peaks_values: peak_vals =\ nib.nifti1.load(args.peaks_values).get_fdata(dtype=np.float32) @@ -243,36 +207,17 @@ def _get_data_from_inputs(args): raise ValueError('Peaks volume dimensions {0} do not agree ' 'with fODF dimensions {1}.' .format(peak_vals.shape, fodf.shape)) - data['peaks_values'] =\ - _crop_along_axis(peak_vals, args.slice_index, - args.axis_name) - - grid_shape = data['fodf'].shape[:3] - return data, grid_shape - + data['peaks_values'] = peak_vals -def validate_order(sh_order, ncoeffs, full_basis): - """ - Check that the sh order agrees with the number - of coefficients in the input - """ - if full_basis: - expected_ncoeffs = (sh_order + 1)**2 - else: - expected_ncoeffs = (sh_order + 1) * (sh_order + 2) // 2 - return ncoeffs == expected_ncoeffs + return data def main(): parser = _build_arg_parser() args = _parse_args(parser) - data, grid_shape = _get_data_from_inputs(args) + data = _get_data_from_inputs(args) sph = get_sphere(args.sphere) - sh_order = order_from_ncoef(data['fodf'].shape[-1], args.full_basis) - if not validate_order(sh_order, data['fodf'].shape[-1], args.full_basis): - parser.error('Invalid number of coefficients for fODF. ' - 'Use --full_basis if your input is in ' - 'full SH basis.') + sh_order, full_basis = get_sh_order_and_fullness(data['fodf'].shape[-1]) actors = [] @@ -285,15 +230,17 @@ def main(): # Instantiate the ODF slicer actor odf_actor = create_odf_slicer(data['fodf'], mask, sph, args.sph_subdivide, sh_order, - args.sh_basis, args.full_basis, + args.sh_basis, full_basis, args.axis_name, args.scale, not args.radial_scale_off, - not args.norm_off, args.colormap) + not args.norm_off, args.colormap, + args.slice_index) actors.append(odf_actor) # Instantiate a texture slicer actor if a background image is supplied if 'bg' in data: - bg_actor = create_texture_slicer(data['bg'], + bg_actor = create_texture_slicer(data['bg'], mask, + args.slice_index, args.bg_range, args.axis_name, args.bg_opacity, @@ -311,15 +258,19 @@ def main(): np.ones(data['peaks'].shape[:-1]) * args.peaks_length peaks_actor = create_peaks_slicer(data['peaks'], args.axis_name, + args.slice_index, peaks_values, mask, args.peaks_color, - args.peaks_width) + args.peaks_width, + not full_basis) actors.append(peaks_actor) # Prepare and display the scene - scene = create_scene(actors, args.axis_name, grid_shape) + scene = create_scene(actors, args.axis_name, + args.slice_index, + data['fodf'].shape[:3]) render_scene(scene, args.win_dims, args.interactor, args.output, args.silent) diff --git a/scripts/tests/test_extract_ushape.py b/scripts/tests/test_extract_ushape.py new file mode 100644 index 000000000..83307902b --- /dev/null +++ b/scripts/tests/test_extract_ushape.py @@ -0,0 +1,31 @@ +#!/usr/bin/env python3 +# -*- coding: utf-8 -*- + +import os +import tempfile + +from scilpy.io.fetcher import get_testing_files_dict, fetch_data, get_home + + +# If they already exist, this only takes 5 seconds (check md5sum) +fetch_data(get_testing_files_dict(), keys=['tracking.zip']) +tmp_dir = tempfile.TemporaryDirectory() + + +def test_help_option(script_runner): + ret = script_runner.run('scil_extract_ushape.py', + '--help') + assert ret.success + + +def test_execution_processing(script_runner): + os.chdir(os.path.expanduser(tmp_dir.name)) + in_trk = os.path.join(get_home(), 'tracking', 'union.trk') + out_trk = 'ushape.trk' + remaining_trk = 'remaining.trk' + ret = script_runner.run('scil_extract_ushape.py', in_trk, out_trk, + '--minU', '0.5', + '--maxU', '1', + '--remaining_tractogram', remaining_trk, + '--display_counts') + assert ret.success diff --git a/scripts/tests/test_visualize_fodf.py b/scripts/tests/test_visualize_fodf.py index dce8956dd..1f4d04c17 100644 --- a/scripts/tests/test_visualize_fodf.py +++ b/scripts/tests/test_visualize_fodf.py @@ -16,34 +16,6 @@ def test_help_option(script_runner): assert ret.success -def test_peaks_full_basis(script_runner): - os.chdir(os.path.expanduser(tmp_dir.name)) - in_fodf = os.path.join(get_home(), 'tracking', 'fodf.nii.gz') - in_peaks = os.path.join(get_home(), 'tracking', 'peaks.nii.gz') - # Tests that the use of a full SH basis with peaks raises a warning - with warnings.catch_warnings(record=True) as w: - ret = script_runner.run('scil_visualize_fodf.py', in_fodf, - '--full_basis', '--peaks', in_peaks) - assert(len(w) > 0) - assert(issubclass(w[0].category, UserWarning)) - assert('Asymmetric peaks visualization is not supported ' - 'by FURY. Peaks shown as symmetric peaks.' in - str(w[0].message)) - - # The whole execution should fail because - # the input fODF is not in full basis - assert (not ret.success) - - -def test_full_basis_input_without_arg(script_runner): - os.chdir(os.path.expanduser(tmp_dir.name)) - in_fodf = os.path.join(get_home(), 'tracking', 'fodf_full.nii.gz') - ret = script_runner.run('scil_visualize_fodf.py', in_fodf) - - # Using a full SH basis without --full_basis argument should fail - assert (not ret.success) - - def test_silent_without_output(script_runner): os.chdir(os.path.expanduser(tmp_dir.name)) in_fodf = os.path.join(get_home(), 'tracking', 'fodf.nii.gz')