Skip to content

Commit

Permalink
switch to ValueToIndexVoxels in voxcell
Browse files Browse the repository at this point in the history
  • Loading branch information
mgeplf committed Mar 26, 2024
1 parent 4044db5 commit 7bbb615
Show file tree
Hide file tree
Showing 2 changed files with 3 additions and 64 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -38,7 +38,6 @@
"""
from __future__ import annotations

import itertools as it
import logging
import warnings
from typing import Dict, List, Optional, Set, Tuple
Expand All @@ -48,7 +47,7 @@
from atlas_commons.typing import AnnotationT, FloatArray
from scipy.optimize import linprog
from tqdm import tqdm
from voxcell import RegionMap
from voxcell import RegionMap, voxel_data

from atlas_densities.densities import utils
from atlas_densities.densities.inhibitory_neuron_densities_helper import (
Expand Down Expand Up @@ -171,66 +170,6 @@ def _replace_inf_with_none(bounds: FloatArray) -> List[MinMaxPair]:
return [(float(min_), None if np.isinf(max_) else float(max_)) for min_, max_ in bounds]


class ValueToIndexVoxels:
"""Faster indexing of annotation style voxel volumes"""

def __init__(self, values):
"""Initialize.
Args:
values(np.array): volume with each voxel marked with a value; usually to group regions
"""
self._order = "C" if values.flags["C_CONTIGUOUS"] else "F"

values = values.ravel(order=self._order)
uniques, counts = np.unique(values, return_counts=True)

offsets = np.empty(len(counts) + 1, dtype=np.uint64)
offsets[0] = 0
offsets[1:] = np.cumsum(counts)

self._indices = np.argsort(values, kind="stable")
self._offsets = offsets
self._mapping = {v: i for i, v in enumerate(uniques)}

@property
def values(self):
"""List of values that are found in the original volume."""
return list(self._mapping)

def value_to_1d_indices(self, value):
"""Return the indices array indices corresponding to the 'value'.
Note: These are 1D indices, so the assumption is they are applied to a volume
who has been ValueToIndexVoxels::ravel(volume)
"""
if value not in self._mapping:
return np.array([], dtype=np.uint64)

group_index = self._mapping[value]
return self._indices[self._offsets[group_index] : self._offsets[group_index + 1]]

def ravel(self, voxel_data):
"""Ensures `voxel_data` matches the layout that the 1D indices can be used."""
return voxel_data.ravel(order=self._order)

def apply(self, values, funcs, voxel_data):
"""For pairs of `values` and `funcs`, apply the func as if a mask was created from `value`.
Args:
values(iterable of value): values to be found in original values array
funcs(iterable of funcs): if only a single function is provided,
it is used for all `values`
voxel_data(np.array): Array on which to apply function based on desired `values`
"""
flat_data = self.ravel(voxel_data)
if hasattr(funcs, "__call__"):
funcs = (funcs,)
for value, func in zip(values, it.cycle(funcs)):
idx = self.value_to_1d_indices(value)
yield func(flat_data[idx])


def _compute_region_cell_counts(
annotation: AnnotationT,
density: FloatArray,
Expand Down Expand Up @@ -260,7 +199,7 @@ def _compute_region_cell_counts(
... ...
The index is the sorted list of all region identifiers.
"""
vtiv = ValueToIndexVoxels(annotation)
vtiv = voxel_data.ValueToIndexVoxels(annotation)
density_copy = vtiv.ravel(density.copy())

id_counts = []
Expand Down
2 changes: 1 addition & 1 deletion setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,7 @@
# from the HiGHS library. We use the "highs" method in the densities module.
"scipy>=1.6.0",
"tqdm>=4.44.1",
"voxcell>=3.0.0",
"voxcell>=3.1.7",
],
extras_require={
"tests": [
Expand Down

0 comments on commit 7bbb615

Please sign in to comment.