Skip to content

Commit

Permalink
Merge pull request #317 from frheault/fix_empty_edge_compute_conn
Browse files Browse the repository at this point in the history
Fix RAM spike in decompose_connectivity
  • Loading branch information
arnaudbore authored Sep 24, 2020
2 parents b8fb602 + dd637da commit edb5193
Show file tree
Hide file tree
Showing 5 changed files with 13 additions and 23 deletions.
4 changes: 2 additions & 2 deletions scilpy/tractanalysis/quick_tools.pyx
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@ def get_next_real_point(points_to_index, vox_index):
int map_idx = -1
int nb_points_to_index
int internal_vox_index
cnp.npy_ulong[:] pts_to_index_view
cnp.npy_uint16[:] pts_to_index_view

nb_points_to_index = len(points_to_index)
internal_vox_index = vox_index
Expand All @@ -35,7 +35,7 @@ def get_previous_real_point(points_to_index, vox_index):
int map_index
int nb_points_to_index
int internal_vox_index
cnp.npy_ulong[:] pts_to_index_view
cnp.npy_uint16[:] pts_to_index_view

nb_points_to_index = len(points_to_index)
previous_point = nb_points_to_index
Expand Down
5 changes: 2 additions & 3 deletions scilpy/tractanalysis/tools.py
Original file line number Diff line number Diff line change
Expand Up @@ -145,11 +145,10 @@ def extract_longest_segments_from_profile(strl_indices, atlas_data):


def compute_connectivity(indices, atlas_data, real_labels, segmenting_func):
atlas_data = atlas_data.astype(np.int32)

connectivity = {k: {lab: [] for lab in real_labels} for k in real_labels}

for strl_idx, strl_indices in enumerate(indices):
if (np.array(strl_indices) > atlas_data.shape).any():
continue
segments_info = segmenting_func(strl_indices, atlas_data)

for si in segments_info:
Expand Down
8 changes: 4 additions & 4 deletions scilpy/tractanalysis/uncompress.pyx
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,7 @@ cdef struct Pointers:
# To bookkeep the final index related to a streamline point
cnp.npy_intp *pti_lengths_out
cnp.npy_intp *pti_offsets_out
cnp.uint64_t *points_to_index_out
cnp.uint16_t *points_to_index_out


@cython.boundscheck(False)
Expand All @@ -51,7 +51,7 @@ def uncompress(streamlines, return_mapping=False):

# Multiplying by 6 is simply a heuristic to avoiding resizing too many
# times. In my bundles tests, I had either 0 or 1 resize.
cnp.npy_intp max_points = (streamlines.get_data().size / 3) * 6
cnp.npy_intp max_points = (streamlines.get_data().size / 3)

new_array_sequence = nib.streamlines.array_sequence.ArraySequence()
new_array_sequence._lengths.resize(nb_streamlines)
Expand All @@ -61,7 +61,7 @@ def uncompress(streamlines, return_mapping=False):
points_to_index = nib.streamlines.array_sequence.ArraySequence()
points_to_index._lengths.resize(nb_streamlines)
points_to_index._offsets.resize(nb_streamlines)
points_to_index._data = np.zeros(int(streamlines.get_data().size / 3), np.uint64)
points_to_index._data = np.zeros(int(streamlines.get_data().size / 3), np.uint16)

cdef:
cnp.npy_intp[:] lengths_view_in = streamlines._lengths
Expand All @@ -72,7 +72,7 @@ def uncompress(streamlines, return_mapping=False):
cnp.uint16_t[:] data_view_out = new_array_sequence._data
cnp.npy_intp[:] pti_lengths_view_out = points_to_index._lengths
cnp.npy_intp[:] pti_offsets_view_out = points_to_index._offsets
cnp.uint64_t[:] points_to_index_view_out = points_to_index._data
cnp.uint16_t[:] points_to_index_view_out = points_to_index._data

cdef Pointers pointers
pointers.lengths_in = &lengths_view_in[0]
Expand Down
1 change: 1 addition & 0 deletions scripts/scil_compute_connectivity.py
Original file line number Diff line number Diff line change
Expand Up @@ -83,6 +83,7 @@ def _processing_wrapper(args):
streamlines = reconstruct_streamlines_from_hdf5(hdf5_file, key)
if len(streamlines) == 0:
return

affine, dimensions, voxel_sizes, _ = get_reference_info(labels_img)
measures_to_return = {}

Expand Down
18 changes: 4 additions & 14 deletions scripts/scil_decompose_connectivity.py
Original file line number Diff line number Diff line change
Expand Up @@ -110,8 +110,8 @@ def _save_if_needed(sft, hdf5_file, args,
in_label, out_label):
if step_type == 'final':
group = hdf5_file.create_group('{}_{}'.format(in_label, out_label))
group.create_dataset('data', data=sft.streamlines.get_data(),
dtype=np.float32)
group.create_dataset('data', data=sft.streamlines._data,
dtype=np.float16)
group.create_dataset('offsets', data=sft.streamlines._offsets,
dtype=np.int64)
group.create_dataset('lengths', data=sft.streamlines._lengths,
Expand Down Expand Up @@ -260,7 +260,8 @@ def main():

logging.info('*** Loading streamlines ***')
time1 = time.time()
sft = load_tractogram_with_reference(parser, args, args.in_tractogram)
sft = load_tractogram_with_reference(parser, args, args.in_tractogram,
bbox_check=False)
time2 = time.time()
logging.info(' Loading {} streamlines took {} sec.'.format(
len(sft), round(time2 - time1, 2)))
Expand All @@ -269,19 +270,8 @@ def main():
raise IOError('{} and {}do not have a compatible header'.format(
args.in_tractogram, args.in_labels))

logging.info('*** Filtering streamlines ***')
original_len = len(sft)
time1 = time.time()

sft.to_vox()
sft.to_corner()
sft.remove_invalid_streamlines()
time2 = time.time()
logging.info(
' Discarded {} streamlines from filtering in {} sec.'.format(
original_len - len(sft), round(time2 - time1, 2)))
logging.info(' Number of streamlines to process: {}'.format(len(sft)))

# Get all streamlines intersection indices
logging.info('*** Computing streamlines intersection ***')
time1 = time.time()
Expand Down

0 comments on commit edb5193

Please sign in to comment.