diff --git a/scilpy/io/utils.py b/scilpy/io/utils.py index a2adc5728..658cd495f 100644 --- a/scilpy/io/utils.py +++ b/scilpy/io/utils.py @@ -313,7 +313,7 @@ def check(path): if check_dir_exists: path_dir = os.path.dirname(path) if path_dir and not os.path.isdir(path_dir): - parser.error('Directory {} \n for a given output file ' + parser.error('Directory {}/ \n for a given output file ' 'does not exists.'.format(path_dir)) if isinstance(required, str): diff --git a/scilpy/utils/metrics_tools.py b/scilpy/utils/metrics_tools.py index d39fa4d4a..e28658731 100644 --- a/scilpy/utils/metrics_tools.py +++ b/scilpy/utils/metrics_tools.py @@ -11,6 +11,80 @@ from scilpy.utils.filenames import split_name_with_nii +def compute_lesion_stats(map_data, lesion_atlas, single_label=True, + voxel_sizes=[1.0, 1.0, 1.0], min_lesion_vol=7, + precomputed_lesion_labels=None): + """ + Returns information related to lesion inside of a binary mask or voxel + labels map (bundle, for tractometry). + + Parameters + ------------ + map_data : np.ndarray + Either a binary mask (uint8) or a voxel labels map (int16). + lesion_atlas : np.ndarray (3) + Labelled atlas of lesion. Should be int16. + single_label : boolean + If true, does not add an extra layer for number of labels. + voxel_sizes : np.ndarray (3) + If not specified, returns voxel count (instead of volume) + min_lesion_vol : float + Minimum lesion volume in mm3 (default: 7, cross-shape). + precomputed_lesion_labels : np.ndarray (N) + For connectivity analysis, when the unique lesion labels are known, + provided a pre-computed list of labels save computation. + Returns + --------- + lesion_load_dict : dict + For each label, volume and lesion count + """ + voxel_vol = np.prod(voxel_sizes) + + if single_label: + labels_list = [1] + else: + labels_list = np.unique(map_data)[1:].astype(np.int32) + + section_dict = {'lesion_total_volume': {}, 'lesion_volume': {}, + 'lesion_count': {}} + for label in labels_list: + zlabel = str(label).zfill(3) + if not single_label: + tmp_mask = np.zeros(map_data.shape, dtype=np.int16) + tmp_mask[map_data == label] = 1 + tmp_mask *= lesion_atlas + else: + tmp_mask = lesion_atlas * map_data + + lesion_vols = [] + if precomputed_lesion_labels is None: + computed_lesion_labels = np.unique(tmp_mask)[1:] + else: + computed_lesion_labels = precomputed_lesion_labels + + for lesion in computed_lesion_labels: + curr_vol = np.count_nonzero(tmp_mask[tmp_mask == lesion]) \ + * voxel_vol + if curr_vol >= min_lesion_vol: + lesion_vols.append(curr_vol) + if lesion_vols: + section_dict['lesion_total_volume'][zlabel] = round( + np.sum(lesion_vols), 3) + section_dict['lesion_volume'][zlabel] = np.round(lesion_vols, 3).tolist() + section_dict['lesion_count'][zlabel] = float(len(lesion_vols)) + else: + section_dict['lesion_total_volume'][zlabel] = 0.0 + section_dict['lesion_volume'][zlabel] = [0.0] + section_dict['lesion_count'][zlabel] = 0.0 + + if single_label: + section_dict = {'lesion_total_volume': section_dict['lesion_total_volume']['001'], + 'lesion_volume': section_dict['lesion_volume']['001'], + 'lesion_count': section_dict['lesion_count']['001']} + + return section_dict + + def get_bundle_metrics_profiles(sft, metrics_files): """ Returns the profile of each metric along each streamline from a sft. @@ -161,11 +235,11 @@ def get_bundle_metrics_mean_std_per_point(streamlines, bundle_name, unique_labels = np.unique(labels) num_digits_labels = 3 if density_weighting: - track_count = compute_tract_counts_map(streamlines, - metrics[0].shape) + streamlines_count = compute_tract_counts_map(streamlines, + metrics[0].shape) else: - track_count = np.ones(metrics[0].shape) - track_count = track_count.astype(np.float64) + streamlines_count = np.ones(metrics[0].shape) + streamlines_count = streamlines_count.astype(np.float64) # Bigger weight near the centroid streamline distances_to_centroid_streamline = 1.0 / distances_to_centroid_streamline @@ -190,9 +264,9 @@ def get_bundle_metrics_mean_std_per_point(streamlines, bundle_name, label_metric = metric_data[label_indices[:, 0], label_indices[:, 1], label_indices[:, 2]] - track_weight = track_count[label_indices[:, 0], - label_indices[:, 1], - label_indices[:, 2]] + track_weight = streamlines_count[label_indices[:, 0], + label_indices[:, 1], + label_indices[:, 2]] label_weight = track_weight if distance_weighting: label_weight *= distances_to_centroid_streamline[labels == i] @@ -259,8 +333,8 @@ def plot_metrics_stats(means, stds, title=None, xlabel=None, std = np.std(means, axis=1) alpha = 0.5 else: - mean = np.array(means) - std = np.array(stds) + mean = np.array(means).ravel() + std = np.array(stds).ravel() alpha = 0.9 dim = np.arange(1, len(mean)+1, 1) diff --git a/scripts/scil_analyse_lesions_load.py b/scripts/scil_analyse_lesions_load.py new file mode 100644 index 000000000..1d01c97fc --- /dev/null +++ b/scripts/scil_analyse_lesions_load.py @@ -0,0 +1,157 @@ +#!/usr/bin/env python +# -*- coding: utf-8 -*- + +""" +This script will output informations about lesion load in bundle(s). +The input can either be streamlines, binary bundle map, or a bundle voxel +label map. + +To be considered a valid lesion, the lesion volume must be at least +min_lesion_vol mm3. This avoid the detection of thousand of single voxel +lesions if an automatic lesion segmentation tool is used. +""" + +import argparse +import json +import os + +import nibabel as nib +import numpy as np +import scipy.ndimage as ndi + + +from scilpy.io.image import get_data_as_mask, get_data_as_label +from scilpy.io.streamlines import load_tractogram_with_reference +from scilpy.io.utils import (add_overwrite_arg, + assert_inputs_exist, + add_json_args, + assert_outputs_exist, + add_reference_arg) +from scilpy.segment.streamlines import filter_grid_roi +from scilpy.tractanalysis.streamlines_metrics import compute_tract_counts_map +from scilpy.utils.filenames import split_name_with_nii +from scilpy.utils.metrics_tools import compute_lesion_stats + + +def _build_arg_parser(): + p = argparse.ArgumentParser(description=__doc__, + formatter_class=argparse.RawTextHelpFormatter) + + p.add_argument('in_lesion', + help='Binary mask of the lesion(s) (.nii.gz).') + p.add_argument('out_json', + help='Output file for lesion information (.json).') + p1 = p.add_mutually_exclusive_group() + p1.add_argument('--bundle', + help='Path of the bundle file (.trk).') + p1.add_argument('--bundle_mask', + help='Path of the bundle binary mask (.nii.gz).') + p1.add_argument('--bundle_labels_map', + help='Path of the bundle labels map (.nii.gz).') + + p.add_argument('--min_lesion_vol', type=float, default=7, + help='Minimum lesion volume in mm3 [%(default)s].') + p.add_argument('--out_lesion_atlas', metavar='FILE', + help='Save the labelized lesion(s) map (.nii.gz).') + p.add_argument('--out_lesion_stats', metavar='FILE', + help='Save the lesion-wise volume measure (.json).') + p.add_argument('--out_streamlines_stats', metavar='FILE', + help='Save the lesion-wise streamline count (.json).') + + add_json_args(p) + add_overwrite_arg(p) + add_reference_arg(p) + + return p + + +def main(): + parser = _build_arg_parser() + args = parser.parse_args() + + if (not args.bundle) and (not args.bundle_mask) \ + and (not args.bundle_labels_map): + parser.error('One of the option --bundle or --map must be used') + + assert_inputs_exist(parser, [args.in_lesion], + optional=[args.bundle, args.bundle_mask, + args.bundle_labels_map]) + assert_outputs_exist(parser, args, args.out_json, + optional=[args.out_lesion_stats, + args.out_streamlines_stats]) + + lesion_img = nib.load(args.in_lesion) + lesion_data = get_data_as_mask(lesion_img, dtype=np.bool) + + if args.bundle: + bundle_name, _ = split_name_with_nii(os.path.basename(args.bundle)) + sft = load_tractogram_with_reference(parser, args, args.bundle) + sft.to_vox() + sft.to_corner() + streamlines = sft.get_streamlines_copy() + map_data = compute_tract_counts_map(streamlines, + lesion_data.shape) + map_data[map_data > 0] = 1 + elif args.bundle_mask: + bundle_name, _ = split_name_with_nii( + os.path.basename(args.bundle_mask)) + map_img = nib.load(args.bundle_mask) + map_data = get_data_as_mask(map_img) + else: + bundle_name, _ = split_name_with_nii(os.path.basename( + args.bundle_labels_map)) + map_img = nib.load(args.bundle_labels_map) + map_data = get_data_as_label(map_img) + + is_single_label = args.bundle_labels_map is None + voxel_sizes = lesion_img.header.get_zooms()[0:3] + lesion_atlas, _ = ndi.label(lesion_data) + + lesion_load_dict = compute_lesion_stats( + map_data, lesion_atlas, single_label=is_single_label, + voxel_sizes=voxel_sizes, min_lesion_vol=args.min_lesion_vol) + + if args.out_lesion_atlas: + # lesion_atlas *= map_data.astype(np.bool) + nib.save(nib.Nifti1Image(lesion_atlas, lesion_img.affine), + args.out_lesion_atlas) + + volume_dict = {bundle_name: lesion_load_dict} + with open(args.out_json, 'w') as outfile: + json.dump(volume_dict, outfile, + sort_keys=args.sort_keys, indent=args.indent) + + if args.out_streamlines_stats or args.out_lesion_stats: + lesion_dict = {} + for lesion in np.unique(lesion_atlas)[1:]: + curr_vol = np.count_nonzero(lesion_atlas[lesion_atlas == lesion]) \ + * np.prod(voxel_sizes) + if curr_vol >= args.min_lesion_vol: + key = str(lesion).zfill(4) + lesion_dict[key] = {'volume': curr_vol} + if args.bundle: + tmp = np.zeros(lesion_atlas.shape) + tmp[lesion_atlas == lesion] = 1 + new_sft, _ = filter_grid_roi(sft, tmp, 'any', False) + lesion_dict[key]['strs_count'] = len(new_sft) + + lesion_vol_dict = {bundle_name: {}} + streamlines_count_dict = {bundle_name: {'streamlines_count': {}}} + for key in lesion_dict.keys(): + lesion_vol_dict[bundle_name][key] = lesion_dict[key]['volume'] + if args.bundle: + streamlines_count_dict[bundle_name]['streamlines_count'][key] = \ + lesion_dict[key]['strs_count'] + + if args.out_lesion_stats: + with open(args.out_lesion_stats, 'w') as outfile: + json.dump(lesion_vol_dict, outfile, + sort_keys=args.sort_keys, indent=args.indent) + if args.out_streamlines_stats: + with open(args.out_streamlines_stats, 'w') as outfile: + json.dump(streamlines_count_dict, outfile, + sort_keys=args.sort_keys, indent=args.indent) + + +if __name__ == "__main__": + main() diff --git a/scripts/scil_compute_connectivity.py b/scripts/scil_compute_connectivity.py index 8e8a60cc7..cc69b31bf 100755 --- a/scripts/scil_compute_connectivity.py +++ b/scripts/scil_compute_connectivity.py @@ -13,8 +13,9 @@ This script only generates matrices in the form of array, does not visualize or reorder the labels (node). -The parameter --similarity expects a folder with density maps (LABEL1_LABEL2.nii.gz) -following the same naming convention as the input directory. +The parameter --similarity expects a folder with density maps +(LABEL1_LABEL2.nii.gz) following the same naming convention as the input +directory. The bundles should be averaged version in the same space. This will compute the weighted-dice between each node and their homologuous average version. @@ -27,6 +28,13 @@ pre-computed maps (LABEL1_LABEL2.nii.gz) following the same naming convention as the input directory. Each will generate a matrix. The average non-zeros value in the map will be reported in the matrices nodes. + +The parameters --lesion_load will compute 3 lesion(s) related matrices: +lesion_count.npy, lesion_vol.npy, lesion_sc.npy and put it inside of a +specified folder. They represent the number of lesion, the total volume of +lesion(s) and the total of streamlines going through the lesion(s) for of each +connection. Each connection can be seen as a 'bundle' and then something +similar to scil_analyse_lesion_load.py is run for each 'bundle'. """ import argparse @@ -39,11 +47,13 @@ import coloredlogs from dipy.io.utils import is_header_compatible, get_reference_info from dipy.tracking.streamlinespeed import length +from dipy.tracking.vox2track import _streamlines_in_mask import h5py import nibabel as nib import numpy as np +import scipy.ndimage as ndi -from scilpy.io.image import get_data_as_label +from scilpy.io.image import get_data_as_label, get_data_as_mask from scilpy.io.streamlines import reconstruct_streamlines_from_hdf5 from scilpy.io.utils import (add_overwrite_arg, add_processes_arg, add_verbose_arg, @@ -51,6 +61,7 @@ validate_nbr_processes) from scilpy.tractanalysis.reproducibility_measures import compute_bundle_adjacency_voxel from scilpy.tractanalysis.streamlines_metrics import compute_tract_counts_map +from scilpy.utils.metrics_tools import compute_lesion_stats def load_node_nifti(directory, in_label, out_label, ref_img): @@ -75,6 +86,7 @@ def _processing_wrapper(args): similarity_directory = args[4][0] weighted = args[5] include_dps = args[6] + min_lesion_vol = args[7] hdf5_file = h5py.File(hdf5_filename, 'r') key = '{}_{}'.format(in_label, out_label) @@ -132,6 +144,7 @@ def _processing_wrapper(args): measures_to_compute.remove('similarity') for measure in measures_to_compute: + # Maps if isinstance(measure, str) and os.path.isdir(measure): map_dirname = measure map_data = load_node_nifti(map_dirname, @@ -139,23 +152,61 @@ def _processing_wrapper(args): labels_img) measures_to_return[map_dirname] = np.average( map_data[map_data > 0]) - elif isinstance(measure, tuple) and os.path.isfile(measure[0]): - metric_filename = measure[0] - metric_img = measure[1] - if not is_header_compatible(metric_img, labels_img): - logging.error('{} do not have a compatible header'.format( - metric_filename)) - raise IOError - - metric_data = metric_img.get_fdata(dtype=np.float64) - if weighted: - density = density / np.max(density) - voxels_value = metric_data * density - voxels_value = voxels_value[voxels_value > 0] - else: - voxels_value = metric_data[density > 0] + elif isinstance(measure, tuple): + if not isinstance(measure[0], tuple) \ + and os.path.isfile(measure[0]): + metric_filename = measure[0] + metric_img = measure[1] + if not is_header_compatible(metric_img, labels_img): + logging.error('{} do not have a compatible header'.format( + metric_filename)) + raise IOError + + metric_data = metric_img.get_fdata(dtype=np.float64) + if weighted: + density = density / np.max(density) + voxels_value = metric_data * density + voxels_value = voxels_value[voxels_value > 0] + else: + voxels_value = metric_data[density > 0] - measures_to_return[metric_filename] = np.average(voxels_value) + measures_to_return[metric_filename] = np.average(voxels_value) + # lesion + else: + lesion_filename = measure[0][0] + computed_lesion_labels = measure[0][1] + lesion_img = measure[1] + if not is_header_compatible(lesion_img, labels_img): + logging.error('{} do not have a compatible header'.format( + lesion_filename)) + raise IOError + + voxel_sizes = lesion_img.header.get_zooms()[0:3] + lesion_img.set_filename('tmp.nii.gz') + lesion_atlas = get_data_as_label(lesion_img) + tmp_dict = compute_lesion_stats( + density.astype(np.bool), lesion_atlas, + voxel_sizes=voxel_sizes, single_label=True, + min_lesion_vol=min_lesion_vol, + precomputed_lesion_labels=computed_lesion_labels) + + tmp_ind = _streamlines_in_mask(list(streamlines), + lesion_atlas.astype(np.uint8), + np.eye(3), [0, 0, 0]) + streamlines_count = len( + np.where(tmp_ind == [0, 1][True])[0].tolist()) + + if tmp_dict: + measures_to_return[lesion_filename+'vol'] = \ + tmp_dict['lesion_total_volume'] + measures_to_return[lesion_filename+'count'] = \ + tmp_dict['lesion_count'] + measures_to_return[lesion_filename+'sc'] = \ + streamlines_count + else: + measures_to_return[lesion_filename+'vol'] = 0 + measures_to_return[lesion_filename+'count'] = 0 + measures_to_return[lesion_filename+'sc'] = 0 if include_dps: for dps_key in hdf5_file[key].keys(): @@ -201,6 +252,11 @@ def _build_arg_parser(): metavar=('IN_FILE', 'OUT_FILE'), help='Input (.nii.gz). and output file (.npy) for a metric ' 'weighted matrix.') + p.add_argument('--lesion_load', nargs=2, metavar=('IN_FILE', 'OUT_DIR'), + help='Input binary mask (.nii.gz) and output directory ' + 'for all lesion-related matrices.') + p.add_argument('--min_lesion_vol', type=float, default=7, + help='Minimum lesion volume in mm3 [%(default)s].') p.add_argument('--density_weighting', action="store_true", help='Use density-weighting for the metric weighted matrix.') @@ -269,6 +325,25 @@ def main(): dict_metrics_out_name[in_name] = out_name measures_output_filename.append(out_name) + dict_lesion_out_name = {} + if args.lesion_load is not None: + in_name = args.lesion_load[0] + lesion_img = nib.load(in_name) + lesion_data = get_data_as_mask(lesion_img, dtype=np.bool) + lesion_atlas, _ = ndi.label(lesion_data) + measures_to_compute.append(((in_name, np.unique(lesion_atlas)[1:]), + nib.Nifti1Image(lesion_atlas, + lesion_img.affine))) + + out_name_1 = os.path.join(args.lesion_load[1], 'lesion_vol.npy') + out_name_2 = os.path.join(args.lesion_load[1], 'lesion_count.npy') + out_name_3 = os.path.join(args.lesion_load[1], 'lesion_sc.npy') + + dict_lesion_out_name[in_name+'vol'] = out_name_1 + dict_lesion_out_name[in_name+'count'] = out_name_2 + dict_lesion_out_name[in_name+'sc'] = out_name_3 + measures_output_filename.extend([out_name_1, out_name_2, out_name_3]) + assert_outputs_exist(parser, args, measures_output_filename) if not measures_to_compute: parser.error('No connectivity measures were selected, nothing ' @@ -303,7 +378,8 @@ def main(): measures_to_compute, args.similarity, args.density_weighting, - args.include_dps])) + args.include_dps, + args.min_lesion_vol])) else: pool = multiprocessing.Pool(nbr_cpu) measures_dict_list = pool.map(_processing_wrapper, @@ -315,7 +391,8 @@ def main(): itertools.repeat(args.similarity), itertools.repeat( args.density_weighting), - itertools.repeat(args.include_dps))) + itertools.repeat(args.include_dps), + itertools.repeat(args.min_lesion_vol))) pool.close() pool.join() @@ -341,6 +418,7 @@ def main(): # Filling out all the matrices (symmetric) in the order of labels_list nbr_of_measures = len(list(measures_dict.values())[0]) matrix = np.zeros((len(labels_list), len(labels_list), nbr_of_measures)) + for in_label, out_label in measures_dict: curr_node_dict = measures_dict[(in_label, out_label)] measures_ordering = list(curr_node_dict.keys()) @@ -348,7 +426,6 @@ def main(): for i, measure in enumerate(curr_node_dict): in_pos = labels_list.index(in_label) out_pos = labels_list.index(out_label) - matrix[in_pos, out_pos, i] = curr_node_dict[measure] matrix[out_pos, in_pos, i] = curr_node_dict[measure] @@ -366,6 +443,8 @@ def main(): matrix_basename = dict_metrics_out_name[measure] elif measure in dict_maps_out_name: matrix_basename = dict_maps_out_name[measure] + elif measure in dict_lesion_out_name: + matrix_basename = dict_lesion_out_name[measure] else: matrix_basename = measure diff --git a/scripts/scil_convert_json_to_xlsx.py b/scripts/scil_convert_json_to_xlsx.py index cc593f796..2c48e04ef 100755 --- a/scripts/scil_convert_json_to_xlsx.py +++ b/scripts/scil_convert_json_to_xlsx.py @@ -38,7 +38,6 @@ def _get_metrics_names(stats): for bundles in iter(stats.values()): for val in iter(bundles.values()): mnames |= set(val.keys()) - return mnames @@ -66,11 +65,17 @@ def _find_stat_name(stats): def _get_stats_parse_function(stats, stats_over_population): first_sub_stats = stats[list(stats.keys())[0]] first_bundle_stats = first_sub_stats[list(first_sub_stats.keys())[0]] - first_bundle_substat = first_bundle_stats[list(first_bundle_stats.keys())[0]] + first_bundle_substat = first_bundle_stats[list( + first_bundle_stats.keys())[0]] if len(first_bundle_stats.keys()) == 1 and\ _are_all_elements_scalars(first_bundle_stats): return _parse_scalar_stats + elif len(first_bundle_stats.keys()) == 4 and \ + set(first_bundle_stats.keys()) == \ + set(['lesion_total_vol', 'lesion_avg_vol', 'lesion_std_vol', + 'lesion_count']): + return _parse_lesion elif len(first_bundle_stats.keys()) == 4 and \ set(first_bundle_stats.keys()) == \ set(['min_length', 'max_length', 'mean_length', 'std_length']): @@ -156,6 +161,42 @@ def _parse_scalar_meanstd(stats, subs, bundles): return dataframes, df_names +def _parse_scalar_lesions(stats, subs, bundles): + metric_names = _get_metrics_names(stats) + nb_subs = len(subs) + nb_bundles = len(bundles) + nb_metrics = len(metric_names) + + means = np.full((nb_subs, nb_bundles, nb_metrics), np.NaN) + stddev = np.full((nb_subs, nb_bundles, nb_metrics), np.NaN) + + for sub_id, sub_name in enumerate(subs): + for bundle_id, bundle_name in enumerate(bundles): + for metric_id, metric_name in enumerate(metric_names): + b_stat = stats[sub_name].get(bundle_name) + + if b_stat is not None: + m_stat = b_stat.get(metric_name) + + if m_stat is not None: + means[sub_id, bundle_id, metric_id] = m_stat['mean'] + stddev[sub_id, bundle_id, metric_id] = m_stat['std'] + + dataframes = [] + df_names = [] + + for metric_id, metric_name in enumerate(metric_names): + dataframes.append(pd.DataFrame(data=means[:, :, metric_id], + index=subs, columns=bundles)) + df_names.append(metric_name + "_mean") + + dataframes.append(pd.DataFrame(data=stddev[:, :, metric_id], + index=subs, columns=bundles)) + df_names.append(metric_name + "_std") + + return dataframes, df_names + + def _parse_lengths(stats, subs, bundles): nb_subs = len(subs) nb_bundles = len(bundles) @@ -193,6 +234,44 @@ def _parse_lengths(stats, subs, bundles): return dataframes, df_names +def _parse_lesion(stats, subs, bundles): + nb_subs = len(subs) + nb_bundles = len(bundles) + + total_volume = np.full((nb_subs, nb_bundles), np.NaN) + avg_volume = np.full((nb_subs, nb_bundles), np.NaN) + std_volume = np.full((nb_subs, nb_bundles), np.NaN) + lesion_count = np.full((nb_subs, nb_bundles), np.NaN) + + for sub_id, sub_name in enumerate(subs): + for bundle_id, bundle_name in enumerate(bundles): + b_stat = stats[sub_name].get(bundle_name) + + if b_stat is not None: + total_volume[sub_id, bundle_id] = b_stat['lesion_total_vol'] + avg_volume[sub_id, bundle_id] = b_stat['lesion_avg_vol'] + std_volume[sub_id, bundle_id] = b_stat['lesion_std_vol'] + lesion_count[sub_id, bundle_id] = b_stat['lesion_count'] + + dataframes = [pd.DataFrame(data=total_volume, + index=subs, + columns=bundles), + pd.DataFrame(data=avg_volume, + index=subs, + columns=bundles), + pd.DataFrame(data=std_volume, + index=subs, + columns=bundles), + pd.DataFrame(data=lesion_count, + index=subs, + columns=bundles)] + + df_names = ["lesion_total_vol", "lesion_avg_vol", + "lesion_std_vol", "lesion_count"] + + return dataframes, df_names + + def _parse_per_label_scalar(stats, subs, bundles): labels = _get_labels(stats) labels.sort() diff --git a/scripts/scil_merge_json.py b/scripts/scil_merge_json.py index 68a51e25f..b287d5e21 100755 --- a/scripts/scil_merge_json.py +++ b/scripts/scil_merge_json.py @@ -10,6 +10,8 @@ import json import os +import numpy as np + from scilpy.io.utils import (add_overwrite_arg, add_json_args, assert_inputs_exist, assert_outputs_exist) @@ -42,6 +44,20 @@ def _merge_dict(dict_1, dict_2, no_list=False, recursive=False): return new_dict +def _average_dict(dict_1): + for key in dict_1.keys(): + if isinstance(dict_1[key], dict): + dict_1[key] = _average_dict(dict_1[key]) + elif isinstance(dict_1[key], list) or np.isscalar(dict_1[key]): + new_dict = {} + for subkey in dict_1.keys(): + new_dict[subkey] = {'mean': np.average(dict_1[subkey]), + 'std': np.std(dict_1[subkey])} + return new_dict + + return dict_1 + + def _build_arg_parser(): p = argparse.ArgumentParser(formatter_class=argparse.RawTextHelpFormatter, description=__doc__) @@ -61,6 +77,8 @@ def _build_arg_parser(): help='Merge ignoring parent key (e.g for population).') p.add_argument('--recursive', action='store_true', help='Merge all entries at the lowest layers.') + p.add_argument('--average_last_layer', action='store_true', + help='Average all entries at the lowest layers.') add_json_args(p) add_overwrite_arg(p) @@ -87,6 +105,9 @@ def main(): no_list=args.no_list, recursive=args.recursive) + if args.average_last_layer: + out_dict = _average_dict(out_dict) + with open(args.out_json, 'w') as outfile: if args.add_parent_key: out_dict = {args.add_parent_key: out_dict} diff --git a/scripts/tests/test_plot_mean_std_per_point.py b/scripts/tests/test_plot_mean_std_per_point.py index 0899fd0bc..455a93d7b 100644 --- a/scripts/tests/test_plot_mean_std_per_point.py +++ b/scripts/tests/test_plot_mean_std_per_point.py @@ -22,6 +22,6 @@ def test_execution_tractometry(script_runner): in_json = os.path.join(get_home(), 'tractometry', 'metric_label.json') ret = script_runner.run('scil_plot_mean_std_per_point.py', in_json, - 'out/') + 'out/', '--stats_over_population') assert ret.success