Skip to content

Commit

Permalink
Feat: Add a general cli to populate whole brain
Browse files Browse the repository at this point in the history
  • Loading branch information
arnaudon committed Mar 8, 2024
1 parent b3880ce commit d3b16c0
Showing 1 changed file with 114 additions and 0 deletions.
114 changes: 114 additions & 0 deletions atlas_direction_vectors/app/direction_vectors.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,16 @@
"""Generate and save the direction vectors for different regions of the mouse brain"""

# pylint: disable=import-outside-toplevel,too-many-arguments
import json
import logging
from copy import copy
import numpy as np
from functools import partial
from joblib import Parallel, delayed

import click # type: ignore
import voxcell # type: ignore
from atlas_commons.utils import compute_halfway
from atlas_commons.app_utils import (
EXISTING_FILE_PATH,
assert_meta_properties,
Expand All @@ -23,6 +29,9 @@
from atlas_direction_vectors.exceptions import AtlasDirectionVectorsError
from atlas_direction_vectors.interpolation import interpolate_vectors
from atlas_direction_vectors.isocortex import ISOCORTEX_ALGORITHMS
from atlas_direction_vectors.algorithms.layer_based_direction_vectors import (
compute_layered_region_direction_vectors,
)

L = logging.getLogger(__name__)

Expand Down Expand Up @@ -388,3 +397,108 @@ def from_center(annotation_path, hierarchy_path, output_path, region, center):

dir_vectors = direction_vectors_from_center.command(region_map, annotation.raw, region, center)
annotation.with_data(dir_vectors).save_nrrd(output_path)


def _create_subregion_direction_vectors(region_id, region_config, region_map, annotation):
"""Create direciton vectors for a subregion, to be used in parallel."""
acronym = region_map.get(region_id, "acronym")
metadata = copy(region_config["metadata"])

for i, query in enumerate(metadata["layers"]["queries"]):
if "***" in query:
metadata["layers"]["queries"][i] = query.replace("***", acronym)
metadata["region"]["query"] = metadata["region"]["query"].replace("***", acronym)

# names are not needed, but mandatory from atlas-commons
metadata["layers"]["names"] = metadata["layers"]["queries"]

region_to_weight = {}
for key, weight in region_config["region_to_weight"].items():
if "***" in key:
key = key.replace("***", acronym)
region_to_weight[key] = weight

# find if the region is split in hemispheres (no voxel near the halfplane)
z_halfway = compute_halfway(annotation.shape[2])
annot = np.zeros_like(annotation.raw, dtype=int)
subregion_ids = list(region_map.find(region_id, "id", with_descendants=True))
annot[np.isin(annotation.raw, subregion_ids)] = 1

has_hemisphere = not any(annot[:, :, z_halfway - 1 : z_halfway + 1].flatten())
L.info("Subregion %s has hemispheres: %s", region_map.get(region_id, "name"), has_hemisphere)

return compute_layered_region_direction_vectors(
region_map=region_map,
annotation=annotation,
metadata=metadata,
region_to_weight=region_to_weight,
shading_width=region_config.get("shading_width", 4),
expansion_width=region_config.get("expansion_width", 8),
has_hemispheres=has_hemisphere,
)


def _get_region_ids(region, region_map, region_config, annotation):
"""Fetch all the region ids that need to be processed in parallel."""
if not region_config["region_query"].get("with_descendants", False):

return region_map.find(
region_config["region_query"]["query"], region_config["region_query"]["attribute"]
)

subregion_ids = []
all_ids = np.unique(annotation.raw)
for subregion_id in region_map.find(
region_config["region_query"]["query"],
region_config["region_query"]["attribute"],
with_descendants=True,
):
parent_id = region_map.get(subregion_id, "parent_structure_id")
if (
subregion_id in all_ids
and region_map.is_leaf_id(subregion_id)
and parent_id not in subregion_ids
):
subregion_ids.append(parent_id)
return subregion_ids


@app.command()
@common_atlas_options
@click.option(
"--output-path",
required=True,
help="Path of file to write the direction vectors to.",
)
@click.option("--config-path", type=str, required=True, help="path to config file")
@log_args(L)
def from_config(annotation_path, hierarchy_path, output_path, config_path):
"""General function to compute direction vector from a configuration file."""
with open(config_path) as config_file:
config = json.load(config_file)

annotation = voxcell.VoxelData.load_nrrd(annotation_path)
region_map = voxcell.RegionMap.load_json(hierarchy_path)
direction_vectors = np.full(annotation.raw.shape + (3,), np.nan, dtype=np.float32)

for region, region_config in config.items():
if region == 'cerebellum':
continue

region_ids = _get_region_ids(region, region_map, region_config, annotation)
L.info("Computing direction vectors for %s with %s subregions.", region, len(region_ids))
f = partial(
_create_subregion_direction_vectors,
region_config=region_config,
region_map=region_map,
annotation=annotation,
)
with Parallel(n_jobs=20, verbose=10) as parallel:
for subregion_direction_vectors in parallel(
delayed(f)(region_id) for region_id in region_ids
):
# Assembles subregion direction vectors.
subregion_mask = np.logical_not(np.isnan(subregion_direction_vectors))
direction_vectors[subregion_mask] = subregion_direction_vectors[subregion_mask]

annotation.with_data(direction_vectors).save_nrrd(output_path)

0 comments on commit d3b16c0

Please sign in to comment.