From d3b16c023401c0c4b70286acae161f75ca8ac268 Mon Sep 17 00:00:00 2001 From: arnaudon Date: Fri, 8 Mar 2024 16:13:08 +0100 Subject: [PATCH] Feat: Add a general cli to populate whole brain --- .../app/direction_vectors.py | 114 ++++++++++++++++++ 1 file changed, 114 insertions(+) diff --git a/atlas_direction_vectors/app/direction_vectors.py b/atlas_direction_vectors/app/direction_vectors.py index 8b865b7..7462399 100644 --- a/atlas_direction_vectors/app/direction_vectors.py +++ b/atlas_direction_vectors/app/direction_vectors.py @@ -1,10 +1,16 @@ """Generate and save the direction vectors for different regions of the mouse brain""" + # pylint: disable=import-outside-toplevel,too-many-arguments import json import logging +from copy import copy +import numpy as np +from functools import partial +from joblib import Parallel, delayed import click # type: ignore import voxcell # type: ignore +from atlas_commons.utils import compute_halfway from atlas_commons.app_utils import ( EXISTING_FILE_PATH, assert_meta_properties, @@ -23,6 +29,9 @@ from atlas_direction_vectors.exceptions import AtlasDirectionVectorsError from atlas_direction_vectors.interpolation import interpolate_vectors from atlas_direction_vectors.isocortex import ISOCORTEX_ALGORITHMS +from atlas_direction_vectors.algorithms.layer_based_direction_vectors import ( + compute_layered_region_direction_vectors, +) L = logging.getLogger(__name__) @@ -388,3 +397,108 @@ def from_center(annotation_path, hierarchy_path, output_path, region, center): dir_vectors = direction_vectors_from_center.command(region_map, annotation.raw, region, center) annotation.with_data(dir_vectors).save_nrrd(output_path) + + +def _create_subregion_direction_vectors(region_id, region_config, region_map, annotation): + """Create direciton vectors for a subregion, to be used in parallel.""" + acronym = region_map.get(region_id, "acronym") + metadata = copy(region_config["metadata"]) + + for i, query in enumerate(metadata["layers"]["queries"]): + if "***" in query: + metadata["layers"]["queries"][i] = query.replace("***", acronym) + metadata["region"]["query"] = metadata["region"]["query"].replace("***", acronym) + + # names are not needed, but mandatory from atlas-commons + metadata["layers"]["names"] = metadata["layers"]["queries"] + + region_to_weight = {} + for key, weight in region_config["region_to_weight"].items(): + if "***" in key: + key = key.replace("***", acronym) + region_to_weight[key] = weight + + # find if the region is split in hemispheres (no voxel near the halfplane) + z_halfway = compute_halfway(annotation.shape[2]) + annot = np.zeros_like(annotation.raw, dtype=int) + subregion_ids = list(region_map.find(region_id, "id", with_descendants=True)) + annot[np.isin(annotation.raw, subregion_ids)] = 1 + + has_hemisphere = not any(annot[:, :, z_halfway - 1 : z_halfway + 1].flatten()) + L.info("Subregion %s has hemispheres: %s", region_map.get(region_id, "name"), has_hemisphere) + + return compute_layered_region_direction_vectors( + region_map=region_map, + annotation=annotation, + metadata=metadata, + region_to_weight=region_to_weight, + shading_width=region_config.get("shading_width", 4), + expansion_width=region_config.get("expansion_width", 8), + has_hemispheres=has_hemisphere, + ) + + +def _get_region_ids(region, region_map, region_config, annotation): + """Fetch all the region ids that need to be processed in parallel.""" + if not region_config["region_query"].get("with_descendants", False): + + return region_map.find( + region_config["region_query"]["query"], region_config["region_query"]["attribute"] + ) + + subregion_ids = [] + all_ids = np.unique(annotation.raw) + for subregion_id in region_map.find( + region_config["region_query"]["query"], + region_config["region_query"]["attribute"], + with_descendants=True, + ): + parent_id = region_map.get(subregion_id, "parent_structure_id") + if ( + subregion_id in all_ids + and region_map.is_leaf_id(subregion_id) + and parent_id not in subregion_ids + ): + subregion_ids.append(parent_id) + return subregion_ids + + +@app.command() +@common_atlas_options +@click.option( + "--output-path", + required=True, + help="Path of file to write the direction vectors to.", +) +@click.option("--config-path", type=str, required=True, help="path to config file") +@log_args(L) +def from_config(annotation_path, hierarchy_path, output_path, config_path): + """General function to compute direction vector from a configuration file.""" + with open(config_path) as config_file: + config = json.load(config_file) + + annotation = voxcell.VoxelData.load_nrrd(annotation_path) + region_map = voxcell.RegionMap.load_json(hierarchy_path) + direction_vectors = np.full(annotation.raw.shape + (3,), np.nan, dtype=np.float32) + + for region, region_config in config.items(): + if region == 'cerebellum': + continue + + region_ids = _get_region_ids(region, region_map, region_config, annotation) + L.info("Computing direction vectors for %s with %s subregions.", region, len(region_ids)) + f = partial( + _create_subregion_direction_vectors, + region_config=region_config, + region_map=region_map, + annotation=annotation, + ) + with Parallel(n_jobs=20, verbose=10) as parallel: + for subregion_direction_vectors in parallel( + delayed(f)(region_id) for region_id in region_ids + ): + # Assembles subregion direction vectors. + subregion_mask = np.logical_not(np.isnan(subregion_direction_vectors)) + direction_vectors[subregion_mask] = subregion_direction_vectors[subregion_mask] + + annotation.with_data(direction_vectors).save_nrrd(output_path)